Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
CublasLtLinearPlan.ixx File Reference

cuBLASLt matmul plan builder for CudaLinearOp. More...

#include <stdexcept>
#include <string>
#include <cublasLt.h>
import CublasLt.Error;
import Cuda.Error;
import Logging.Logger;
import Compute.CudaTensorDataType;
import Dnn.TensorDataType;

Classes

struct  Mila::Dnn::Compute::Cuda::CublasLtLinearPlan< TComputePrecision, TParameterPrecision >
 RAII wrapper owning cuBLASLt descriptors for a Linear matmul. More...

Namespaces

namespace  Mila
 Mila main API namespace.
namespace  Mila::Dnn
namespace  Mila::Dnn::Compute
namespace  Mila::Dnn::Compute::Cuda

Functions

template<TensorDataType TComputePrecision, TensorDataType TParameterPrecision = TComputePrecision>
CublasLtLinearPlan< TComputePrecision, TParameterPrecision > Mila::Dnn::Compute::Cuda::build_linear_plan (cublasLtHandle_t handle, int outer_size, int in_features, int out_features, bool has_bias, cublasComputeType_t compute_type, cudaDataType_t scale_type, const float *weight_scale=nullptr)
 Build a cuBLASLt plan for a Linear matmul.
template<TensorDataType TComputePrecision, TensorDataType TParameterPrecision = TComputePrecision>
void Mila::Dnn::Compute::Cuda::execute_linear_plan (cublasLtHandle_t handle, const CublasLtLinearPlan< TComputePrecision, TParameterPrecision > &plan, const float *alpha, const void *A, const void *B, const float *beta, typename CublasLtLinearPlan< TComputePrecision, TParameterPrecision >::ActivationType *C, const typename CublasLtLinearPlan< TComputePrecision, TParameterPrecision >::ActivationType *bias, const float *weight_scale, cudaStream_t stream, void *workspace=nullptr, size_t workspace_size=0)
 Execute a previously-built CublasLtLinearPlan.

Detailed Description

cuBLASLt matmul plan builder for CudaLinearOp.

Provides a mixed-precision-aware plan struct and build/execute functions for the Linear op only. For now, GQA and MHA continue to use the existing build_plan / build_strided_plan / execute_plan path untouched.

Template parameters follow Mila conventions: TComputePrecision - activation and output element type (e.g. BF16) TParameterPrecision - weight element type (e.g. FP8_E4M3); defaults to TComputePrecision TAccumPrecision - accumulator and scale type; always float

When TParameterPrecision == TComputePrecision (non-quantized):

  • layoutA, layoutB, layoutC all use the same cudaDataType_t
  • has_per_channel_scale is false
  • execute_linear_plan ignores the per_channel_scale pointer

When TParameterPrecision != TComputePrecision (quantized, e.g. BF16 x FP8):

  • layoutA and layoutC use cuda_data_type_v<TComputePrecision>
  • layoutB uses cuda_data_type_v<TParameterPrecision>
  • has_per_channel_scale is true
  • execute_linear_plan sets CUBLASLT_MATMUL_DESC_B_SCALE_POINTER at execution time

Caller contract: compute_type and scale_type are supplied by CudaLinearOp::getComputeTypes() and must be consistent with TComputePrecision and the active ComputePrecision::Policy. The cudaDataType_t values for A, B, C are derived at compile time from the template parameters via cuda_data_type_v and are not caller-supplied.