Mila
Deep Neural Network Library
|
A module class that can contain and manage child modules. More...
Public Types | |
using | ModuleBase = Module< TDeviceType, TDataType, TDataType > |
Base class type for the module. | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource > |
Memory resource type based on device type. | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
CompositeModule () | |
Default constructor. | |
CompositeModule (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with device context. | |
virtual | ~CompositeModule ()=default |
Virtual destructor. | |
CompositeModule & | addModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
Add a named child module to this module. | |
CompositeModule & | addModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
Add an unnamed child module to this module. | |
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > | getModule (const std::string &name) const |
Get a specific sub-module by name. | |
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getModules () const |
Get all sub-modules contained in this module. | |
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getNamedModules () const |
Get all named sub-modules contained in this module. | |
bool | hasModule (const std::string &name) const |
Check if a sub-module with the given name exists. | |
void | load (ModelArchive &archive) override |
Default load implementation for container modules. | |
size_t | parameterCount () const override |
Count the total number of parameters in this module and all sub-modules. | |
bool | removeModule (const std::string &name) |
Remove a sub-module by name. | |
bool | replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
Replace an existing sub-module with a new one. | |
void | save (ModelArchive &archive) const override |
Default save implementation for container modules. | |
void | setTraining (bool is_training) override |
Set the training mode for this module and all its sub-modules. | |
std::string | toString () const override |
Default toString implementation for container modules. | |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
Private Attributes | |
std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > | child_module_map_ |
Named child modules for efficient lookup by name. | |
std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > | child_modules_ |
Child modules in the order they were added (ordered) | |
Additional Inherited Members | |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
A module class that can contain and manage child modules.
CompositeModule extends the base Module class with functionality to add, remove, and manage child modules. This is used for composite neural network components like MLPs, transformers, etc. that are built by composing simpler modules.
A single type parameter is used for data consistency across the module, as the output of one layer becomes the input of the next in a feed-forward composite architecture.
TDeviceType | The device type (CPU or CUDA) on which the module will operate. |
TDataType | Data type used for both input and output tensor elements. |
using Mila::Dnn::CompositeModule< TDeviceType, TDataType >::ModuleBase = Module<TDeviceType, TDataType, TDataType> |
Base class type for the module.
using Mila::Dnn::CompositeModule< TDeviceType, TDataType >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, HostMemoryResource> |
Memory resource type based on device type.
|
inline |
Default constructor.
|
inlineexplicit |
Constructor with device name.
device_name | The device name to use for this module. |
precision | The compute precision policy to use (defaults to Auto). |
|
inlineexplicit |
Constructor with device context.
context | The device context to use for this module. |
precision | The compute precision policy to use (defaults to Auto). |
|
virtualdefault |
Virtual destructor.
|
inline |
Add a named child module to this module.
name | The name to identify the child module. |
module | The child module to register. |
std::invalid_argument | If the name is empty or already exists. |
std::invalid_argument | If the module pointer is null. |
|
inline |
Add an unnamed child module to this module.
module | The child module to register. |
std::invalid_argument | If the module pointer is null. |
|
inline |
Get a specific sub-module by name.
name | The name of the sub-module to retrieve. |
std::out_of_range | If no module with the given name exists. |
|
inline |
Get all sub-modules contained in this module.
|
inline |
Get all named sub-modules contained in this module.
|
inline |
Check if a sub-module with the given name exists.
name | The name to check. |
|
inlineoverridevirtual |
Default load implementation for container modules.
Loads all child modules. Override if container has its own parameters.
zip | The ZIP archive to load the module state from. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.
|
inlineoverridevirtual |
Count the total number of parameters in this module and all sub-modules.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.
|
inline |
Remove a sub-module by name.
name | The name of the sub-module to remove. |
|
inline |
Replace an existing sub-module with a new one.
name | The name of the sub-module to replace. |
module | The new module to use as replacement. |
std::invalid_argument | If the replacement module pointer is null. |
|
inlineoverridevirtual |
Default save implementation for container modules.
Saves all child modules. Override if container has its own parameters.
zip | The ZIP archive to save the module state to. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.
|
inlineoverridevirtual |
Set the training mode for this module and all its sub-modules.
is_training | True if the module is in training mode, false for inference mode. |
Reimplemented from Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverridevirtual |
Default toString implementation for container modules.
Lists all child modules. Override for custom string representation.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
Reimplemented in Mila::Dnn::MLP< TDeviceType, TDataType >, and Mila::Dnn::TransformerBlock< TDeviceType, TDataType >.
|
private |
Named child modules for efficient lookup by name.
|
private |
Child modules in the order they were added (ordered)