|
Mila 0.13.48
Deep Neural Network Library
|
Generic plan cache keyed on batch size bucket. More...
Public Types | |
| using | PlanBuilder = std::function<TPlan( int bucket )> |
Public Member Functions | |
| CublasLtPlanCache ()=default | |
| CublasLtPlanCache (int max_batch_size, PlanBuilder builder) | |
| Construct and eagerly build all plans. | |
| const std::vector< int > & | buckets () const |
| bool | empty () const |
| const TPlan & | get (int batch_size) const |
| Get the plan for the smallest bucket >= batch_size. | |
| size_t | size () const |
Private Attributes | |
| std::vector< int > | buckets_ |
| std::unordered_map< int, TPlan > | cache_ |
Generic plan cache keyed on batch size bucket.
TPlan must be move-constructible (e.g. CublasLtMatMulPlan<T>). Plans are built eagerly at construction time for all buckets.
Usage: CublasLtPlanCache<CublasLtMatMulPlan<float>> cache( max_batch_size, [&]( int bucket ) { return build_my_plan( bucket ); } );
const auto& plan = cache.get( actual_batch_size );
| using Mila::Dnn::Compute::Cuda::CublasLtPlanCache< TPlan >::PlanBuilder = std::function<TPlan( int bucket )> |
|
default |
|
inline |
Construct and eagerly build all plans.
| max_batch_size | Upper bound (training / max seq len) |
| builder | Callable: int bucket -> TPlan |

|
inline |
|
inline |
|
inline |
Get the plan for the smallest bucket >= batch_size.

|
inline |
|
private |
|
private |