Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::CompositeModule< TDeviceType, TDataType > Class Template Referenceexport

A module class that can contain and manage child modules. More...

Inheritance diagram for Mila::Dnn::CompositeModule< TDeviceType, TDataType >:
Collaboration diagram for Mila::Dnn::CompositeModule< TDeviceType, TDataType >:

Public Types

using ModuleBase = Module< TDeviceType, TDataType, TDataType >
 Base class type for the module.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 CompositeModule ()
 Default constructor.
 
 CompositeModule (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with device context.
 
virtual ~CompositeModule ()=default
 Virtual destructor.
 
CompositeModuleaddModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add a named child module to this module.
 
CompositeModuleaddModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add an unnamed child module to this module.
 
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > getModule (const std::string &name) const
 Get a specific sub-module by name.
 
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getModules () const
 Get all sub-modules contained in this module.
 
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getNamedModules () const
 Get all named sub-modules contained in this module.
 
bool hasModule (const std::string &name) const
 Check if a sub-module with the given name exists.
 
void load (ModelArchive &archive) override
 Default load implementation for container modules.
 
size_t parameterCount () const override
 Count the total number of parameters in this module and all sub-modules.
 
bool removeModule (const std::string &name)
 Remove a sub-module by name.
 
bool replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Replace an existing sub-module with a new one.
 
void save (ModelArchive &archive) const override
 Default save implementation for container modules.
 
void setTraining (bool is_training) override
 Set the training mode for this module and all its sub-modules.
 
std::string toString () const override
 Default toString implementation for container modules.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 

Private Attributes

std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > child_module_map_
 Named child modules for efficient lookup by name.
 
std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > child_modules_
 Child modules in the order they were added (ordered)
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
requires ValidTensorType<TDataType>
class Mila::Dnn::CompositeModule< TDeviceType, TDataType >

A module class that can contain and manage child modules.

CompositeModule extends the base Module class with functionality to add, remove, and manage child modules. This is used for composite neural network components like MLPs, transformers, etc. that are built by composing simpler modules.

A single type parameter is used for data consistency across the module, as the output of one layer becomes the input of the next in a feed-forward composite architecture.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which the module will operate.
TDataTypeData type used for both input and output tensor elements.

Member Typedef Documentation

◆ ModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::CompositeModule< TDeviceType, TDataType >::ModuleBase = Module<TDeviceType, TDataType, TDataType>

Base class type for the module.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::CompositeModule< TDeviceType, TDataType >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, HostMemoryResource>

Memory resource type based on device type.

Constructor & Destructor Documentation

◆ CompositeModule() [1/3]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::CompositeModule< TDeviceType, TDataType >::CompositeModule ( )
inline

Default constructor.

◆ CompositeModule() [2/3]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::CompositeModule< TDeviceType, TDataType >::CompositeModule ( const std::string &  device_name,
const ComponentConfig config 
)
inlineexplicit

Constructor with device name.

Parameters
device_nameThe device name to use for this module.
precisionThe compute precision policy to use (defaults to Auto).

◆ CompositeModule() [3/3]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::CompositeModule< TDeviceType, TDataType >::CompositeModule ( std::shared_ptr< DeviceContext context,
const ComponentConfig config 
)
inlineexplicit

Constructor with device context.

Parameters
contextThe device context to use for this module.
precisionThe compute precision policy to use (defaults to Auto).

◆ ~CompositeModule()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
virtual Mila::Dnn::CompositeModule< TDeviceType, TDataType >::~CompositeModule ( )
virtualdefault

Virtual destructor.

Member Function Documentation

◆ addModule() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
CompositeModule & Mila::Dnn::CompositeModule< TDeviceType, TDataType >::addModule ( const std::string &  name,
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > >  module 
)
inline

Add a named child module to this module.

Parameters
nameThe name to identify the child module.
moduleThe child module to register.
Exceptions
std::invalid_argumentIf the name is empty or already exists.
std::invalid_argumentIf the module pointer is null.
Returns
Reference to this module for method chaining.
Here is the caller graph for this function:

◆ addModule() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
CompositeModule & Mila::Dnn::CompositeModule< TDeviceType, TDataType >::addModule ( std::shared_ptr< Module< TDeviceType, TDataType, TDataType > >  module)
inline

Add an unnamed child module to this module.

Parameters
moduleThe child module to register.
Exceptions
std::invalid_argumentIf the module pointer is null.
Returns
Reference to this module for method chaining.
Here is the call graph for this function:

◆ getModule()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > Mila::Dnn::CompositeModule< TDeviceType, TDataType >::getModule ( const std::string &  name) const
inline

Get a specific sub-module by name.

Parameters
nameThe name of the sub-module to retrieve.
Returns
std::shared_ptr<Module<TDeviceType, TDataType, TDataType>> The requested module.
Exceptions
std::out_of_rangeIf no module with the given name exists.

◆ getModules()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & Mila::Dnn::CompositeModule< TDeviceType, TDataType >::getModules ( ) const
inline

Get all sub-modules contained in this module.

Returns
const std::vector<std::shared_ptr<Module<TDeviceType, TDataType, TDataType>>>& Vector of child module pointers.
Here is the caller graph for this function:

◆ getNamedModules()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & Mila::Dnn::CompositeModule< TDeviceType, TDataType >::getNamedModules ( ) const
inline

Get all named sub-modules contained in this module.

Returns
const std::unordered_map<std::string, std::shared_ptr<Module<TDeviceType, TDataType, TDataType>>>& Map of child module names to pointers.
Here is the caller graph for this function:

◆ hasModule()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
bool Mila::Dnn::CompositeModule< TDeviceType, TDataType >::hasModule ( const std::string &  name) const
inline

Check if a sub-module with the given name exists.

Parameters
nameThe name to check.
Returns
bool True if a sub-module with the given name exists.

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::CompositeModule< TDeviceType, TDataType >::load ( ModelArchive archive)
inlineoverridevirtual

Default load implementation for container modules.

Loads all child modules. Override if container has its own parameters.

Parameters
zipThe ZIP archive to load the module state from.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
size_t Mila::Dnn::CompositeModule< TDeviceType, TDataType >::parameterCount ( ) const
inlineoverridevirtual

Count the total number of parameters in this module and all sub-modules.

Returns
size_t Total number of parameters.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.

◆ removeModule()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
bool Mila::Dnn::CompositeModule< TDeviceType, TDataType >::removeModule ( const std::string &  name)
inline

Remove a sub-module by name.

Parameters
nameThe name of the sub-module to remove.
Returns
bool True if a module was removed, false if no module with that name existed.
Here is the caller graph for this function:

◆ replaceModule()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
bool Mila::Dnn::CompositeModule< TDeviceType, TDataType >::replaceModule ( const std::string &  name,
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > >  module 
)
inline

Replace an existing sub-module with a new one.

Parameters
nameThe name of the sub-module to replace.
moduleThe new module to use as replacement.
Returns
bool True if a module was replaced, false if no module with that name existed.
Exceptions
std::invalid_argumentIf the replacement module pointer is null.

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::CompositeModule< TDeviceType, TDataType >::save ( ModelArchive archive) const
inlineoverridevirtual

Default save implementation for container modules.

Saves all child modules. Override if container has its own parameters.

Parameters
zipThe ZIP archive to save the module state to.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.

◆ setTraining()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::CompositeModule< TDeviceType, TDataType >::setTraining ( bool  is_training)
inlineoverridevirtual

Set the training mode for this module and all its sub-modules.

Parameters
is_trainingTrue if the module is in training mode, false for inference mode.

Reimplemented from Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::string Mila::Dnn::CompositeModule< TDeviceType, TDataType >::toString ( ) const
inlineoverridevirtual

Default toString implementation for container modules.

Lists all child modules. Override for custom string representation.

Returns
std::string A string representation of the module information.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.

Here is the call graph for this function:

Member Data Documentation

◆ child_module_map_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::unordered_map<std::string, std::shared_ptr<Module<TDeviceType, TDataType, TDataType> > > Mila::Dnn::CompositeModule< TDeviceType, TDataType >::child_module_map_
private

Named child modules for efficient lookup by name.

◆ child_modules_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::vector<std::shared_ptr<Module<TDeviceType, TDataType, TDataType> > > Mila::Dnn::CompositeModule< TDeviceType, TDataType >::child_modules_
private

Child modules in the order they were added (ordered)


The documentation for this class was generated from the following file: