Mila
Deep Neural Network Library
|
Abstract base class for all modules in the Mila DNN framework. More...
Public Types | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | load (ModelArchive &archive)=0 |
Load the module state from a zip archive. | |
virtual size_t | parameterCount () const =0 |
Get the number of trainable parameters in the module. | |
virtual void | save (ModelArchive &archive) const =0 |
Save the module state to a zip archive. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
virtual std::string | toString () const =0 |
Convert the module to a string representation. | |
Protected Member Functions | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
Protected Attributes | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
Static Private Member Functions | |
static std::shared_ptr< Compute::DeviceContext > | createContext (const std::string &device_name) |
Helper method to create a DeviceContext from a device name. | |
Private Attributes | |
ComponentConfig | config_ |
std::shared_ptr< Compute::DeviceContext > | device_context_ |
The device context used for this module's computations. | |
bool | training_mode_ { false } |
Whether the module is in training mode. | |
Friends | |
std::ostream & | operator<< (std::ostream &os, const Module &module) |
Overload the << operator to print the module information. | |
Abstract base class for all modules in the Mila DNN framework.
The Module class provides a common interface for all neural network layers and components, enabling consistent handling of parameters, state, and device context. For container functionality that supports child modules, use the Block class.
using Mila::Dnn::Module< TDeviceType, TInput, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource> |
|
inlineexplicit |
Constructor with device name.
Creates a module with a device context for the specified device name. This allows modules to be created with a simple string identifier rather than requiring manual construction of a DeviceContext.
device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). Must be one of the names returned by DeviceRegistry::list_devices(). |
policy | The compute precision policy to use (default is Auto). |
std::runtime_error | If the specified device name is invalid or doesn't match TDeviceType. |
|
inlineexplicit |
Constructor with a specific device context.
context | The device context to use for this module. |
policy | The compute precision policy to use (default is Auto). |
std::invalid_argument | If the provided context is nullptr. |
std::runtime_error | If the context device type doesn't match TDeviceType. |
|
virtualdefault |
Virtual destructor for proper cleanup in derived classes.
|
inlinestaticprivate |
Helper method to create a DeviceContext from a device name.
device_name | Name of the device to create a context for. |
std::runtime_error | If the device name is invalid. |
|
inline |
Get the device context for this module.
Returns a shared pointer to the device context that this module is currently using.
|
inline |
Get the device type of the current device context.
|
inline |
Get the name of the module.
|
inline |
Get the parameter tensors of this module.
Parameter tensors represent learnable weights that are updated during training via gradient descent or other optimization algorithms.
|
inline |
|
inline |
Get the state tensors of this module.
State tensors represent non-trainable tensors that may be updated during forward/backward passes (e.g., running mean and variance in batch normalization).
TMR | Memory resource type (defaults to the module's MR type). |
|
inline |
Check if the module is in training mode.
|
pure virtual |
Load the module state from a zip archive.
Deserializes the module's parameters and state from the provided zip archive. This enables loading pre-trained models for inference or continued training.
zip | The zip archive to load the state from. |
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.
|
pure virtual |
Get the number of trainable parameters in the module.
This should count only the parameters in this specific module.
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.
|
inlineprotected |
Helper method to convert parameters to string representation.
|
pure virtual |
Save the module state to a zip archive.
Serializes the module's parameters and state to the provided zip archive. This enables model persistence for later reuse.
archive | The archive to save the state to. |
Implemented in Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >, Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, and Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >.
|
inlinevirtual |
Set the training mode of this module.
is_training | True if the module is in training mode, false for inference. |
Reimplemented in Mila::Dnn::CompositeModule< TDeviceType, TDataType >, and Mila::Dnn::CompositeModule< DeviceType::Cuda, float >.
|
inlineprotected |
Helper method to convert state tensors to string representation.
|
pure virtual |
Convert the module to a string representation.
This should include relevant information about the module structure, parameters, and configuration for debugging and logging purposes.
Implemented in Mila::Dnn::Gelu< TDeviceType, TDataType >, Mila::Dnn::MLP< TDeviceType, TDataType >, Mila::Dnn::TransformerBlock< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< TDeviceType, TDataType >, Mila::Dnn::CompositeModule< DeviceType::Cuda, float >, Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >, Mila::Dnn::Linear< TDeviceType, TInput, TOutput >, Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >, Mila::Dnn::Residual< TDeviceType, TInput, TOutput >, Mila::Dnn::Softmax< TDeviceType, TInput, TOutput >, Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >, Mila::Dnn::LayerNorm< TDeviceType, TInput, TOutput >, and Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >.
|
friend |
Overload the << operator to print the module information.
os | Output stream. |
module | Module to print. |
|
private |
|
private |
The device context used for this module's computations.
|
protected |
Map of parameter names to parameter tensors.
|
protected |
Map of state names to state tensors.
|
private |
Whether the module is in training mode.
Default is false