Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Component< TDeviceType, TPrecision > Class Template Referenceabstractexport

Abstract base class for neural network components. More...

Inheritance diagram for Mila::Dnn::Component< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::Component< TDeviceType, TPrecision >:

Public Member Functions

 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual DeviceId getDeviceId () const =0
 Get the compute device id associated with this component.
virtual std::vector< ITensor * > getGradients () const =0
 Return non-owning pointers to parameter gradient tensors.
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
virtual std::vector< ITensor * > getParameters () const =0
 Return non-owning pointers to parameter tensors.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual const ComponentType getType () const =0
 Get the component type identifier.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
virtual size_t parameterCount () const =0
 Return number of trainable parameters.
virtual void save_ (ModelArchive &archive, SerializationMode mode) const =0
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void synchronize ()=0
 Wait for outstanding device work submitted by this component.
virtual std::string toString () const =0
 Produce a short, human-readable description of the component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Static Public Member Functions

static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.

Protected Member Functions

IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
virtual void onExecutionContextSet ()
 Lifecycle hook: Called immediately after ExecutionContext is set.
virtual void onTrainingModeChanging (TrainingMode mode)
 Hook called before TrainingMode transitions.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Protected Attributes

BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Private Member Functions

void ensureBuilt (const char *method) const
 Throws if the component has not yet been built.

Static Private Member Functions

static bool isIdentifier (const std::string &s) noexcept
 Checks if a string is a valid component identifier.
static const std::string & validateName (const std::string &name)
 Validates the component name.

Private Attributes

bool built_ { false }
IExecutionContextexec_context_ { nullptr }
std::string name_
TrainingMode training_mode_ { TrainingMode::Normal }
std::mutex training_mode_mutex_

Friends

template<DeviceType, TensorDataType>
class CompositeComponent
template<DeviceType, TensorDataType>
class Network
std::ostream & operator<< (std::ostream &os, const Component &component)
 Stream output uses toString() to provide a human-readable description of the component.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Component< TDeviceType, TPrecision >

Abstract base class for neural network components.

Component enforces a single ownership model: all components receive a non-owning pointer to an IExecutionContext that is owned by the parent (Network or test fixture). This ensures consistent resource sharing and eliminates dual constructor patterns.

Ownership model

  • Components NEVER own ExecutionContext.
  • Parent (Network/CompositeComponent) owns and provides shared context.
  • Tests explicitly create context and pass raw pointer to components.

Component build lifecycle

Components progress through a well-defined lifecycle. Each stage has a single responsibility and a designated hook for subclass extension.

Stage 1 — Construction

The component is constructed with its name and component config. No device resources are allocated. No ExecutionContext is required.

auto linear = std::make_unique<Linear>( "fc", config, context );

Stage 2 — Build [ onBuilding() ]

build() is called with a BuildContext carrying the leading shape { B, T, ... } and the ExecutionMode that governs buffer allocation.

Mode allocationSeqLen() Gradient buffers
Inference 1 never allocated
Training leading_shape[1] allocated on demand

Allocated in onBuilding():

  • output_ forward output buffer sized by allocationSeqLen()
  • decode_output_ decode output buffer (decode-capable components)
  • kv_cache_ KV cache (decode-capable components)
  • operation_ buffers via operation_->build()

NOT allocated in onBuilding():

  • gradient buffers (deferred to first setEvaluation( false ))
  • backward state (deferred to first setEvaluation( false ))

The same BuildContext is cascaded unchanged through CompositeComponent and Network to all child components.

BuildContext config( shape_t{ batch_size, seq_length } );
config.withExecutionMode( ExecutionMode::Training );
model->build( config );
Build-time context for Component::build().
Definition Component.BuildContext.ixx:56
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

Stage 3 — Evaluation mode [ onEvaluationChanging() ]

Only valid for Training-built components. setEvaluation( false ) triggers gradient buffer allocation on the first call. Subsequent calls zero existing buffers without reallocating. setEvaluation( true ) zeros gradient buffers and disables the backward path.

Allocated on first setEvaluation( false ):

  • input_grad_ input gradient buffer
  • weight_grad_ weight gradient buffer
  • bias_grad_ bias gradient buffer
model->setEvaluation( true ); // suspend backward — eval checkpoint
generateSample( model );
model->setEvaluation( false ); // resume training

Stage 4 — Forward / Decode / Backward

Runtime dimensions are read from the input tensor shape on each call. No shape information is cached from build time beyond what is in build_config_.

Lifecycle invariants

build() requires ExecutionContext to be set setEvaluation() requires build() to have completed setEvaluation() requires ExecutionMode::Training forward() requires build() to have completed backward() requires isTrainingMode() == true decode() requires build() to have completed

Base class provides

Template Parameters
TDeviceTypeCompile-time device identifier for this component.
TPrecisionTensor data precision for this component.
Trusted Collaborators
This class grants private access to:
  • CompositeComponent: Parent components that manage child execution contexts and aggregate parameters from child components.
  • Network: Top-level graph coordinator that collects parameters and gradients for optimization and serialization.
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx.

Constructor & Destructor Documentation

◆ Component()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Component< TDeviceType, TPrecision >::Component ( const std::string & name)
inlineexplicitexport

Construct component with required name identifier.

The name is used for identification, logging, and serialization. Names must be valid identifiers: start with a letter, contain only letters, digits, '.', '_', '-', and be 1-128 characters long.

Parameters
nameComponent name identifier (mandatory).
Exceptions
std::invalid_argumentif name is not a valid identifier.

◆ ~Component()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual Mila::Dnn::Component< TDeviceType, TPrecision >::~Component ( )
exportvirtualdefault

Member Function Documentation

◆ build()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::build ( const BuildContext & context)
inlinefinalexportvirtual

Build the component with the provided BuildContext (canonical overload).

Validates the config, stores it as build_config_, then invokes the onBuilding() hook for component-specific buffer allocation and initialization.

After onBuilding() returns without throwing, the component is marked built and isBuilt() returns true. If onBuilding() throws, built_ remains false and build() may be retried — but only if the onBuilding() implementation leaves component state coherent on failure.

The stored BuildContext is accessible to derived classes via the protected build_config_ member throughout the component lifetime.

Parameters
configBuild-time configuration carrying the leading shape { B, T, ... }, ExecutionMode, and optional micro-batching settings.
Exceptions
std::runtime_errorif the component is already built.
std::runtime_errorif no ExecutionContext has been set.
std::invalid_argumentif config.validate() fails.
Anyexception from onBuilding().

◆ ensureBuilt()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Component< TDeviceType, TPrecision >::ensureBuilt ( const char * method) const
inlineexportprivate

Throws if the component has not yet been built.

Used as a precondition guard in public methods that require build() to have completed.

Parameters
methodCaller name for the error message.
Exceptions
std::runtime_errorif !built_.
Here is the caller graph for this function:

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual DeviceId Mila::Dnn::Component< TDeviceType, TPrecision >::getDeviceId ( ) const
exportpure virtual

◆ getDeviceType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
constexpr DeviceType Mila::Dnn::Component< TDeviceType, TPrecision >::getDeviceType ( )
inlinestaticconstexprexport

Compile-time device type for this component instance.

Here is the caller graph for this function:

◆ getExecutionContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
IExecutionContext * Mila::Dnn::Component< TDeviceType, TPrecision >::getExecutionContext ( ) const
inlineexportprotected

Get the shared execution context.

Provides access to the execution context for derived classes to:

  • Query device information
  • Create tensors on the correct device
  • Pass to backend operations
  • Synchronize device work
Returns
Non-owning pointer to execution context (guaranteed non-null).
Exceptions
std::runtime_errorif context has not been set.
Here is the caller graph for this function:

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::vector< ITensor * > Mila::Dnn::Component< TDeviceType, TPrecision >::getGradients ( ) const
exportpure virtual

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual MemoryStats Mila::Dnn::Component< TDeviceType, TPrecision >::getMemoryStats ( ) const
exportpure virtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implemented in Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

◆ getName()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const std::string Mila::Dnn::Component< TDeviceType, TPrecision >::getName ( ) const
inlineexport

Get the component's name identifier.

The name is used for logging, diagnostics, and serialization.

Returns
Component full hierarchical path.
Here is the caller graph for this function:

◆ getParameterNames()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::vector< std::string > Mila::Dnn::Component< TDeviceType, TPrecision >::getParameterNames ( ) const
inlineexportvirtual

List all available parameter names for this component.

Returns an empty vector by default. Leaf components with parameters should override to return their canonical parameter name list in the same stable order used by save_() and loadParameter().

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::vector< ITensor * > Mila::Dnn::Component< TDeviceType, TPrecision >::getParameters ( ) const
exportpure virtual

◆ getPrecision()

template<DeviceType TDeviceType, TensorDataType TPrecision>
constexpr TensorDataType Mila::Dnn::Component< TDeviceType, TPrecision >::getPrecision ( )
inlinestaticconstexprexportnoexcept

Compile-time tensor precision for this component instance.

◆ getRuntimeMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
RuntimeMode Mila::Dnn::Component< TDeviceType, TPrecision >::getRuntimeMode ( ) const
inlineexportnoexcept

Convenience accessor — true if currently in Eval mode.

Equivalent to getTrainingMode() == TrainingMode::Eval. Valid for both RuntimeMode::Inference and RuntimeMode::Training built components.

Returns
true if in Eval mode.

◆ getTrainingMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TrainingMode Mila::Dnn::Component< TDeviceType, TPrecision >::getTrainingMode ( ) const
inlineexportnoexcept

The current runtime behavioral mode of this Component.

Returns the current TrainingMode for Components built with RuntimeMode::Training. For Components built with RuntimeMode::Inference the return value is always TrainingMode::Eval — inference components never compute gradients.

Returns
Current TrainingMode.

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual const ComponentType Mila::Dnn::Component< TDeviceType, TPrecision >::getType ( ) const
exportpure virtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implemented in Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

◆ hasExecutionContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Component< TDeviceType, TPrecision >::hasExecutionContext ( ) const
inlineexportprotectednoexcept

Check if execution context has been set.

Returns
true if context is set, false otherwise.
Here is the caller graph for this function:

◆ isBuilt()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual bool Mila::Dnn::Component< TDeviceType, TPrecision >::isBuilt ( ) const
inlinefinalexportvirtual

Returns true if build() has completed successfully.

Here is the caller graph for this function:

◆ isIdentifier()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Component< TDeviceType, TPrecision >::isIdentifier ( const std::string & s)
inlinestaticexportprivatenoexcept

Checks if a string is a valid component identifier.

Rules: start with A-Za-z, then allow A-Za-z0-9._- ; length 1..128.

Parameters
sString to check.
Returns
true if valid identifier, false otherwise.
Here is the caller graph for this function:

◆ isInferenceMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Component< TDeviceType, TPrecision >::isInferenceMode ( ) const
inlineexportnoexcept
Here is the caller graph for this function:

◆ isTrainingMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Component< TDeviceType, TPrecision >::isTrainingMode ( ) const
inlineexportnoexcept
Here is the caller graph for this function:

◆ loadParameter()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::loadParameter ( const std::string & name,
const Serialization::ITensorBlob & blob )
inlineexportvirtual

Load a parameter from serialized tensor data.

Loads raw tensor bytes directly into an existing parameter tensor, handling precision conversion and device upload as needed.

The component validates that the blob's shape matches the parameter's expected shape, then delegates to the backend to perform:

  • Precision conversion (blob dtype → parameter dtype)
  • Device upload (CPU bytes → target device)
Parameters
nameParameter name used to locate the target tensor.
blobSerialized tensor metadata and raw bytes.
Exceptions
std::runtime_errorif component has no parameters to load.
std::runtime_errorif blob shape doesn't match parameter shape.

Reimplemented in Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

◆ loadParameterFromBlob()

template<DeviceType TDeviceType, TensorDataType TPrecision>
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void Mila::Dnn::Component< TDeviceType, TPrecision >::loadParameterFromBlob ( const std::string & param_name,
const Serialization::ITensorBlob & blob,
Tensor< TParameterPrecision, TMemoryResource > & target,
const shape_t & expected_shape )
inlineexportprotected

Load a tensor blob into a parameter tensor with validation.

Validates dtype and shape match then copies blob data into the tensor. Intended for use in loadParameter() overrides.

Parameters
param_nameParameter name (used in error messages).
blobSource tensor blob from the model archive.
targetDestination tensor (must be initialized).
expected_shapeExpected tensor shape for validation.
Exceptions
std::invalid_argumentif dtype or shape mismatch.
Here is the caller graph for this function:

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::onBuilding ( const BuildContext & config)
inlineexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
const std::string getName() const
Get the component's name identifier.
Definition Component.ixx:410

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented in Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

Here is the caller graph for this function:

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::onExecutionContextSet ( )
inlineexportprotectedvirtual

Lifecycle hook: Called immediately after ExecutionContext is set.

Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.

Common uses:

  • Composite components: Create and configure child components.
  • Device resource allocation: Query device capabilities.

Default implementation does nothing.

Exceptions
Anyexception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state.

Reimplemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

Here is the caller graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::onTrainingModeChanging ( TrainingMode mode)
inlineexportprotectedvirtual

Hook called before TrainingMode transitions.

Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.

The default implementation is a no-op.

Parameters
modeThe incoming TrainingMode.

Reimplemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

Here is the caller graph for this function:

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual size_t Mila::Dnn::Component< TDeviceType, TPrecision >::parameterCount ( ) const
exportpure virtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
exportpure virtual

Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

◆ setExecutionContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Component< TDeviceType, TPrecision >::setExecutionContext ( IExecutionContext * context)
inlineexportprotected

Set the execution context for this component.

Establishes the device and execution environment. Can only be called once — the execution context is immutable after setting.

Called by:

  • The component itself (standalone mode with owned context)
  • Parent composite when adding child (shared context mode)
  • ComponentFactory during deserialization

After setting the context, the onExecutionContextSet() hook is invoked to allow the component to perform context-dependent initialization.

Parameters
contextNon-owning pointer to execution context (must be non-null).
Exceptions
std::invalid_argumentif context is null.
std::runtime_errorif context has already been set.
std::invalid_argumentif context device type doesn't match TDeviceType.
std::runtime_errorif onExecutionContextSet() throws; context is restored to nullptr on failure.
Here is the caller graph for this function:

◆ setTrainingMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Component< TDeviceType, TPrecision >::setTrainingMode ( TrainingMode mode)
inlineexport

Set the runtime behavioral mode for this Component.

Toggles between Training and Eval behavioral states at runtime. Only valid on Components built with RuntimeMode::Training — throws if called on a Component built with RuntimeMode::Inference.

State transitions

From To Effect
Training Eval Gradients off, dropout off, running stats
Eval Training Gradients on, dropout on, batch stats

Derived classes respond to the transition via the onTrainingModeChanging() hook, called before the state is updated.

Parameters
modeTrainingMode::Normal or TrainingMode::Eval.
Exceptions
std::runtime_errorif the component is not built.
std::runtime_errorif built with RuntimeMode::Inference.
Here is the caller graph for this function:

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Component< TDeviceType, TPrecision >::synchronize ( )
exportpure virtual

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::string Mila::Dnn::Component< TDeviceType, TPrecision >::toString ( ) const
exportpure virtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implemented in Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cpu, TInput, TOutput >, Mila::Dnn::Dropout< DeviceType::Cuda, TInput, TOutput >, Mila::Dnn::FusedComponent< TDeviceType, TPrecision >, Mila::Dnn::Gelu< TDeviceType, TPrecision >, Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, Mila::Dnn::GroupedQueryAttention< TDeviceType, TPrecision, TKvPolicy >, Mila::Dnn::LayerNorm< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TComputePrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuant >, Mila::Dnn::Linear< TDeviceType, TPrecision, TWeightQuantization >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >, Mila::Dnn::Lpe< TDeviceType, dtype_t::INT32, TPrecision >, Mila::Dnn::MLP< TDeviceType, TPrecision >, Mila::Dnn::MultiHeadAttention< TDeviceType, TPrecision >, Mila::Dnn::Network< TDeviceType, TPrecision >, Mila::Dnn::Residual< TDeviceType, TPrecision >, Mila::Dnn::RmsNorm< TDeviceType, TPrecision >, Mila::Dnn::Rope< TDeviceType, TPrecision >, Mila::Dnn::Softmax< TDeviceType, TPrecision >, Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >, Mila::Dnn::Swiglu< TDeviceType, TPrecision >, Mila::Dnn::TokenEmbedding< TDeviceType, TIndex, TPrecision >, and Mila::Dnn::TokenEmbedding< TDeviceType, dtype_t::INT32, TPrecision >.

Here is the caller graph for this function:

◆ validateName()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const std::string & Mila::Dnn::Component< TDeviceType, TPrecision >::validateName ( const std::string & name)
inlinestaticexportprivate

Validates the component name.

Enforces identifier rules: must start with a letter and contain only letters, digits, '.', '_', '-' with length between 1 and 128 characters.

Parameters
nameName to validate.
Exceptions
std::invalid_argumentif name is not a valid identifier.

◆ zeroGradients()

◆ operator<<

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::ostream & operator<< ( std::ostream & os,
const Component< TDeviceType, TPrecision > & component )
friend

Stream output uses toString() to provide a human-readable description of the component.


The documentation for this class was generated from the following file: