Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant > Class Template Referenceexport

CUDA Linear operation with compile-time weight quantization policy dispatch. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >:

Public Types

using ComputeType = typename TensorDataTypeMap<TComputePrecision>::device_type
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using TensorType = Tensor<TComputePrecision, MR>
using WeightType = typename TensorDataTypeMap<kWeightDtype>::device_type
Public Types inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >
using DataTypeTraits

Public Member Functions

 CudaLinearOp (IExecutionContext *context, const LinearConfig &config)
 ~CudaLinearOp ()=default
void backward (const TensorType &input, const TensorType &output_grad, TensorType &input_grad) const
 Backward pass.
void build (const BuildContext &build_context) override
 Prepare the operation for a concrete input shape.
void forward (const TensorType &input, TensorType &output) const
 Forward pass: output = input * weight^T + bias.
const LinearConfiggetConfig () const
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void quantize (const ITensorBlob &blob, ITensor &weight_out, ITensor &scales_out, const shape_t &expected_shape)
 Quantize a BF16 host blob to FP8_E4M3 with per-channel FP32 scales.
void setGradients (ITensor *weight_grad, ITensor *bias_grad) override
 Bind module-owned gradient tensors to the operation.
void setParameters (ITensor *weight, ITensor *bias) override
 Bind module-owned parameter tensors to the operation.
void setWeightScales (ITensor *scales)
 Bind the per-channel FP32 weight scale tensor produced by quantize().
void setWeightZeroPoints (ITensor *zero_points)
 Bind the packed INT4 zero-point tensor for per-group asymmetric quantization.
Public Member Functions inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Static Public Attributes

static constexpr bool kIsPerChannelQuantized = kIsQuantized && TWeightQuant::kPerChannel
static constexpr bool kIsPerGroupQuantized = kIsQuantized && !TWeightQuant::kPerChannel
static constexpr bool kIsQuantized = TWeightQuant::kIsQuantized
static constexpr bool kUseW8A16Gemm = false
static constexpr TensorDataType kWeightDtype
Static Public Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type

Private Member Functions

void buildCublasLtPlans ()
cudaDataType_t getActivationCudaDataType () const
void getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const
cudaDataType_t getWeightCudaDataType () const
bool supportsCuBLASLt () const
 Returns true if an optimised batch compute path is available.

Private Attributes

CublasLtPlanCache< CublasLtMatMulPlan< ComputeType > > backward_input_plan_cache_
CublasLtMatMulPlan< ComputeTypebackward_weight_plan_
const ComputeTypebias_ { nullptr }
ComputeTypebias_grad_ { nullptr }
cublasLtHandle_t cached_cublaslt_handle_ { nullptr }
int cached_in_features_ { 0 }
int cached_outer_size_ { 0 }
cublasComputeType_t compute_type_ {}
LinearConfig config_
CudaExecutionContextcontext_
cudaDataType_t cuda_data_type_ {}
cudaDataType_t cuda_weight_data_type_ {}
CublasLtPlanCache< CublasLtLinearPlan< TComputePrecision > > forward_plan_cache_
int out_features_ { 0 }
cudaDataType_t scale_type_ {}
bool use_cublaslt_ { false }
bool use_wmma_fp4_gemm_ { false }
const WeightTypeweight_ { nullptr }
ComputeTypeweight_grad_ { nullptr }
int weight_group_size_ { 128 }
int64_t weight_in_features_ { 0 }
int64_t weight_out_features_ { 0 }
const float * weight_scales_ { nullptr }
const uint8_t * weight_zero_points_ { nullptr }

Additional Inherited Members

Protected Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
requires PrecisionSupportedOnDevice<TComputePrecision, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >

CUDA Linear operation with compile-time weight quantization policy dispatch.

Forward: output = input * weight^T + bias Backward (NoWeightQuant only): input_grad = output_grad * weight weight_grad = output_grad^T * input (accumulated) bias_grad = sum(output_grad, dim=0)

When TWeightQuant = PerChannelFp8<>, weights are stored as FP8_E4M3 with one float32 scale per output channel. quantize() performs the one-time host-side BF16->FP8 conversion at load time.

Forward dispatch on the quantized path: Single vector (outer_size == 1): fused matvec applies FP8 per-channel dequantization inline — optimal for memory-bandwidth-bound single-vector compute. Batch (outer_size > 1): two paths selected by kUseW8A16Gemm: kUseW8A16Gemm=true — fused W8A16 GEMM reads FP8 once, dequantizes per-channel inline in shared memory, writes BF16 output directly (no staging buffer). kUseW8A16Gemm=false — 2-phase: dequantize FP8 → BF16 staging buffer, then standard BF16 cuBLASLt NT GEMM, then cuda_add_bias post-pass.

Backward is not supported on the quantized path (inference only).

Template Parameters
TComputePrecisionActivation and accumulation precision.
TWeightQuantWeight quantization policy. Defaults to NoWeightQuant.

Member Typedef Documentation

◆ ComputeType

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::ComputeType = typename TensorDataTypeMap<TComputePrecision>::device_type

◆ CudaExecutionContext

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::MR = CudaDeviceMemoryResource

◆ TensorType

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::TensorType = Tensor<TComputePrecision, MR>

◆ WeightType

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
using Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::WeightType = typename TensorDataTypeMap<kWeightDtype>::device_type

Constructor & Destructor Documentation

◆ CudaLinearOp()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::CudaLinearOp ( IExecutionContext * context,
const LinearConfig & config )
inline

◆ ~CudaLinearOp()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::~CudaLinearOp ( )
default

Member Function Documentation

◆ backward()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::backward ( const TensorType & input,
const TensorType & output_grad,
TensorType & input_grad ) const
inline

Backward pass.

Not supported on the quantized path.

Parameters
inputSaved forward input.
output_gradUpstream gradient.
input_gradOutput: gradient with respect to forward input.

◆ build()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::build ( const BuildContext & build_context)
inlineoverridevirtual

Prepare the operation for a concrete input shape.

Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.

◆ buildCublasLtPlans()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::buildCublasLtPlans ( )
inlineprivate
Here is the caller graph for this function:

◆ forward()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::forward ( const TensorType & input,
TensorType & output ) const
inline

Forward pass: output = input * weight^T + bias.

Dispatch priority:

  1. outer_size == 1: FP8/non-quantized: fused matvec via cuda_matvec_impl. INT4: M=1 tiled W4A16 GEMM (no dedicated decode matvec yet).
  2. outer_size > 1, use_cublaslt_: kIsPerChannelQuantized: fused W8A16 GEMM — reads FP8 weights once, dequantizes per-channel inline in shared memory, bias added in-kernel. kIsPerGroupQuantized: fused W4A16 GEMM — inline per-group INT4 dequant. !kIsQuantized: NT row-major BF16 cuBLASLt GEMM; bias via epilogue.
  3. outer_size > 1, quantized, no cuBLASLt: per-row fallback loop (SM < 8.0 SM < 8.0 or plan build failure).
  4. outer_size > 1, !kIsQuantized, no cuBLASLt: error — non-quantized batch compute always requires cuBLASLt.

◆ getActivationCudaDataType()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cudaDataType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getActivationCudaDataType ( ) const
inlineprivate
Here is the caller graph for this function:

◆ getComputeTypes()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getComputeTypes ( cublasComputeType_t & compute_type,
cudaDataType_t & scale_type ) const
inlineprivate
Here is the caller graph for this function:

◆ getConfig()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const LinearConfig & Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getConfig ( ) const
inline

◆ getName()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
std::string Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getName ( ) const
inlineoverridevirtual

◆ getOperationType()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
OperationType Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getOperationType ( ) const
inlineoverridevirtual

◆ getWeightCudaDataType()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cudaDataType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::getWeightCudaDataType ( ) const
inlineprivate
Here is the caller graph for this function:

◆ quantize()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::quantize ( const ITensorBlob & blob,
ITensor & weight_out,
ITensor & scales_out,
const shape_t & expected_shape )
inline

Quantize a BF16 host blob to FP8_E4M3 with per-channel FP32 scales.

Runs once at model load time. Delegates to Detail::quantize_fp8_per_channel() (pre-compiled by NVCC in the :Quantize partition), which performs per-channel absmax scaling and uploads both the FP8 weight tensor and the FP32 scale tensor to device. The BF16 source blob is never retained on device.

Parameters
blobHost BF16 weight blob from the model archive.
weight_outDevice FP8_E4M3 tensor [out_features, in_features].
scales_outDevice Float32 tensor [out_features].
expected_shapeExpected weight shape for validation.

◆ setGradients()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlineoverridevirtual

Bind module-owned gradient tensors to the operation.

New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().

The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.

◆ setParameters()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::setParameters ( ITensor * weight,
ITensor * bias )
inlineoverridevirtual

Bind module-owned parameter tensors to the operation.

The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TComputePrecision >.

◆ setWeightScales()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::setWeightScales ( ITensor * scales)
inline

Bind the per-channel FP32 weight scale tensor produced by quantize().

Must be called after quantize() and before the first forward(). The scale pointer is stored and passed to the cuBLASLt FP8 matmul descriptor.

Parameters
scalesDevice tensor of shape [output_features], dtype Float32.

◆ setWeightZeroPoints()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
void Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::setWeightZeroPoints ( ITensor * zero_points)
inline

Bind the packed INT4 zero-point tensor for per-group asymmetric quantization.

Optional — only required for asymmetric INT4 quantization. Pass nullptr (or omit) for symmetric quantization (implicit zero = 8). The tensor layout must match the kernel expectation: [out_features, in_features / (group_size * 2)], dtype UINT8, with two packed INT4 zero values per byte.

Parameters
zero_pointsDevice UINT8 tensor, or nullptr for symmetric.

◆ supportsCuBLASLt()

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::supportsCuBLASLt ( ) const
inlineprivate

Returns true if an optimised batch compute path is available.

FP8 (kIsPerChannelQuantized): SM >= 8.0 (Ampere+) for both the fused W8A16 GEMM and the 2-phase dequant + cuBLASLt BF16 GEMM baseline. INT4 (kIsPerGroupQuantized): SM >= 8.0 (Ampere+) for BF16. The fused W4A16 kernel reads packed INT4 and dequantizes per-group inline. Non-quantized: requires a cuBLASLt-supported compute type (FP32/FP16/BF16).

Here is the caller graph for this function:

Member Data Documentation

◆ backward_input_plan_cache_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
CublasLtPlanCache<CublasLtMatMulPlan<ComputeType> > Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::backward_input_plan_cache_
private

◆ backward_weight_plan_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
CublasLtMatMulPlan<ComputeType> Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::backward_weight_plan_
private

◆ bias_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const ComputeType* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::bias_ { nullptr }
private

◆ bias_grad_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
ComputeType* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::bias_grad_ { nullptr }
private

◆ cached_cublaslt_handle_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cublasLtHandle_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::cached_cublaslt_handle_ { nullptr }
private

◆ cached_in_features_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::cached_in_features_ { 0 }
private

◆ cached_outer_size_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::cached_outer_size_ { 0 }
private

◆ compute_type_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cublasComputeType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::compute_type_ {}
private

◆ config_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
LinearConfig Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::config_
private

◆ context_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::context_
private

◆ cuda_data_type_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cudaDataType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::cuda_data_type_ {}
private

◆ cuda_weight_data_type_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cudaDataType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::cuda_weight_data_type_ {}
private

◆ forward_plan_cache_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
CublasLtPlanCache<CublasLtLinearPlan<TComputePrecision> > Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::forward_plan_cache_
private

◆ kIsPerChannelQuantized

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::kIsPerChannelQuantized = kIsQuantized && TWeightQuant::kPerChannel
staticconstexpr

◆ kIsPerGroupQuantized

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::kIsPerGroupQuantized = kIsQuantized && !TWeightQuant::kPerChannel
staticconstexpr

◆ kIsQuantized

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::kIsQuantized = TWeightQuant::kIsQuantized
staticconstexpr

◆ kUseW8A16Gemm

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::kUseW8A16Gemm = false
staticconstexpr

◆ kWeightDtype

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
TensorDataType Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::kWeightDtype
staticconstexpr
Initial value:
? TWeightQuant::kStorageDtype : TComputePrecision
static constexpr bool kIsQuantized
Definition CudaLinearOp.ixx:112

◆ out_features_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::out_features_ { 0 }
private

◆ scale_type_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
cudaDataType_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::scale_type_ {}
private

◆ use_cublaslt_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::use_cublaslt_ { false }
private

◆ use_wmma_fp4_gemm_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
bool Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::use_wmma_fp4_gemm_ { false }
private

◆ weight_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const WeightType* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_ { nullptr }
private

◆ weight_grad_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
ComputeType* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_grad_ { nullptr }
private

◆ weight_group_size_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_group_size_ { 128 }
private

◆ weight_in_features_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int64_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_in_features_ { 0 }
private

◆ weight_out_features_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
int64_t Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_out_features_ { 0 }
private

◆ weight_scales_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const float* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_scales_ { nullptr }
private

◆ weight_zero_points_

template<TensorDataType TComputePrecision, WeightQuantPolicy TWeightQuant = NoWeightQuant>
const uint8_t* Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >::weight_zero_points_ { nullptr }
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Linear/CudaLinearOp.ixx