Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::TransformerBlock< TDeviceType, TDataType > Class Template Referenceexport

TransformerBlock implements a standard transformer encoder block. More...

Inheritance diagram for Mila::Dnn::TransformerBlock< TDeviceType, TDataType >:
Collaboration diagram for Mila::Dnn::TransformerBlock< TDeviceType, TDataType >:

Public Types

using CompositeModuleBase = CompositeModule< TDeviceType, TDataType >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 
- Public Types inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType >
using ModuleBase = Module< TDeviceType, TDataType, TDataType >
 Base class type for the module.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 TransformerBlock (const std::string &device_name, const TransformerBlockConfig &config)
 Constructs a new TransformerBlock module with a device name.
 
 TransformerBlock (std::shared_ptr< DeviceContext > device_context, const TransformerBlockConfig &config)
 Constructs a new TransformerBlock module with a provided device context.
 
void forward (const Tensor< TDataType, MR > &input, Tensor< TDataType, MR > &output)
 Performs the forward pass of the TransformerBlock.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &archive) const override
 Serializes the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 
- Public Member Functions inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType >
 CompositeModule ()
 Default constructor.
 
 CompositeModule (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with device context.
 
virtual ~CompositeModule ()=default
 Virtual destructor.
 
CompositeModuleaddModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add a named child module to this module.
 
CompositeModuleaddModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Add an unnamed child module to this module.
 
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > getModule (const std::string &name) const
 Get a specific sub-module by name.
 
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getModules () const
 Get all sub-modules contained in this module.
 
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & getNamedModules () const
 Get all named sub-modules contained in this module.
 
bool hasModule (const std::string &name) const
 Check if a sub-module with the given name exists.
 
bool removeModule (const std::string &name)
 Remove a sub-module by name.
 
bool replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module)
 Replace an existing sub-module with a new one.
 
void setTraining (bool is_training) override
 Set the training mode for this module and all its sub-modules.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 

Private Member Functions

void initializeModules ()
 Initializes the sub-modules and output tensors for the transformer block.
 

Private Attributes

std::shared_ptr< MultiHeadAttention< TDeviceType, TDataType > > attn_block_ { nullptr }
 Multi-head self-attention block including projections.
 
Tensor< TDataType, MRattn_output_
 Output tensor from attention block.
 
TransformerBlockConfig config_
 Configuration for the TransformerBlock module.
 
std::shared_ptr< Dropout< TDeviceType, TDataType > > dropout_ { nullptr }
 Optional dropout module.
 
std::shared_ptr< LayerNorm< TDeviceType, TDataType > > ln_1_ { nullptr }
 First layer normalization module.
 
Tensor< TDataType, MRln_1_output_
 Output tensor from first layer normalization.
 
std::shared_ptr< LayerNorm< TDeviceType, TDataType > > ln_2_ { nullptr }
 Second layer normalization module.
 
Tensor< TDataType, MRln_2_output_
 Output tensor from second layer normalization.
 
std::shared_ptr< MLP< TDeviceType, TDataType > > mlp_ { nullptr }
 Feed-forward network (MLP).
 
Tensor< TDataType, MRmlp_output_
 Output tensor from MLP.
 
Tensor< TDataType, MRres_1_output_
 Output tensor from first residual connection.
 
Tensor< TDataType, MRres_2_output_
 Output tensor from second residual connection.
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
requires ValidFloatTensorType<TDataType>
class Mila::Dnn::TransformerBlock< TDeviceType, TDataType >

TransformerBlock implements a standard transformer encoder block.

The transformer block consists of:

  • Multi-head self-attention mechanism with residual connection
  • Feed-forward network (MLP) with residual connection
  • Layer normalization before or after each sub-block (configurable)

This is the fundamental building block of transformer architectures like BERT and GPT. The implementation supports both pre-LN (more stable) and post-LN (original) architectures, configurable dropout rates, and other hyperparameters.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which to perform computations.
TDataTypeThe data type used for tensor elements throughout the network.

Member Typedef Documentation

◆ CompositeModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::CompositeModuleBase = CompositeModule<TDeviceType, TDataType>
export

Alias for base module type.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
using Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>
export

Memory resource type used for tensors, selected based on device type.

Constructor & Destructor Documentation

◆ TransformerBlock() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::TransformerBlock ( const std::string &  device_name,
const TransformerBlockConfig config 
)
inlineexplicitexport

Constructs a new TransformerBlock module with a device name.

Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
configConfiguration parameters for the TransformerBlock module.
Exceptions
std::invalid_argumentIf the device name is invalid or the configuration is invalid
std::runtime_errorIf device type doesn't match template parameter TDeviceType
Here is the call graph for this function:

◆ TransformerBlock() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::TransformerBlock ( std::shared_ptr< DeviceContext device_context,
const TransformerBlockConfig config 
)
inlineexplicitexport

Constructs a new TransformerBlock module with a provided device context.

Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.

Parameters
device_contextThe device context to use for this module.
configConfiguration parameters for the TransformerBlock module.
Exceptions
std::invalid_argumentIf device_context is null or configuration is invalid
std::runtime_errorIf device context type doesn't match template parameter TDeviceType
Here is the call graph for this function:

Member Function Documentation

◆ forward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::forward ( const Tensor< TDataType, MR > &  input,
Tensor< TDataType, MR > &  output 
)
inlineexport

Performs the forward pass of the TransformerBlock.

The forward pass follows either pre-LN or post-LN architecture based on configuration:

Pre-LN (default):

  1. Layer normalization 1
  2. Self-attention
  3. Residual connection
  4. Layer normalization 2
  5. Feed-forward network
  6. Residual connection

Post-LN:

  1. Self-attention
  2. Residual connection
  3. Layer normalization 1
  4. Feed-forward network
  5. Residual connection
  6. Layer normalization 2
Parameters
inputThe input tensor to be processed.
outputThe output tensor where the results will be stored.
Here is the call graph for this function:

◆ initializeModules()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::initializeModules ( )
inlineexportprivate

Initializes the sub-modules and output tensors for the transformer block.

Creates and configures all components of the transformer block according to the configuration, including layer norm, attention, and feed-forward network.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::load ( ModelArchive archive)
inlineoverrideexportvirtual

Deserializes the module state from a ZIP archive.

Loads the state of all sub-modules from the provided ZIP archive.

Parameters
zipThe ZIP archive to load the module state from.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
size_t Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::parameterCount ( ) const
inlineoverrideexportvirtual

Gets the number of trainable parameters in this module.

Counts the total number of parameters in all sub-modules.

Returns
size_t The total number of parameters.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
void Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::save ( ModelArchive archive) const
inlineoverrideexportvirtual

Serializes the module state to a ZIP archive.

Saves the state of all sub-modules to the provided ZIP archive.

Parameters
zipThe ZIP archive to save the module state to.

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::string Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::toString ( ) const
inlineoverrideexportvirtual

Generates a string representation of this module's configuration.

Returns
std::string A formatted string with module information

Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

Here is the call graph for this function:

Member Data Documentation

◆ attn_block_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<MultiHeadAttention<TDeviceType, TDataType> > Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::attn_block_ { nullptr }
exportprivate

Multi-head self-attention block including projections.

◆ attn_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::attn_output_
exportprivate

Output tensor from attention block.

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
TransformerBlockConfig Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::config_
exportprivate

Configuration for the TransformerBlock module.

◆ dropout_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<Dropout<TDeviceType, TDataType> > Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::dropout_ { nullptr }
exportprivate

Optional dropout module.

◆ ln_1_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<LayerNorm<TDeviceType, TDataType> > Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::ln_1_ { nullptr }
exportprivate

First layer normalization module.

In pre-LN architecture, applied before attention. In post-LN architecture, applied after attention and residual connection.

◆ ln_1_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::ln_1_output_
exportprivate

Output tensor from first layer normalization.

◆ ln_2_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<LayerNorm<TDeviceType, TDataType> > Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::ln_2_ { nullptr }
exportprivate

Second layer normalization module.

In pre-LN architecture, applied before MLP. In post-LN architecture, applied after MLP and residual connection.

◆ ln_2_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::ln_2_output_
exportprivate

Output tensor from second layer normalization.

◆ mlp_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
std::shared_ptr<MLP<TDeviceType, TDataType> > Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::mlp_ { nullptr }
exportprivate

Feed-forward network (MLP).

◆ mlp_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::mlp_output_
exportprivate

Output tensor from MLP.

◆ res_1_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::res_1_output_
exportprivate

Output tensor from first residual connection.

◆ res_2_output_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
Tensor<TDataType, MR> Mila::Dnn::TransformerBlock< TDeviceType, TDataType >::res_2_output_
exportprivate

Output tensor from second residual connection.


The documentation for this class was generated from the following file: