|
Mila 0.13.48
Deep Neural Network Library
|
Abstract base class for neural network components. More...


Public Member Functions | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| virtual DeviceId | getDeviceId () const =0 |
| Get the compute device id associated with this component. | |
| virtual std::vector< ITensor * > | getGradients () const =0 |
| Return non-owning pointers to parameter gradient tensors. | |
| virtual MemoryStats | getMemoryStats () const =0 |
| Return the current memory allocation breakdown for this component. | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| virtual std::vector< ITensor * > | getParameters () const =0 |
| Return non-owning pointers to parameter tensors. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual const ComponentType | getType () const =0 |
| Get the component type identifier. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| virtual size_t | parameterCount () const =0 |
| Return number of trainable parameters. | |
| virtual void | save_ (ModelArchive &archive, SerializationMode mode) const =0 |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
| virtual void | synchronize ()=0 |
| Wait for outstanding device work submitted by this component. | |
| virtual std::string | toString () const =0 |
| Produce a short, human-readable description of the component. | |
| virtual void | zeroGradients () |
| Clear all model-owned gradients for this component. | |
Static Public Member Functions | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
Protected Member Functions | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| template<TensorDataType TParameterPrecision, typename TMemoryResource> | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| virtual void | onBuilding (const BuildContext &config) |
| Hook invoked by build() to allocate component buffers. | |
| virtual void | onExecutionContextSet () |
| Lifecycle hook: Called immediately after ExecutionContext is set. | |
| virtual void | onTrainingModeChanging (TrainingMode mode) |
| Hook called before TrainingMode transitions. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Protected Attributes | |
| BuildContext | build_context_ { shape_t{ 1 }, RuntimeMode::Training } |
| The BuildContext stored at build time. | |
Private Member Functions | |
| void | ensureBuilt (const char *method) const |
| Throws if the component has not yet been built. | |
Static Private Member Functions | |
| static bool | isIdentifier (const std::string &s) noexcept |
| Checks if a string is a valid component identifier. | |
| static const std::string & | validateName (const std::string &name) |
| Validates the component name. | |
Private Attributes | |
| bool | built_ { false } |
| IExecutionContext * | exec_context_ { nullptr } |
| std::string | name_ |
| TrainingMode | training_mode_ { TrainingMode::Normal } |
| std::mutex | training_mode_mutex_ |
Friends | |
| template<DeviceType, TensorDataType> | |
| class | CompositeComponent |
| template<DeviceType, TensorDataType> | |
| class | Network |
| std::ostream & | operator<< (std::ostream &os, const Component &component) |
| Stream output uses toString() to provide a human-readable description of the component. | |
Abstract base class for neural network components.
Component enforces a single ownership model: all components receive a non-owning pointer to an IExecutionContext that is owned by the parent (Network or test fixture). This ensures consistent resource sharing and eliminates dual constructor patterns.
Components progress through a well-defined lifecycle. Each stage has a single responsibility and a designated hook for subclass extension.
The component is constructed with its name and component config. No device resources are allocated. No ExecutionContext is required.
build() is called with a BuildContext carrying the leading shape { B, T, ... } and the ExecutionMode that governs buffer allocation.
| Mode | allocationSeqLen() | Gradient buffers |
|---|---|---|
| Inference | 1 | never allocated |
| Training | leading_shape[1] | allocated on demand |
Allocated in onBuilding():
NOT allocated in onBuilding():
The same BuildContext is cascaded unchanged through CompositeComponent and Network to all child components.
Only valid for Training-built components. setEvaluation( false ) triggers gradient buffer allocation on the first call. Subsequent calls zero existing buffers without reallocating. setEvaluation( true ) zeros gradient buffers and disables the backward path.
Allocated on first setEvaluation( false ):
Runtime dimensions are read from the input tensor shape on each call. No shape information is cached from build time beyond what is in build_config_.
build() requires ExecutionContext to be set setEvaluation() requires build() to have completed setEvaluation() requires ExecutionMode::Training forward() requires build() to have completed backward() requires isTrainingMode() == true decode() requires build() to have completed
| TDeviceType | Compile-time device identifier for this component. |
| TPrecision | Tensor data precision for this component. |
|
inlineexplicitexport |
Construct component with required name identifier.
The name is used for identification, logging, and serialization. Names must be valid identifiers: start with a letter, contain only letters, digits, '.', '_', '-', and be 1-128 characters long.
| name | Component name identifier (mandatory). |
| std::invalid_argument | if name is not a valid identifier. |
|
exportvirtualdefault |
|
inlinefinalexportvirtual |
Build the component with the provided BuildContext (canonical overload).
Validates the config, stores it as build_config_, then invokes the onBuilding() hook for component-specific buffer allocation and initialization.
After onBuilding() returns without throwing, the component is marked built and isBuilt() returns true. If onBuilding() throws, built_ remains false and build() may be retried — but only if the onBuilding() implementation leaves component state coherent on failure.
The stored BuildContext is accessible to derived classes via the protected build_config_ member throughout the component lifetime.
| config | Build-time configuration carrying the leading shape { B, T, ... }, ExecutionMode, and optional micro-batching settings. |
| std::runtime_error | if the component is already built. |
| std::runtime_error | if no ExecutionContext has been set. |
| std::invalid_argument | if config.validate() fails. |
| Any | exception from onBuilding(). |
|
inlineexportprivate |
Throws if the component has not yet been built.
Used as a precondition guard in public methods that require build() to have completed.
| method | Caller name for the error message. |
| std::runtime_error | if !built_. |

|
exportpure virtual |
Get the compute device id associated with this component.
Must return the device on which parameters and operations execute.
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlinestaticconstexprexport |
Compile-time device type for this component instance.

|
inlineexportprotected |
Get the shared execution context.
Provides access to the execution context for derived classes to:
| std::runtime_error | if context has not been set. |

|
exportpure virtual |
Return non-owning pointers to parameter gradient tensors.
Only valid when isTraining() is true.
| std::runtime_error | if called when not in training mode or before the component has been built. |
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
exportpure virtual |
Return the current memory allocation breakdown for this component.
Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:
After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients
For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.
May be called at any time — no lifecycle preconditions.
Implemented in Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlineexport |
Get the component's name identifier.
The name is used for logging, diagnostics, and serialization.

|
inlineexportvirtual |
List all available parameter names for this component.
Returns an empty vector by default. Leaf components with parameters should override to return their canonical parameter name list in the same stable order used by save_() and loadParameter().
|
exportpure virtual |
Return non-owning pointers to parameter tensors.
The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlinestaticconstexprexportnoexcept |
Compile-time tensor precision for this component instance.
|
inlineexportnoexcept |
Convenience accessor — true if currently in Eval mode.
Equivalent to getTrainingMode() == TrainingMode::Eval. Valid for both RuntimeMode::Inference and RuntimeMode::Training built components.
|
inlineexportnoexcept |
The current runtime behavioral mode of this Component.
Returns the current TrainingMode for Components built with RuntimeMode::Training. For Components built with RuntimeMode::Inference the return value is always TrainingMode::Eval — inference components never compute gradients.
|
exportpure virtual |
Get the component type identifier.
Used for serialization and runtime type identification.
Implemented in Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlineexportprotectednoexcept |
Check if execution context has been set.

|
inlinefinalexportvirtual |
|
inlinestaticexportprivatenoexcept |
Checks if a string is a valid component identifier.
Rules: start with A-Za-z, then allow A-Za-z0-9._- ; length 1..128.
| s | String to check. |

|
inlineexportnoexcept |

|
inlineexportnoexcept |

|
inlineexportvirtual |
Load a parameter from serialized tensor data.
Loads raw tensor bytes directly into an existing parameter tensor, handling precision conversion and device upload as needed.
The component validates that the blob's shape matches the parameter's expected shape, then delegates to the backend to perform:
| name | Parameter name used to locate the target tensor. |
| blob | Serialized tensor metadata and raw bytes. |
| std::runtime_error | if component has no parameters to load. |
| std::runtime_error | if blob shape doesn't match parameter shape. |
Reimplemented in Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlineexportprotected |
Load a tensor blob into a parameter tensor with validation.
Validates dtype and shape match then copies blob data into the tensor. Intended for use in loadParameter() overrides.
| param_name | Parameter name (used in error messages). |
| blob | Source tensor blob from the model archive. |
| target | Destination tensor (must be initialized). |
| expected_shape | Expected tensor shape for validation. |
| std::invalid_argument | if dtype or shape mismatch. |

|
inlineexportprotectedvirtual |
Hook invoked by build() to allocate component buffers.
Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.
The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.
| config | Build-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension. |
Reimplemented in Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

|
inlineexportprotectedvirtual |
Lifecycle hook: Called immediately after ExecutionContext is set.
Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.
Common uses:
Default implementation does nothing.
| Any | exception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state. |
Reimplemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

|
inlineexportprotectedvirtual |
Hook called before TrainingMode transitions.
Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.
The default implementation is a no-op.
| mode | The incoming TrainingMode. |
Reimplemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

|
exportpure virtual |
Return number of trainable parameters.
For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
exportpure virtual |
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
inlineexportprotected |
Set the execution context for this component.
Establishes the device and execution environment. Can only be called once — the execution context is immutable after setting.
Called by:
After setting the context, the onExecutionContextSet() hook is invoked to allow the component to perform context-dependent initialization.
| context | Non-owning pointer to execution context (must be non-null). |
| std::invalid_argument | if context is null. |
| std::runtime_error | if context has already been set. |
| std::invalid_argument | if context device type doesn't match TDeviceType. |
| std::runtime_error | if onExecutionContextSet() throws; context is restored to nullptr on failure. |

|
inlineexport |
Set the runtime behavioral mode for this Component.
Toggles between Training and Eval behavioral states at runtime. Only valid on Components built with RuntimeMode::Training — throws if called on a Component built with RuntimeMode::Inference.
| From | To | Effect |
|---|---|---|
| Training | Eval | Gradients off, dropout off, running stats |
| Eval | Training | Gradients on, dropout on, batch stats |
Derived classes respond to the transition via the onTrainingModeChanging() hook, called before the state is updated.
| mode | TrainingMode::Normal or TrainingMode::Eval. |
| std::runtime_error | if the component is not built. |
| std::runtime_error | if built with RuntimeMode::Inference. |

|
exportpure virtual |
Wait for outstanding device work submitted by this component.
On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
exportpure virtual |
Produce a short, human-readable description of the component.
Implementations should keep output concise and avoid throwing.
Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

|
inlinestaticexportprivate |
Validates the component name.
Enforces identifier rules: must start with a letter and contain only letters, digits, '.', '_', '-' with length between 1 and 128 characters.
| name | Name to validate. |
| std::invalid_argument | if name is not a valid identifier. |
|
inlineexportvirtual |
Clear all model-owned gradients for this component.
Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.
Reimplemented in Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.
|
friend |
Stream output uses toString() to provide a human-readable description of the component.