Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
CudaLinearOp.Quantize.ixx File Reference

Quantize partition of CudaLinearOp. More...

#include <cstdint>
#include <stdexcept>
#include <format>
#include "Kernels/Quantization/CudaFp8WeightQuantization.cuh"
#include "Kernels/Quantization/CudaFp4WeightQuantization.cuh"
import Serialization.Tensor;
import Dnn.TensorTypes;
import Dnn.ITensor;

Namespaces

namespace  Mila
 Mila main API namespace.
namespace  Mila::Dnn
namespace  Mila::Dnn::Compute
namespace  Mila::Dnn::Compute::Cuda
namespace  Mila::Dnn::Compute::Cuda::Linear
namespace  Mila::Dnn::Compute::Cuda::Linear::Detail

Functions

void Mila::Dnn::Compute::Cuda::Linear::Detail::quantize_fp4_per_group (const Mila::Dnn::Serialization::ITensorBlob &blob, Mila::Dnn::ITensor &weight_out, Mila::Dnn::ITensor &scales_out, const Mila::Dnn::shape_t &expected_shape, int group_size, void *dev_staging, cudaStream_t stream)
 Validate, quantize and upload a BF16 weight blob to packed FP4_E2M1 with per-group float32 scales.
void Mila::Dnn::Compute::Cuda::Linear::Detail::quantize_fp8_per_channel (const Mila::Dnn::Serialization::ITensorBlob &blob, Mila::Dnn::ITensor &weight_out, Mila::Dnn::ITensor &scales_out, const Mila::Dnn::shape_t &expected_shape, void *dev_staging, cudaStream_t stream)
 Validate, quantize and upload a BF16 weight blob to FP8_E4M3 on device.
void Mila::Dnn::Compute::Cuda::Linear::Detail::quantize_fp8_per_tensor (const Mila::Dnn::Serialization::ITensorBlob &blob, Mila::Dnn::ITensor &weight_out, Mila::Dnn::ITensor &scales_out, const Mila::Dnn::shape_t &expected_shape, void *dev_staging, cudaStream_t stream)
 Validate, quantize and upload a BF16 weight blob to FP8_E4M3 with a single per-tensor scale — for the Ada (SM 8.9+) cuBLASLt TN path.

Detailed Description

Quantize partition of CudaLinearOp.

Exports Detail::quantize_fp8_per_channel() as a non-template function compiled by NVCC. This is the critical module-boundary crossing point for the FP8 quantize-on-load path:

Linear::loadParameter (cl.exe, unchanged) → CudaLinearOp::quantize() (class template body, cl.exe instantiation) → Detail::quantize_fp8_per_channel() (non-template; pre-compiled by NVCC) → cuda_quantize_fp8_per_channel() (plain .cu, NVCC)

Because Detail::quantize_fp8_per_channel() is a non-template function, its body is pre-compiled into the NVCC-generated BMI. When cl.exe instantiates CudaLinearOp::quantize() it only needs the declaration of this function — not the body — so no CUDA headers or intrinsics reach cl.exe.