Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan > Class Template Referenceexport

Generic plan cache keyed on batch size bucket. More...

Public Types

using PlanBuilder = std::function<TPlan( int bucket )>

Public Member Functions

 CublasLtPlanCache ()=default
 CublasLtPlanCache (int max_batch_size, PlanBuilder builder)
 Construct and eagerly build all plans.
const std::vector< int > & buckets () const
bool empty () const
const TPlan & get (int batch_size) const
 Get the plan for the smallest bucket >= batch_size.
size_t size () const

Private Attributes

std::vector< int > buckets_
std::unordered_map< int, TPlan > cache_

Detailed Description

template<typename TPlan>
class Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >

Generic plan cache keyed on batch size bucket.

TPlan must be move-constructible (e.g. CublasLtMatMulPlan<T>). Plans are built eagerly at construction time for all buckets.

Usage: CublasLtPlanCache<CublasLtMatMulPlan<float>> cache( max_batch_size, [&]( int bucket ) { return build_my_plan( bucket ); } );

const auto& plan = cache.get( actual_batch_size );

Member Typedef Documentation

◆ PlanBuilder

template<typename TPlan>
using Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::PlanBuilder = std::function<TPlan( int bucket )>

Constructor & Destructor Documentation

◆ CublasLtPlanCache() [1/2]

template<typename TPlan>
Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::CublasLtPlanCache ( )
default

◆ CublasLtPlanCache() [2/2]

template<typename TPlan>
Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::CublasLtPlanCache ( int max_batch_size,
PlanBuilder builder )
inline

Construct and eagerly build all plans.

Parameters
max_batch_sizeUpper bound (training / max seq len)
builderCallable: int bucket -> TPlan
Here is the call graph for this function:

Member Function Documentation

◆ buckets()

template<typename TPlan>
const std::vector< int > & Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::buckets ( ) const
inline

◆ empty()

template<typename TPlan>
bool Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::empty ( ) const
inline

◆ get()

template<typename TPlan>
const TPlan & Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::get ( int batch_size) const
inline

Get the plan for the smallest bucket >= batch_size.

Here is the call graph for this function:

◆ size()

template<typename TPlan>
size_t Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::size ( ) const
inline

Member Data Documentation

◆ buckets_

template<typename TPlan>
std::vector<int> Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::buckets_
private

◆ cache_

template<typename TPlan>
std::unordered_map<int, TPlan> Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::cache_
private

The documentation for this class was generated from the following file: