Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision > Class Template Referenceabstractexport

Public Types

using DataTypeTraits = TensorDataTypeTraits<TComputePrecision>

Public Member Functions

virtual ~Operation ()=default
virtual void build (const BuildContext &build_context)
 Prepare the operation for a concrete input shape.
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::string getName () const =0
 Human-readable operation name.
virtual OperationType getOperationType () const =0
 Operation type identifier.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setGradients (ITensor *weight_grad, ITensor *bias_grad)
 Bind module-owned gradient tensors to the operation.
virtual void setParameters (ITensor *weight, ITensor *bias)
 Bind module-owned parameter tensors to the operation.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Static Public Attributes

static constexpr TensorDataType data_type = TComputePrecision
static constexpr DeviceType device_type = TDeviceType

Protected Attributes

bool is_built_ { false }
TrainingMode training_mode_ { TrainingMode::Normal }

Member Typedef Documentation

◆ DataTypeTraits

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
using Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::DataTypeTraits = TensorDataTypeTraits<TComputePrecision>

Constructor & Destructor Documentation

◆ ~Operation()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::~Operation ( )
virtualdefault

Member Function Documentation

◆ build()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual void Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::build ( const BuildContext & build_context)
inlinevirtual

Prepare the operation for a concrete input shape.

Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.

Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

Here is the caller graph for this function:

◆ clearGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual void Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::clearGradients ( )
inlinevirtualnoexcept

Clear any cached gradient pointers held by the operation.

Explicit unbind called by modules before freeing/resetting module-owned gradient buffers. Implementations MUST null-out any cached raw pointers and MUST NOT throw. Marked noexcept so it is safe to call from destructors or during state transitions.

◆ getDataType()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual TensorDataType Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::getDataType ( ) const
inlinevirtual

Tensor data type for this operation.

◆ getDeviceType()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual DeviceType Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::getDeviceType ( ) const
inlinevirtual

Device type for this operation.

◆ getName()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual std::string Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::getName ( ) const
pure virtual

Human-readable operation name.

Implemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuCrossEntropyOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuResidualOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MatMulBiasGelu::CudaMatMulBiasGeluOp< TInput, TOutput >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

◆ getOperationType()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual OperationType Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::getOperationType ( ) const
pure virtual

Operation type identifier.

Implemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuGeluOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuResidualOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gelu::CudaGeluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TInputA, TInputB, TPrecision >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Residual::CudaResidualOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TPrecision >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Swiglu::CudaSwigluOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

◆ getStateMemorySize()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual std::size_t Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::getStateMemorySize ( ) const
inlinevirtual

Returns the number of bytes of state memory allocated by this operation.

State memory includes build-time buffers such as caches and scratch allocations. Parameters and gradients are owned at the component level and are not included.

Override in derived operations that allocate device or host state during build().

Reimplemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.

◆ isBuilt()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual bool Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::isBuilt ( ) const
inlinevirtual

Whether build() completed successfully for a concrete input shape.

◆ isEvalMode()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual bool Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::isEvalMode ( ) const
inlinevirtual

Query whether operation is configured for training.

◆ setGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual void Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlinevirtual

Bind module-owned gradient tensors to the operation.

New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().

The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.

Default: no-op for stateless operations.

Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

◆ setParameters()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual void Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::setParameters ( ITensor * weight,
ITensor * bias )
inlinevirtual

Bind module-owned parameter tensors to the operation.

The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.

Default: no-op for stateless operations.

Reimplemented in Mila::Dnn::Compute::CpuAttentionOp, Mila::Dnn::Compute::CpuEncoderOp, Mila::Dnn::Compute::CpuLayerNormOp, Mila::Dnn::Compute::CpuLinearOp, Mila::Dnn::Compute::CpuSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::CpuSoftmaxOp, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::LayerNorm::CudaLayerNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TComputePrecision, TWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerChannelFp8<> >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupFp4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 128 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::BF16, PerGroupInt4< 64 > >, Mila::Dnn::Compute::Cuda::Linear::CudaLinearOp< TensorDataType::FP32, NoWeightQuant >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TensorDataType::INT32, TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TPrecision >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::Softmax::CudaSoftmaxOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::BF16 >, Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TensorDataType::FP32 >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >, Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TensorDataType::INT32, TensorDataType::FP32 >.

◆ setTrainingMode()

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
virtual void Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::setTrainingMode ( TrainingMode training_mode)
inlinevirtual

Configure operation training-mode behavior.

Implementations may use this to enable/disable training-specific work.

Member Data Documentation

◆ data_type

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
TensorDataType Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::data_type = TComputePrecision
staticconstexpr

◆ device_type

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
DeviceType Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::device_type = TDeviceType
staticconstexpr

◆ is_built_

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
bool Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::is_built_ { false }
protected

◆ training_mode_

template<DeviceType TDeviceType, TensorDataType TComputePrecision>
TrainingMode Mila::Dnn::Compute::Operation< TDeviceType, TComputePrecision >::training_mode_ { TrainingMode::Normal }
protected

The documentation for this class was generated from the following file: