|
Mila 0.13.48
Deep Neural Network Library
|
CUDA Linear operation with compile-time weight quantization policy dispatch. More...


Public Types | |
| using | ComputeType = typename TensorDataTypeMap<TComputePrecision>::device_type |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | TensorType = Tensor<TComputePrecision, MR> |
| using | WeightType = typename TensorDataTypeMap<kWeightDtype>::device_type |
| Public Types inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaLinearOp (IExecutionContext *context, const LinearConfig &config) | |
| ~CudaLinearOp ()=default | |
| void | backward (const TensorType &input, const TensorType &output_grad, TensorType &input_grad) const |
| Backward pass. | |
| void | build (const BuildContext &build_context) override |
| Prepare the operation for a concrete input shape. | |
| void | forward (const TensorType &input, TensorType &output) const |
| Forward pass: output = input * weight^T + bias. | |
| const LinearConfig & | getConfig () const |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| void | quantize (const ITensorBlob &blob, ITensor &weight_out, ITensor &scales_out, const shape_t &expected_shape) |
| Quantize a BF16 host blob to FP8_E4M3 with per-channel FP32 scales. | |
| void | setGradients (ITensor *weight_grad, ITensor *bias_grad) override |
| Bind module-owned gradient tensors to the operation. | |
| void | setParameters (ITensor *weight, ITensor *bias) override |
| Bind module-owned parameter tensors to the operation. | |
| void | setWeightScales (ITensor *scales) |
| Bind the per-channel FP32 weight scale tensor produced by quantize(). | |
| void | setWeightZeroPoints (ITensor *zero_points) |
| Bind the packed INT4 zero-point tensor for per-group asymmetric quantization. | |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual std::size_t | getStateMemorySize () const |
| Returns the number of bytes of state memory allocated by this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
Static Public Attributes | |
| static constexpr bool | kIsPerChannelQuantized = kIsQuantized && TWeightQuant::kPerChannel |
| static constexpr bool | kIsPerGroupQuantized = kIsQuantized && !TWeightQuant::kPerChannel |
| static constexpr bool | kIsQuantized = TWeightQuant::kIsQuantized |
| static constexpr bool | kUseW8A16Gemm = false |
| static constexpr TensorDataType | kWeightDtype |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
Private Member Functions | |
| void | buildCublasLtPlans () |
| cudaDataType_t | getActivationCudaDataType () const |
| void | getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const |
| cudaDataType_t | getWeightCudaDataType () const |
| bool | supportsCuBLASLt () const |
| Returns true if an optimised batch compute path is available. | |
Private Attributes | |
| CublasLtPlanCache< CublasLtMatMulPlan< ComputeType > > | backward_input_plan_cache_ |
| CublasLtMatMulPlan< ComputeType > | backward_weight_plan_ |
| const ComputeType * | bias_ { nullptr } |
| ComputeType * | bias_grad_ { nullptr } |
| cublasLtHandle_t | cached_cublaslt_handle_ { nullptr } |
| int | cached_in_features_ { 0 } |
| int | cached_outer_size_ { 0 } |
| cublasComputeType_t | compute_type_ {} |
| LinearConfig | config_ |
| CudaExecutionContext * | context_ |
| cudaDataType_t | cuda_data_type_ {} |
| cudaDataType_t | cuda_weight_data_type_ {} |
| CublasLtPlanCache< CublasLtLinearPlan< TComputePrecision > > | forward_plan_cache_ |
| int | out_features_ { 0 } |
| cudaDataType_t | scale_type_ {} |
| bool | use_cublaslt_ { false } |
| bool | use_wmma_fp4_gemm_ { false } |
| const WeightType * | weight_ { nullptr } |
| ComputeType * | weight_grad_ { nullptr } |
| int | weight_group_size_ { 128 } |
| int64_t | weight_in_features_ { 0 } |
| int64_t | weight_out_features_ { 0 } |
| const float * | weight_scales_ { nullptr } |
| const uint8_t * | weight_zero_points_ { nullptr } |
Additional Inherited Members | |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
CUDA Linear operation with compile-time weight quantization policy dispatch.
Forward: output = input * weight^T + bias Backward (NoWeightQuant only): input_grad = output_grad * weight weight_grad = output_grad^T * input (accumulated) bias_grad = sum(output_grad, dim=0)
When TWeightQuant = PerChannelFp8<>, weights are stored as FP8_E4M3 with one float32 scale per output channel. quantize() performs the one-time host-side BF16->FP8 conversion at load time.
Forward dispatch on the quantized path: Single vector (outer_size == 1): fused matvec applies FP8 per-channel dequantization inline — optimal for memory-bandwidth-bound single-vector compute. Batch (outer_size > 1): two paths selected by kUseW8A16Gemm: kUseW8A16Gemm=true — fused W8A16 GEMM reads FP8 once, dequantizes per-channel inline in shared memory, writes BF16 output directly (no staging buffer). kUseW8A16Gemm=false — 2-phase: dequantize FP8 → BF16 staging buffer, then standard BF16 cuBLASLt NT GEMM, then cuda_add_bias post-pass.
Backward is not supported on the quantized path (inference only).
| TComputePrecision | Activation and accumulation precision. |
| TWeightQuant | Weight quantization policy. Defaults to NoWeightQuant. |
| using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::ComputeType = typename TensorDataTypeMap<TComputePrecision>::device_type |
| using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::TensorType = Tensor<TComputePrecision, MR> |
| using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::WeightType = typename TensorDataTypeMap<kWeightDtype>::device_type |
|
inline |
|
default |
|
inline |
Backward pass.
Not supported on the quantized path.
| input | Saved forward input. |
| output_grad | Upstream gradient. |
| input_grad | Output: gradient with respect to forward input. |
|
inlineoverridevirtual |
Prepare the operation for a concrete input shape.
Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.
|
inlineprivate |

|
inline |
Forward pass: output = input * weight^T + bias.
Dispatch priority:
|
inlineprivate |

|
inlineprivate |

|
inline |
|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.
|
inlineprivate |

|
inline |
Quantize a BF16 host blob to FP8_E4M3 with per-channel FP32 scales.
Runs once at model load time. Delegates to Detail::quantize_fp8_per_channel() (pre-compiled by NVCC in the :Quantize partition), which performs per-channel absmax scaling and uploads both the FP8 weight tensor and the FP32 scale tensor to device. The BF16 source blob is never retained on device.
|
inlineoverridevirtual |
Bind module-owned gradient tensors to the operation.
New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().
The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.
|
inlineoverridevirtual |
Bind module-owned parameter tensors to the operation.
The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.
|
inline |
Bind the per-channel FP32 weight scale tensor produced by quantize().
Must be called after quantize() and before the first forward(). The scale pointer is stored and passed to the cuBLASLt FP8 matmul descriptor.
| scales | Device tensor of shape [output_features], dtype Float32. |
|
inline |
Bind the packed INT4 zero-point tensor for per-group asymmetric quantization.
Optional — only required for asymmetric INT4 quantization. Pass nullptr (or omit) for symmetric quantization (implicit zero = 8). The tensor layout must match the kernel expectation: [out_features, in_features / (group_size * 2)], dtype UINT8, with two packed INT4 zero values per byte.
| zero_points | Device UINT8 tensor, or nullptr for symmetric. |
|
inlineprivate |
Returns true if an optimised batch compute path is available.
FP8 (kIsPerChannelQuantized): SM >= 8.0 (Ampere+) for both the fused W8A16 GEMM and the 2-phase dequant + cuBLASLt BF16 GEMM baseline. INT4 (kIsPerGroupQuantized): SM >= 8.0 (Ampere+) for BF16. The fused W4A16 kernel reads packed INT4 and dequantizes per-group inline. Non-quantized: requires a cuBLASLt-supported compute type (FP32/FP16/BF16).

|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |