|
| TransformerBlock (const std::string &device_name, const TransformerBlockConfig &config) |
| Constructs a new TransformerBlock module with a device name.
|
|
| TransformerBlock (std::shared_ptr< DeviceContext > device_context, const TransformerBlockConfig &config) |
| Constructs a new TransformerBlock module with a provided device context.
|
|
void | forward (const Tensor< TDataType, MR > &input, Tensor< TDataType, MR > &output) |
| Performs the forward pass of the TransformerBlock.
|
|
void | load (ModelArchive &archive) override |
| Deserializes the module state from a ZIP archive.
|
|
size_t | parameterCount () const override |
| Gets the number of trainable parameters in this module.
|
|
void | save (ModelArchive &archive) const override |
| Serializes the module state to a ZIP archive.
|
|
std::string | toString () const override |
| Generates a string representation of this module's configuration.
|
|
| CompositeModule () |
| Default constructor.
|
|
| CompositeModule (const std::string &device_name, const ComponentConfig &config) |
| Constructor with device name.
|
|
| CompositeModule (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) |
| Constructor with device context.
|
|
virtual | ~CompositeModule ()=default |
| Virtual destructor.
|
|
CompositeModule & | addModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Add a named child module to this module.
|
|
CompositeModule & | addModule (std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Add an unnamed child module to this module.
|
|
std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > | getModule (const std::string &name) const |
| Get a specific sub-module by name.
|
|
const std::vector< std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getModules () const |
| Get all sub-modules contained in this module.
|
|
const std::unordered_map< std::string, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > > & | getNamedModules () const |
| Get all named sub-modules contained in this module.
|
|
bool | hasModule (const std::string &name) const |
| Check if a sub-module with the given name exists.
|
|
bool | removeModule (const std::string &name) |
| Remove a sub-module by name.
|
|
bool | replaceModule (const std::string &name, std::shared_ptr< Module< TDeviceType, TDataType, TDataType > > module) |
| Replace an existing sub-module with a new one.
|
|
void | setTraining (bool is_training) override |
| Set the training mode for this module and all its sub-modules.
|
|
| Module (const std::string &device_name, const ComponentConfig &config) |
| Constructor with device name.
|
|
| Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) |
| Constructor with a specific device context.
|
|
virtual | ~Module ()=default |
| Virtual destructor for proper cleanup in derived classes.
|
|
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
| Get the device context for this module.
|
|
Compute::DeviceType | getDeviceType () const |
| Get the device type of the current device context.
|
|
std::string | getName () const |
| Get the name of the module.
|
|
const auto & | getParameterTensors () const |
| Get the parameter tensors of this module.
|
|
const ComputePrecision::Policy & | getPrecision () const |
|
const auto & | getStateTensors () const |
| Get the state tensors of this module.
|
|
bool | isTraining () const |
| Check if the module is in training mode.
|
|
|
std::shared_ptr< MultiHeadAttention< TDeviceType, TDataType > > | attn_block_ { nullptr } |
| Multi-head self-attention block including projections.
|
|
Tensor< TDataType, MR > | attn_output_ |
| Output tensor from attention block.
|
|
TransformerBlockConfig | config_ |
| Configuration for the TransformerBlock module.
|
|
std::shared_ptr< Dropout< TDeviceType, TDataType > > | dropout_ { nullptr } |
| Optional dropout module.
|
|
std::shared_ptr< LayerNorm< TDeviceType, TDataType > > | ln_1_ { nullptr } |
| First layer normalization module.
|
|
Tensor< TDataType, MR > | ln_1_output_ |
| Output tensor from first layer normalization.
|
|
std::shared_ptr< LayerNorm< TDeviceType, TDataType > > | ln_2_ { nullptr } |
| Second layer normalization module.
|
|
Tensor< TDataType, MR > | ln_2_output_ |
| Output tensor from second layer normalization.
|
|
std::shared_ptr< MLP< TDeviceType, TDataType > > | mlp_ { nullptr } |
| Feed-forward network (MLP).
|
|
Tensor< TDataType, MR > | mlp_output_ |
| Output tensor from MLP.
|
|
Tensor< TDataType, MR > | res_1_output_ |
| Output tensor from first residual connection.
|
|
Tensor< TDataType, MR > | res_2_output_ |
| Output tensor from second residual connection.
|
|
template<
DeviceType TDeviceType = DeviceType::Cuda, typename TDataType = float>
requires ValidFloatTensorType<TDataType>
class Mila::Dnn::TransformerBlock< TDeviceType, TDataType >
TransformerBlock implements a standard transformer encoder block.
The transformer block consists of:
- Multi-head self-attention mechanism with residual connection
- Feed-forward network (MLP) with residual connection
- Layer normalization before or after each sub-block (configurable)
This is the fundamental building block of transformer architectures like BERT and GPT. The implementation supports both pre-LN (more stable) and post-LN (original) architectures, configurable dropout rates, and other hyperparameters.
- Template Parameters
-
TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
TDataType | The data type used for tensor elements throughout the network. |