|
Mila
Deep Neural Network Library
|
TransformerBlock implements a standard transformer encoder block. More...


Public Types | |
| using | CompositeModuleBase = CompositeModule< TDeviceType, TDataType > |
| Alias for base module type. | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
| Memory resource type used for tensors, selected based on device type. | |
Public Types inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType > | |
| using | ModuleBase = Module< TDeviceType, TDataType, TDataType > |
| Base class type for the module. | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource > |
| Memory resource type based on device type. | |
Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
| TransformerBlock (const std::string &device_name, const TransformerBlockConfig &config) | |
| Constructs a new TransformerBlock module with a device name. | |
| TransformerBlock (std::shared_ptr< DeviceContext > device_context, const TransformerBlockConfig &config) | |
| Constructs a new TransformerBlock module with a provided device context. | |
| void | forward (const Tensor< TDataType, MR > &input, Tensor< TDataType, MR > &output) |
| Performs the forward pass of the TransformerBlock. | |
| void | load (ModelArchive &archive) override |
| Deserializes the module state from a ZIP archive. | |
| size_t | parameterCount () const override |
| Gets the number of trainable parameters in this module. | |
| void | save (ModelArchive &archive) const override |
| Serializes the module state to a ZIP archive. | |
| std::string | toString () const override |
| Generates a string representation of this module's configuration. | |
Public Member Functions inherited from Mila::Dnn::CompositeModule< TDeviceType, TDataType > | |
| CompositeModule () | |
| Default constructor. | |
| CompositeModule (const std::string &device_name, const ComponentConfig &config) | |
| Constructor with device name. | |
| CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
| Constructor with device context. | |
| virtual | ~CompositeModule ()=default |
| Virtual destructor. | |
| CompositeModule & | addModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Add a named child module to this module. | |
| CompositeModule & | addModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Add an unnamed child module to this module. | |
| std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > | getModule (const std::string &name) const |
| Get a specific sub-module by name. | |
| const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getModules () const |
| Get all sub-modules contained in this module. | |
| const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getNamedModules () const |
| Get all named sub-modules contained in this module. | |
| bool | hasModule (const std::string &name) const |
| Check if a sub-module with the given name exists. | |
| bool | removeModule (const std::string &name) |
| Remove a sub-module by name. | |
| bool | replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Replace an existing sub-module with a new one. | |
| void | setTraining (bool is_training) override |
| Set the training mode for this module and all its sub-modules. | |
Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| Module (const std::string &device_name, const ComponentConfig &config) | |
| Constructor with device name. | |
| Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
| Constructor with a specific device context. | |
| virtual | ~Module ()=default |
| Virtual destructor for proper cleanup in derived classes. | |
| std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
| Get the device context for this module. | |
| Compute::DeviceType | getDeviceType () const |
| Get the device type of the current device context. | |
| std::string | getName () const |
| Get the name of the module. | |
| const auto & | getParameterTensors () const |
| Get the parameter tensors of this module. | |
| const ComputePrecision::Policy & | getPrecision () const |
| const auto & | getStateTensors () const |
| Get the state tensors of this module. | |
| bool | isTraining () const |
| Check if the module is in training mode. | |
Private Member Functions | |
| void | initializeModules () |
| Initializes the sub-modules and output tensors for the transformer block. | |
Private Attributes | |
| std::shared_ptr< MultiHeadAttention< TDeviceType, TDataType > > | attn_block_ { nullptr } |
| Multi-head self-attention block including projections. | |
| Tensor< TDataType, MR > | attn_output_ |
| Output tensor from attention block. | |
| TransformerBlockConfig | config_ |
| Configuration for the TransformerBlock module. | |
| std::shared_ptr< Dropout< TDeviceType, TDataType > > | dropout_ { nullptr } |
| Optional dropout module. | |
| std::shared_ptr< LayerNorm< TDeviceType, TDataType > > | ln_1_ { nullptr } |
| First layer normalization module. | |
| Tensor< TDataType, MR > | ln_1_output_ |
| Output tensor from first layer normalization. | |
| std::shared_ptr< LayerNorm< TDeviceType, TDataType > > | ln_2_ { nullptr } |
| Second layer normalization module. | |
| Tensor< TDataType, MR > | ln_2_output_ |
| Output tensor from second layer normalization. | |
| std::shared_ptr< MLP< TDeviceType, TDataType > > | mlp_ { nullptr } |
| Feed-forward network (MLP). | |
| Tensor< TDataType, MR > | mlp_output_ |
| Output tensor from MLP. | |
| Tensor< TDataType, MR > | res_1_output_ |
| Output tensor from first residual connection. | |
| Tensor< TDataType, MR > | res_2_output_ |
| Output tensor from second residual connection. | |
Additional Inherited Members | |
Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| const std::string | parametersToString () const |
| Helper method to convert parameters to string representation. | |
| const std::string | stateToString () const |
| Helper method to convert state tensors to string representation. | |
Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
| Map of parameter names to parameter tensors. | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
| Map of state names to state tensors. | |
TransformerBlock implements a standard transformer encoder block.
The transformer block consists of:
This is the fundamental building block of transformer architectures like BERT and GPT. The implementation supports both pre-LN (more stable) and post-LN (original) architectures, configurable dropout rates, and other hyperparameters.
| TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
| TDataType | The data type used for tensor elements throughout the network. |
|
export |
Alias for base module type.
|
export |
Memory resource type used for tensors, selected based on device type.
|
inlineexplicitexport |
Constructs a new TransformerBlock module with a device name.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
| device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
| config | Configuration parameters for the TransformerBlock module. |
| std::invalid_argument | If the device name is invalid or the configuration is invalid |
| std::runtime_error | If device type doesn't match template parameter TDeviceType |

|
inlineexplicitexport |
Constructs a new TransformerBlock module with a provided device context.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
| device_context | The device context to use for this module. |
| config | Configuration parameters for the TransformerBlock module. |
| std::invalid_argument | If device_context is null or configuration is invalid |
| std::runtime_error | If device context type doesn't match template parameter TDeviceType |

|
inlineexport |
Performs the forward pass of the TransformerBlock.
The forward pass follows either pre-LN or post-LN architecture based on configuration:
Pre-LN (default):
Post-LN:
| input | The input tensor to be processed. |
| output | The output tensor where the results will be stored. |

|
inlineexportprivate |
Initializes the sub-modules and output tensors for the transformer block.
Creates and configures all components of the transformer block according to the configuration, including layer norm, attention, and feed-forward network.


|
inlineoverrideexportvirtual |
Deserializes the module state from a ZIP archive.
Loads the state of all sub-modules from the provided ZIP archive.
| zip | The ZIP archive to load the module state from. |
Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

|
inlineoverrideexportvirtual |
Gets the number of trainable parameters in this module.
Counts the total number of parameters in all sub-modules.
Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.


|
inlineoverrideexportvirtual |
Serializes the module state to a ZIP archive.
Saves the state of all sub-modules to the provided ZIP archive.
| zip | The ZIP archive to save the module state to. |
Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

|
inlineoverrideexportvirtual |
Generates a string representation of this module's configuration.
Reimplemented from Mila::Dnn::CompositeModule< TDeviceType, TDataType >.

|
exportprivate |
Multi-head self-attention block including projections.
|
exportprivate |
Output tensor from attention block.
|
exportprivate |
Configuration for the TransformerBlock module.
|
exportprivate |
Optional dropout module.
|
exportprivate |
First layer normalization module.
In pre-LN architecture, applied before attention. In post-LN architecture, applied after attention and residual connection.
|
exportprivate |
Output tensor from first layer normalization.
|
exportprivate |
|
exportprivate |
Output tensor from second layer normalization.
|
exportprivate |
Feed-forward network (MLP).
|
exportprivate |
Output tensor from MLP.
|
exportprivate |
Output tensor from first residual connection.
|
exportprivate |
Output tensor from second residual connection.