Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > Class Template Referenceexport

A component that contains and manages child components. More...

Inheritance diagram for Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >:

Public Types

using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
DeviceId getDeviceId () const override
 Get the compute device for this composite.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
void synchronize () override
 Synchronize all child components.
std::string toString () const override
 Generate a human-readable description.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual const ComponentType getType () const =0
 Get the component type identifier.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Protected Member Functions

template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
void save_ (ModelArchive &archive, SerializationMode mode) const override
 Save all child components recursively.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Attributes

std::unordered_map< std::string, ComponentPtr, std::hash< std::string_view >, std::equal_to<> > child_component_map_
 Lookup map from child name to component pointer.
std::vector< ComponentPtrchild_components_
 Child components in insertion order.

Friends

class ComponentFactory

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
class Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >

A component that contains and manages child components.

CompositeComponent is a device-parameterized abstract container that manages child component lifecycle, aggregates operations (parameters, gradients, training mode), and provides context propagation. Derived types implement execution semantics (forward/backward) and architecture definition (createGraph()).

Architecture Philosophy:

  • Context-independent graph creation: Architecture defined without device knowledge
  • Three-phase lifecycle: Graph creation -> Context binding -> Shape binding
  • Automatic context propagation: Base class propagates context to all children
  • Component-owns-name: Children manage their own identity via getName()

NOTE:

Member Typedef Documentation

◆ ComponentBase

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::ComponentBase = Component<TDeviceType, TPrecision>

◆ ComponentPtr

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Constructor & Destructor Documentation

◆ CompositeComponent() [1/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::CompositeComponent ( const std::string & name)
inlineexplicit

Construct composite component with name.

The composite component is named for identification in hierarchical structures. Derived classes should call createGraph() from their constructor to define the architecture graph (context-independent).

All child components added via addComponent() will receive ExecutionContext automatically when the composite receives its context (via onExecutionContextSet).

Parameters
nameComponent name identifier (mandatory)
Exceptions
std::invalid_argumentif name is not a valid identifier
Here is the caller graph for this function:

◆ ~CompositeComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::~CompositeComponent ( )
virtualdefault

◆ CompositeComponent() [2/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::CompositeComponent ( const CompositeComponent< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ CompositeComponent() [3/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::CompositeComponent ( CompositeComponent< TDeviceType, TPrecision > && )
defaultnoexcept
Here is the call graph for this function:

Member Function Documentation

◆ addComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
CompositeComponent & Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::addComponent ( ComponentPtr component)
inline

Add a pre-constructed child component (chainable).

Registers a component that was constructed externally (typically by the derived class in its createGraph() method). The component's getName() is used as the lookup key.

Components are expected to be created in shared mode (no ExecutionContext). Context will be automatically propagated to all children when this composite receives its context via onExecutionContextSet().

Usage pattern in derived class:

{
auto fc1 = std::make_shared<LinearType>(config, std::nullopt);
fc1->setName(this->getName() + ".fc1");
this->addComponent(fc1);
// ... more components
}
const std::string getName() const
Get the component's name identifier.
Definition Component.ixx:410
CompositeComponent & addComponent(ComponentPtr component)
Add a pre-constructed child component (chainable).
Definition CompositeComponent.ixx:128
void createGraph()
Build the internal component graph according to config_.
Definition MLP.ixx:452
Parameters
componentShared pointer to the constructed component
Returns
Reference to *this for method chaining
Exceptions
std::runtime_errorif called after build()
std::invalid_argumentif component is null
std::invalid_argumentif component name already exists
std::invalid_argumentif component already has its own ExecutionContext
Here is the call graph for this function:
Here is the caller graph for this function:

◆ childCount()

template<DeviceType TDeviceType, TensorDataType TPrecision>
size_t Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::childCount ( ) const
inlinenoexcept

Get the number of direct children.

Returns
Number of child components
Here is the caller graph for this function:

◆ clearComponents()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::clearComponents ( )
inline

Clear all child components.

Exceptions
std::runtime_errorif called after build()
Here is the call graph for this function:

◆ findComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
ComponentPtr Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::findComponent ( const std::string & path) const
inline

Resolve a dot-separated component path within this composite.

Supports both relative paths ("lenc.wte") and absolute paths ("gpt2.lenc.wte"). If path starts with this component's name, strips it before searching.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ getComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
ComponentPtr Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getComponent ( const std::string & name) const
inline

Retrieve a direct child component by name.

This performs direct (non-recursive) lookup of immediate children only. Use findComponent() to resolve dot-separated paths across the subgraph.

Parameters
nameName of the direct child component
Returns
Shared pointer to the component
Exceptions
std::out_of_rangeif the direct child is not found
Here is the caller graph for this function:

◆ getComponentAs()

template<DeviceType TDeviceType, TensorDataType TPrecision>
template<typename TComponent>
std::shared_ptr< TComponent > Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getComponentAs ( const std::string & name) const
inlineprotected

Retrieve a typed child component by name.

Helper method for derived composites (like MLP) that need to cache typed pointers to children in their onBuilding() hook. Performs dynamic_pointer_cast and validates the cast succeeded.

Note: This resolves direct children only. For full-path resolution use findComponent() on the appropriate root composite or Network.

Template Parameters
TComponentExpected component type
Parameters
nameName of the direct child component
Returns
Shared pointer to component with correct type
Exceptions
std::out_of_rangeif component name not found
std::runtime_errorif dynamic cast fails (type mismatch)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ getComponents()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const std::vector< ComponentPtr > & Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getComponents ( ) const
inline

Get all child components in insertion order.

Returns
Vector of child component pointers
Here is the caller graph for this function:

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TPrecision>
DeviceId Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getDeviceId ( ) const
inlineoverridevirtual

Get the compute device for this composite.

Returns the device from the shared execution context.

Returns
DeviceId for this composite and its children

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Reimplemented in Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector< ITensor * > Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getGradients ( ) const
inlineoverridevirtual

Get all parameter gradients from all children.

Returns
Vector of non-owning pointers to gradient tensors
Exceptions
std::runtime_errorif called before build() or not in training mode

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector< ITensor * > Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::getParameters ( ) const
inlineoverridevirtual

Get all parameters from all children.

Returns
Vector of non-owning pointers to parameter tensors
Exceptions
std::runtime_errorif called before build()

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ hasChildren()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::hasChildren ( ) const
inlinenoexcept

Check if this composite has any children.

Returns
true if at least one child component exists

◆ hasComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::hasComponent ( const std::string & name) const
inline

Check if a named child component exists.

Parameters
nameName to query
Returns
true if a child with this name exists

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::onExecutionContextSet ( )
inlineoverrideprotectedvirtual

Hook invoked after ExecutionContext is set.

Propagates the execution context to all child components that don't already have one. This enables the pattern where composites define their architecture graph in the constructor (context-independent) and context is bound later when available.

Called by Component::setExecutionContext() after the context is registered. Automatically invoked for both standalone mode (component creates own context) and shared mode (parent provides context).

Override this in derived classes if additional context-dependent initialization is required beyond context propagation to children.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::onTrainingModeChanging ( TrainingMode training_mode)
inlineoverrideprotectedvirtual

Hook invoked when training mode is about to change.

Propagates the new mode to all child components. The hook runs with the Component's training mutex held; it MUST NOT call setTraining().

Parameters
is_trainingNew training mode (true = training, false = eval)

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Reimplemented in Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, and Mila::Dnn::MLP< TDeviceType, TPrecision >.

◆ operator=() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
CompositeComponent & Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::operator= ( CompositeComponent< TDeviceType, TPrecision > && )
defaultnoexcept
Here is the call graph for this function:

◆ operator=() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
CompositeComponent & Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::operator= ( const CompositeComponent< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ optimize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::optimize ( )
inlineprotectedvirtual

Virtual hook for graph optimization after construction.

Called automatically after createGraph() completes. Derived classes can override to perform fusion, pruning, or other optimizations.

Default implementation does nothing. Override to perform architecture-specific graph optimizations.

Here is the caller graph for this function:

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TPrecision>
size_t Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::parameterCount ( ) const
inlineoverridevirtual

Count parameters across all children.

Returns
Total number of trainable parameters
Exceptions
std::runtime_errorif called before build()

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ removeComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::removeComponent ( const std::string & name)
inline

Get the named child components map.

Returns
Map of names to child component pointers

Remove a child component by name.

Parameters
nameName of the component to remove
Returns
true if removed, false if not found
Exceptions
std::runtime_errorif called after build()
Here is the call graph for this function:

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideprotectedvirtual

Save all child components recursively.

Follows the component serialization contract:

  • Writes type, version, and configuration metadata
  • Recursively saves all children with scoped namespaces
  • Each child's save_() handles its own state
Parameters
archiveArchive to write to
modeWhat to save (Checkpoint, WeightsOnly, Architecture)

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Reimplemented in Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >, Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::MLP< TDeviceType, TPrecision >, and Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::synchronize ( )
inlineoverridevirtual

Synchronize all child components.

Waits for outstanding device operations on all children.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Reimplemented in Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::toString ( ) const
inlineoverridevirtual

Generate a human-readable description.

Returns
String representation showing children

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Reimplemented in Mila::Dnn::GptBlock< TDeviceType, TPrecision >, Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >, Mila::Dnn::MLP< TDeviceType, TPrecision >, and Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ tryFindComponent()

template<DeviceType TDeviceType, TensorDataType TPrecision>
ComponentPtr Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::tryFindComponent ( const std::string & path) const
inline

Try to resolve a dot-separated component path within this composite.

Non-throwing version: returns nullptr if any segment is not found or if a path segment attempts to traverse into a non-composite leaf.

Example: auto ptr = composite->tryFindComponent("encoder.mlp.fc1");

Parameters
pathDot-separated path (e.g. "layer_0.mlp.fc_1")
Returns
ComponentPtr or nullptr when not found / invalid traversal

◆ ComponentFactory

template<DeviceType TDeviceType, TensorDataType TPrecision>
friend class ComponentFactory
friend

Member Data Documentation

◆ child_component_map_

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unordered_map<std::string, ComponentPtr, std::hash<std::string_view>, std::equal_to<> > Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::child_component_map_
private

Lookup map from child name to component pointer.

Provides O(1) name-based lookup for children. Keys are component names (obtained via component->getName()) and must be unique within the composite.

Used for:

Note: insertion order is preserved by child_components_; this unordered_map does not guarantee ordering but provides fast lookup.

◆ child_components_

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector<ComponentPtr> Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >::child_components_
private

Child components in insertion order.

This vector holds shared ownership of the composite's direct children and preserves the order in which children were added. The insertion order is used for build sequencing, ordered iteration, and serialization ordering.

Lifecycle invariants:

  • Children are constructed in createGraph() (called from derived constructor)
  • Children are registered via addComponent() before context is available
  • Context is propagated to children via onExecutionContextSet() hook
  • Children are built by parent's onBuilding() via template method pattern

Threading: access is not internally synchronized. Mutations and lifecycle operations must be externally serialized.


The documentation for this class was generated from the following file: