Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::MLP< TDeviceType, TDataType > Class Template Referenceexport

Multi-Layer Perceptron (MLP) block for neural networks. More...

Inheritance diagram for Mila::Dnn::MLP< TDeviceType, TDataType >:
Collaboration diagram for Mila::Dnn::MLP< TDeviceType, TDataType >:

Public Types

using CompositeModuleBase = CompositeModule< TDeviceType, TDataType >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 
- Public Types inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType >
using ModuleBase = Module< TDeviceType, TDataType, TDataType >
 Base class type for the module.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 MLP (const std::string &device_name, const MLPConfig &config)
 Constructs a new MLP module with a device name.
 
 MLP (std::shared_ptr< DeviceContext > device_context, const MLPConfig &config)
 Constructs a new MLP module with a provided device context.
 
void backward (const Tensor< TDataType, MR > &input, const Tensor< TDataType, MR > &output_grad, Tensor< TDataType, MR > &input_grad)
 Performs the backward pass of the MLP block.
 
void forward (const Tensor< TDataType, MR > &input, Tensor< TDataType, MR > &output)
 Performs the forward pass of the MLP block.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &archive) const override
 Serializes the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 
- Public Member Functions inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType >
 CompositeModule ()
 Default constructor.
 
 CompositeModule (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with device context.
 
virtual ~CompositeModule ()=default
 Virtual destructor.
 
CompositeModuleaddModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add a named child module to this module.
 
CompositeModuleaddModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add an unnamed child module to this module.
 
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > getModule (const std::string &name) const
 Get a specific sub-module by name.
 
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getModules () const
 Get all sub-modules contained in this module.
 
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getNamedModules () const
 Get all named sub-modules contained in this module.
 
bool hasModule (const std::string &name) const
 Check if a sub-module with the given name exists.
 
bool removeModule (const std::string &name)
 Remove a sub-module by name.
 
bool replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Replace an existing sub-module with a new one.
 
void setTraining (bool is_training) override
 Set the training mode for this module and all its sub-modules.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 

Private Member Functions

void initializeModules ()
 Initializes all submodules for the MLP.
 

Private Attributes

Tensor< TDataType, MRact_output_
 Output tensor from activation function.
 
std::shared_ptr< Module< TDeviceType, TDataType > > activation_ { nullptr }
 Activation function module.
 
MLPConfig config_
 Configuration for the MLP module.
 
std::shared_ptr< Dropout< TDeviceType, TDataType > > dropout1_ { nullptr }
 Optional dropout module.
 
Tensor< TDataType, MRdropout1_output_
 Output tensor from dropout.
 
std::shared_ptr< Linear< TDeviceType, TDataType > > fc1_ { nullptr }
 First linear layer (input_features -> hidden_size).
 
Tensor< TDataType, MRfc1_output_
 Output tensor from first linear layer.
 
std::shared_ptr< Linear< TDeviceType, TDataType > > fc2_ { nullptr }
 Second linear layer (hidden_size -> input_features).
 
Tensor< TDataType, MRfc2_output_
 Output tensor from second linear layer.
 
std::shared_ptr< LayerNorm< TDeviceType, TDataType > > norm1_ { nullptr }
 Optional layer normalization module.
 
Tensor< TDataType, MRnorm1_output_
 Output tensor from layer normalization.
 
Tensor< TDataType, MRresidual_input_
 Cached input tensor for residual connection.
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
requires ValidFloatTensorType<TDataType>
class Mila::Dnn::MLP< TDeviceType, TDataType >

Multi-Layer Perceptron (MLP) block for neural networks.

This module implements a two-layer MLP with an activation function in between: input -> Linear -> Activation -> Linear -> output

Optionally includes:

MLP blocks are fundamental components in many network architectures, including transformers where they typically follow attention layers and process token representations.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which to perform computations.
TDataTypeThe data type used for tensor elements throughout the network.

Member Typedef Documentation

◆ CompositeModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::MLP< TDeviceType, TDataType >::CompositeModuleBase = CompositeModule<TDeviceType, TDataType>
export

Alias for base module type.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::MLP< TDeviceType, TDataType >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>
export

Memory resource type used for tensors, selected based on device type.

Constructor & Destructor Documentation

◆ MLP() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::MLP< TDeviceType, TDataType >::MLP ( const std::string &  device_name,
const MLPConfig config 
)
inlineexplicitexport

Constructs a new MLP module with a device name.

Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
configConfiguration parameters for the MLP module.
Exceptions
std::invalid_argumentIf the device name is invalid or the configuration is invalid
std::runtime_errorIf device type doesn't match template parameter TDeviceType
Here is the call graph for this function:

◆ MLP() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::MLP< TDeviceType, TDataType >::MLP ( std::shared_ptr< DeviceContext device_context,
const MLPConfig config 
)
inlineexplicitexport

Constructs a new MLP module with a provided device context.

Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.

Parameters
device_contextThe device context to use for this module.
configConfiguration parameters for the MLP module.
Exceptions
std::invalid_argumentIf device_context is null or configuration is invalid
std::runtime_errorIf device context type doesn't match template parameter TDeviceType
Here is the call graph for this function:

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::MLP< TDeviceType, TDataType >::backward ( const Tensor< TDataType, MR > &  input,
const Tensor< TDataType, MR > &  output_grad,
Tensor< TDataType, MR > &  input_grad 
)
inlineexport

Performs the backward pass of the MLP block.

Computes gradients for all components in the network by working backwards from the output gradient. Handles residual connections, dropout, layer normalization, and activation functions.

Parameters
inputThe input tensor from the forward pass.
output_gradThe gradient of loss with respect to the output.
input_gradThe tensor to store gradients with respect to input.
Here is the call graph for this function:

◆ forward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::MLP< TDeviceType, TDataType >::forward ( const Tensor< TDataType, MR > &  input,
Tensor< TDataType, MR > &  output 
)
inlineexport

Performs the forward pass of the MLP block.

Processes the input through the full network: Linear -> (LayerNorm) -> Activation -> (Dropout) -> Linear -> (Residual)

When in inference mode with fused operations enabled, uses optimized execution.

Parameters
inputThe input tensor to be processed.
outputThe output tensor where the results will be stored.
Here is the call graph for this function:

◆ initializeModules()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::MLP< TDeviceType, TDataType >::initializeModules ( )
inlineexportprivate

Initializes all submodules for the MLP.

Creates and configures:

  1. Two linear layers with configurable hidden size
  2. Activation function (GELU or others in the future)
  3. Optional layer normalization
  4. Optional dropout Also prepares intermediate tensors for computation.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::MLP< TDeviceType, TDataType >::load ( ModelArchive archive)
inlineoverrideexportvirtual

Deserializes the module state from a ZIP archive.

Loads the state of all submodules from the provided ZIP archive.

Parameters
zipThe ZIP archive to load the module state from.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
size_t Mila::Dnn::MLP< TDeviceType, TDataType >::parameterCount ( ) const
inlineoverrideexportvirtual

Gets the number of trainable parameters in this module.

Counts the total number of parameters across all submodules.

Returns
size_t The total number of parameters.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::MLP< TDeviceType, TDataType >::save ( ModelArchive archive) const
inlineoverrideexportvirtual

Serializes the module state to a ZIP archive.

Saves the state of all submodules to the provided ZIP archive.

Parameters
zipThe ZIP archive to save the module state to.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::string Mila::Dnn::MLP< TDeviceType, TDataType >::toString ( ) const
inlineoverrideexportvirtual

Generates a string representation of this module's configuration.

Returns
std::string A formatted string with module information

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

Member Data Documentation

◆ act_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::act_output_
exportprivate

Output tensor from activation function.

◆ activation_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<Module<TDeviceType, TDataType> > Mila::Dnn::MLP< TDeviceType, TDataType >::activation_ { nullptr }
exportprivate

Activation function module.

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
MLPConfig Mila::Dnn::MLP< TDeviceType, TDataType >::config_
exportprivate

Configuration for the MLP module.

◆ dropout1_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<Dropout<TDeviceType, TDataType> > Mila::Dnn::MLP< TDeviceType, TDataType >::dropout1_ { nullptr }
exportprivate

Optional dropout module.

◆ dropout1_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::dropout1_output_
exportprivate

Output tensor from dropout.

◆ fc1_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<Linear<TDeviceType, TDataType> > Mila::Dnn::MLP< TDeviceType, TDataType >::fc1_ { nullptr }
exportprivate

First linear layer (input_features -> hidden_size).

◆ fc1_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::fc1_output_
exportprivate

Output tensor from first linear layer.

◆ fc2_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<Linear<TDeviceType, TDataType> > Mila::Dnn::MLP< TDeviceType, TDataType >::fc2_ { nullptr }
exportprivate

Second linear layer (hidden_size -> input_features).

◆ fc2_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::fc2_output_
exportprivate

Output tensor from second linear layer.

◆ norm1_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<LayerNorm<TDeviceType, TDataType> > Mila::Dnn::MLP< TDeviceType, TDataType >::norm1_ { nullptr }
exportprivate

Optional layer normalization module.

◆ norm1_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::norm1_output_
exportprivate

Output tensor from layer normalization.

◆ residual_input_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::MLP< TDeviceType, TDataType >::residual_input_
exportprivate

Cached input tensor for residual connection.


The documentation for this class was generated from the following file: