Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Module< TDeviceType, TInput, TOutput > Class Template Referenceabstractexport
module Dnn.Module

Abstract base class for all modules in the Mila DNN framework. More...

Inheritance diagram for Mila::Dnn::Module< TDeviceType, TInput, TOutput >:
Collaboration diagram for Mila::Dnn::Module< TDeviceType, TInput, TOutput >:

Public Types

using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 
virtual void load (ModelArchive &archive)=0
 Load the module state from a zip archive.
 
virtual size_t parameterCount () const =0
 Get the number of trainable parameters in the module.
 
virtual void save (ModelArchive &archive) const =0
 Save the module state to a zip archive.
 
virtual void setTraining (bool is_training)
 Set the training mode of this module.
 
virtual std::string toString () const =0
 Convert the module to a string representation.
 

Protected Member Functions

const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 

Protected Attributes

std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Static Private Member Functions

static std::shared_ptr< Compute::DeviceContextcreateContext (const std::string &device_name)
 Helper method to create a DeviceContext from a device name.
 

Private Attributes

ComponentConfig config_
 
std::shared_ptr< Compute::DeviceContextdevice_context_
 The device context used for this module's computations.
 
bool training_mode_ { false }
 Whether the module is in training mode.
 

Friends

std::ostream & operator<< (std::ostream &os, const Module &module)
 Overload the << operator to print the module information.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
requires ValidTensorType<TInput> && ValidFloatTensorType<TOutput>
class Mila::Dnn::Module< TDeviceType, TInput, TOutput >

Abstract base class for all modules in the Mila DNN framework.

The Module class provides a common interface for all neural network layers and components, enabling consistent handling of parameters, state, and device context. For container functionality that supports child modules, use the Block class.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which the module will operate.
TInputData type of the input tensor elements.
TOutputData type of the output tensor elements, defaults to TInput.

Member Typedef Documentation

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::Module< TDeviceType, TInput, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>

Constructor & Destructor Documentation

◆ Module() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Mila::Dnn::Module< TDeviceType, TInput, TOutput >::Module ( const std::string &  device_name,
const ComponentConfig config 
)
inlineexplicit

Constructor with device name.

Creates a module with a device context for the specified device name. This allows modules to be created with a simple string identifier rather than requiring manual construction of a DeviceContext.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0"). Must be one of the names returned by DeviceRegistry::list_devices().
policyThe compute precision policy to use (default is Auto).
Exceptions
std::runtime_errorIf the specified device name is invalid or doesn't match TDeviceType.
Here is the call graph for this function:

◆ Module() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Mila::Dnn::Module< TDeviceType, TInput, TOutput >::Module ( std::shared_ptr< DeviceContext context,
const ComponentConfig config 
)
inlineexplicit

Constructor with a specific device context.

Parameters
contextThe device context to use for this module.
policyThe compute precision policy to use (default is Auto).
Exceptions
std::invalid_argumentIf the provided context is nullptr.
std::runtime_errorIf the context device type doesn't match TDeviceType.
Here is the call graph for this function:

◆ ~Module()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
virtual Mila::Dnn::Module< TDeviceType, TInput, TOutput >::~Module ( )
virtualdefault

Virtual destructor for proper cleanup in derived classes.

Member Function Documentation

◆ createContext()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
static std::shared_ptr< Compute::DeviceContext > Mila::Dnn::Module< TDeviceType, TInput, TOutput >::createContext ( const std::string &  device_name)
inlinestaticprivate

Helper method to create a DeviceContext from a device name.

Parameters
device_nameName of the device to create a context for.
Returns
std::shared_ptr<DeviceContext> The created device context.
Exceptions
std::runtime_errorIf the device name is invalid.

◆ getDeviceContext()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::shared_ptr< Compute::DeviceContext > Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getDeviceContext ( ) const
inline

Get the device context for this module.

Returns a shared pointer to the device context that this module is currently using.

Returns
std::shared_ptr<Compute::DeviceContext> The device context.
Here is the caller graph for this function:

◆ getDeviceType()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Compute::DeviceType Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getDeviceType ( ) const
inline

Get the device type of the current device context.

Returns
Compute::DeviceType The device type (CPU or CUDA).
Here is the caller graph for this function:

◆ getName()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::string Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getName ( ) const
inline

Get the name of the module.

Returns
std::string Name of the module.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ getParameterTensors()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const auto & Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getParameterTensors ( ) const
inline

Get the parameter tensors of this module.

Parameter tensors represent learnable weights that are updated during training via gradient descent or other optimization algorithms.

Returns
const std::unordered_map<std::string, std::shared_ptr<Tensor<TDataType, MR>>>& Map of parameter names to tensor pointers.
Here is the caller graph for this function:

◆ getPrecision()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const ComputePrecision::Policy & Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getPrecision ( ) const
inline
Here is the call graph for this function:

◆ getStateTensors()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const auto & Mila::Dnn::Module< TDeviceType, TInput, TOutput >::getStateTensors ( ) const
inline

Get the state tensors of this module.

State tensors represent non-trainable tensors that may be updated during forward/backward passes (e.g., running mean and variance in batch normalization).

Template Parameters
TMRMemory resource type (defaults to the module's MR type).
Returns
const std::unordered_map<std::string, std::shared_ptr<Tensor<TDataType, MR>>>& Map of state names to tensor pointers.
Here is the caller graph for this function:

◆ isTraining()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
bool Mila::Dnn::Module< TDeviceType, TInput, TOutput >::isTraining ( ) const
inline

Check if the module is in training mode.

Returns
true If the module is in training mode.
false If the module is in inference mode.
Here is the caller graph for this function:

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
virtual void Mila::Dnn::Module< TDeviceType, TInput, TOutput >::load ( ModelArchive archive)
pure virtual

◆ parameterCount()

◆ parametersToString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const std::string Mila::Dnn::Module< TDeviceType, TInput, TOutput >::parametersToString ( ) const
inlineprotected

Helper method to convert parameters to string representation.

Returns
std::string String representation of all parameters.
Here is the call graph for this function:

◆ save()

◆ setTraining()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
virtual void Mila::Dnn::Module< TDeviceType, TInput, TOutput >::setTraining ( bool  is_training)
inlinevirtual

Set the training mode of this module.

Parameters
is_trainingTrue if the module is in training mode, false for inference.

Reimplemented in Mila::Dnn::CompositeModule< TDeviceType, TDataType >, and Mila::Dnn::CompositeModule< DeviceType::Cuda, float >.

Here is the caller graph for this function:

◆ stateToString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const std::string Mila::Dnn::Module< TDeviceType, TInput, TOutput >::stateToString ( ) const
inlineprotected

Helper method to convert state tensors to string representation.

Returns
std::string String representation of all state tensors.
Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
virtual std::string Mila::Dnn::Module< TDeviceType, TInput, TOutput >::toString ( ) const
pure virtual

Friends And Related Symbol Documentation

◆ operator<<

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::ostream & operator<< ( std::ostream &  os,
const Module< TDeviceType, TInput, TOutput > &  module 
)
friend

Overload the << operator to print the module information.

Parameters
osOutput stream.
moduleModule to print.
Returns
std::ostream& Reference to the output stream.

Member Data Documentation

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
ComponentConfig Mila::Dnn::Module< TDeviceType, TInput, TOutput >::config_
private

◆ device_context_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::shared_ptr<Compute::DeviceContext> Mila::Dnn::Module< TDeviceType, TInput, TOutput >::device_context_
private

The device context used for this module's computations.

◆ parameter_map_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::unordered_map<std::string, std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::Module< TDeviceType, TInput, TOutput >::parameter_map_ = {}
protected

Map of parameter names to parameter tensors.

◆ state_map_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::unordered_map<std::string, std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::Module< TDeviceType, TInput, TOutput >::state_map_ = {}
protected

Map of state names to state tensors.

◆ training_mode_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
bool Mila::Dnn::Module< TDeviceType, TInput, TOutput >::training_mode_ { false }
private

Whether the module is in training mode.

Default is false


The documentation for this class was generated from the following file: