|
Mila 0.13.48
Deep Neural Network Library
|
cuBLASLt matmul plan builder for CudaLinearOp. More...
#include <stdexcept>#include <string>#include <cublasLt.h>import CublasLt.Error;import Cuda.Error;import Logging.Logger;import Compute.CudaTensorDataType;import Dnn.TensorDataType;Classes | |
| struct | Mila::Dnn::Compute::Cuda::CublasLtLinearPlan< TComputePrecision, TParameterPrecision > |
| RAII wrapper owning cuBLASLt descriptors for a Linear matmul. More... | |
Namespaces | |
| namespace | Mila |
| Mila main API namespace. | |
| namespace | Mila::Dnn |
| namespace | Mila::Dnn::Compute |
| namespace | Mila::Dnn::Compute::Cuda |
Functions | |
| template<TensorDataType TComputePrecision, TensorDataType TParameterPrecision = TComputePrecision> | |
| CublasLtLinearPlan< TComputePrecision, TParameterPrecision > | Mila::Dnn::Compute::Cuda::build_linear_plan (cublasLtHandle_t handle, int outer_size, int in_features, int out_features, bool has_bias, cublasComputeType_t compute_type, cudaDataType_t scale_type, const float *weight_scale=nullptr) |
| Build a cuBLASLt plan for a Linear matmul. | |
| template<TensorDataType TComputePrecision, TensorDataType TParameterPrecision = TComputePrecision> | |
| void | Mila::Dnn::Compute::Cuda::execute_linear_plan (cublasLtHandle_t handle, const CublasLtLinearPlan< TComputePrecision, TParameterPrecision > &plan, const float *alpha, const void *A, const void *B, const float *beta, typename CublasLtLinearPlan< TComputePrecision, TParameterPrecision >::ActivationType *C, const typename CublasLtLinearPlan< TComputePrecision, TParameterPrecision >::ActivationType *bias, const float *weight_scale, cudaStream_t stream, void *workspace=nullptr, size_t workspace_size=0) |
| Execute a previously-built CublasLtLinearPlan. | |
cuBLASLt matmul plan builder for CudaLinearOp.
Provides a mixed-precision-aware plan struct and build/execute functions for the Linear op only. For now, GQA and MHA continue to use the existing build_plan / build_strided_plan / execute_plan path untouched.
Template parameters follow Mila conventions: TComputePrecision - activation and output element type (e.g. BF16) TParameterPrecision - weight element type (e.g. FP8_E4M3); defaults to TComputePrecision TAccumPrecision - accumulator and scale type; always float
When TParameterPrecision == TComputePrecision (non-quantized):
When TParameterPrecision != TComputePrecision (quantized, e.g. BF16 x FP8):
Caller contract: compute_type and scale_type are supplied by CudaLinearOp::getComputeTypes() and must be consistent with TComputePrecision and the active ComputePrecision::Policy. The cudaDataType_t values for A, B, C are derived at compile time from the template parameters via cuda_data_type_v and are not caller-supplied.