|
Mila 0.13.48
Deep Neural Network Library
|
Public Types | |
| using | DataTypeTraits = TensorDataTypeTraits<TComputePrecision> |
Public Member Functions | |
| virtual | ~Operation ()=default |
| virtual void | build (const BuildContext &build_context) |
| Prepare the operation for a concrete input shape. | |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual std::string | getName () const =0 |
| Human-readable operation name. | |
| virtual OperationType | getOperationType () const =0 |
| Operation type identifier. | |
| virtual std::size_t | getStateMemorySize () const |
| Returns the number of bytes of state memory allocated by this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setGradients (ITensor *weight_grad, ITensor *bias_grad) |
| Bind module-owned gradient tensors to the operation. | |
| virtual void | setParameters (ITensor *weight, ITensor *bias) |
| Bind module-owned parameter tensors to the operation. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
Static Public Attributes | |
| static constexpr TensorDataType | data_type = TComputePrecision |
| static constexpr DeviceType | device_type = TDeviceType |
Protected Attributes | |
| bool | is_built_ { false } |
| TrainingMode | training_mode_ { TrainingMode::Normal } |
| using Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::DataTypeTraits = TensorDataTypeTraits<TComputePrecision> |
|
virtualdefault |
|
inlinevirtual |
Prepare the operation for a concrete input shape.
Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.
Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

|
inlinevirtualnoexcept |
Clear any cached gradient pointers held by the operation.
Explicit unbind called by modules before freeing/resetting module-owned gradient buffers. Implementations MUST null-out any cached raw pointers and MUST NOT throw. Marked noexcept so it is safe to call from destructors or during state transitions.
|
inlinevirtual |
Tensor data type for this operation.
|
inlinevirtual |
Device type for this operation.
|
pure virtual |
Human-readable operation name.
Implemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuCrossEntropyOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuResidualOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MatMulBiasGelu::CudaMatMulBiasGeluOp< TInput, TOutput >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.
|
pure virtual |
Operation type identifier.
Implemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuResidualOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.
|
inlinevirtual |
Returns the number of bytes of state memory allocated by this operation.
State memory includes build-time buffers such as caches and scratch allocations. Parameters and gradients are owned at the component level and are not included.
Override in derived operations that allocate device or host state during build().
Reimplemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.
|
inlinevirtual |
Whether build() completed successfully for a concrete input shape.
|
inlinevirtual |
Query whether operation is configured for training.
|
inlinevirtual |
Bind module-owned gradient tensors to the operation.
New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().
The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.
Default: no-op for stateless operations.
Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.
|
inlinevirtual |
Bind module-owned parameter tensors to the operation.
The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.
Default: no-op for stateless operations.
Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.
|
inlinevirtual |
Configure operation training-mode behavior.
Implementations may use this to enable/disable training-specific work.
|
staticconstexpr |
|
staticconstexpr |
|
protected |
|
protected |