|
Mila
Deep Neural Network Library
|
Abstract base class for all modules in the Mila DNN framework. More...


Public Types | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
| Module (const std::string &device_name, const ComponentConfig &config) | |
| Constructor with device name. | |
| Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
| Constructor with a specific device context. | |
| virtual | ~Module ()=default |
| Virtual destructor for proper cleanup in derived classes. | |
| std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
| Get the device context for this module. | |
| Compute::DeviceType | getDeviceType () const |
| Get the device type of the current device context. | |
| std::string | getName () const |
| Get the name of the module. | |
| const auto & | getParameterTensors () const |
| Get the parameter tensors of this module. | |
| const ComputePrecision::Policy & | getPrecision () const |
| const auto & | getStateTensors () const |
| Get the state tensors of this module. | |
| bool | isTraining () const |
| Check if the module is in training mode. | |
| virtual void | load (ModelArchive &archive)=0 |
| Load the module state from a zip archive. | |
| virtual size_t | parameterCount () const =0 |
| Get the number of trainable parameters in the module. | |
| virtual void | save (ModelArchive &archive) const =0 |
| Save the module state to a zip archive. | |
| virtual void | setTraining (bool is_training) |
| Set the training mode of this module. | |
| virtual std::string | toString () const =0 |
| Convert the module to a string representation. | |
Protected Member Functions | |
| const std::string | parametersToString () const |
| Helper method to convert parameters to string representation. | |
| const std::string | stateToString () const |
| Helper method to convert state tensors to string representation. | |
Protected Attributes | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
| Map of parameter names to parameter tensors. | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
| Map of state names to state tensors. | |
Static Private Member Functions | |
| static std::shared_ptr< Compute::DeviceContext > | createContext (const std::string &device_name) |
| Helper method to create a DeviceContext from a device name. | |
Private Attributes | |
| ComponentConfig | config_ |
| std::shared_ptr< Compute::DeviceContext > | device_context_ |
| The device context used for this module's computations. | |
| bool | training_mode_ { false } |
| Whether the module is in training mode. | |
Friends | |
| std::ostream & | operator<< (std::ostream &os, const Module &module) |
| Overload the << operator to print the module information. | |
Abstract base class for all modules in the Mila DNN framework.
The Module class provides a common interface for all neural network layers and components, enabling consistent handling of parameters, state, and device context. For container functionality that supports child modules, use the Block class.
| using Mila::Dnn::Module< TDeviceType, TInput, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource> |
|
inlineexplicit |
Constructor with device name.
Creates a module with a device context for the specified device name. This allows modules to be created with a simple string identifier rather than requiring manual construction of a DeviceContext.
| device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). Must be one of the names returned by DeviceRegistry::list_devices(). |
| policy | The compute precision policy to use (default is Auto). |
| std::runtime_error | If the specified device name is invalid or doesn't match TDeviceType. |

|
inlineexplicit |
Constructor with a specific device context.
| context | The device context to use for this module. |
| policy | The compute precision policy to use (default is Auto). |
| std::invalid_argument | If the provided context is nullptr. |
| std::runtime_error | If the context device type doesn't match TDeviceType. |

|
virtualdefault |
Virtual destructor for proper cleanup in derived classes.
|
inlinestaticprivate |
Helper method to create a DeviceContext from a device name.
| device_name | Name of the device to create a context for. |
| std::runtime_error | If the device name is invalid. |
|
inline |
Get the device context for this module.
Returns a shared pointer to the device context that this module is currently using.

|
inline |
Get the device type of the current device context.

|
inline |
Get the name of the module.


|
inline |
Get the parameter tensors of this module.
Parameter tensors represent learnable weights that are updated during training via gradient descent or other optimization algorithms.

|
inline |

|
inline |
Get the state tensors of this module.
State tensors represent non-trainable tensors that may be updated during forward/backward passes (e.g., running mean and variance in batch normalization).
| TMR | Memory resource type (defaults to the module's MR type). |

|
inline |
Check if the module is in training mode.

|
pure virtual |
Load the module state from a zip archive.
Deserializes the module's parameters and state from the provided zip archive. This enables loading pre-trained models for inference or continued training.
| zip | The zip archive to load the state from. |
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.
|
pure virtual |
Get the number of trainable parameters in the module.
This should count only the parameters in this specific module.
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.

|
inlineprotected |
Helper method to convert parameters to string representation.

|
pure virtual |
Save the module state to a zip archive.
Serializes the module's parameters and state to the provided zip archive. This enables model persistence for later reuse.
| archive | The archive to save the state to. |
Implemented in Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, and Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >.
|
inlinevirtual |
Set the training mode of this module.
| is_training | True if the module is in training mode, false for inference. |
Reimplemented in Mila::Dnn::CompositeModule< TDeviceType, TDataType >, and Mila::Dnn::CompositeModule< DeviceType::Cuda, float >.

|
inlineprotected |
Helper method to convert state tensors to string representation.

|
pure virtual |
Convert the module to a string representation.
This should include relevant information about the module structure, parameters, and configuration for debugging and logging purposes.
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.
|
friend |
Overload the << operator to print the module information.
| os | Output stream. |
| module | Module to print. |
|
private |
|
private |
The device context used for this module's computations.
|
protected |
Map of parameter names to parameter tensors.
|
protected |
Map of state names to state tensors.
|
private |
Whether the module is in training mode.
Default is false