|
Mila
Deep Neural Network Library
|
Multi-head attention module for transformer architectures. More...


Public Types | |
| using | ModuleBase = Module< TDeviceType, TInput, TOutput > |
| Alias for base module type. | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
| Memory resource type used for tensors, selected based on device type. | |
Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
| MultiHeadAttention (const std::string &device_name, const MultiHeadAttentionConfig &config) | |
| Constructs a new MultiHeadAttention module with a device name. | |
| MultiHeadAttention (std::shared_ptr< DeviceContext > device_context, const MultiHeadAttentionConfig &config) | |
| Constructs a new MultiHeadAttention module with a provided device context. | |
| void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad) |
| Performs the backward pass of the MultiHeadAttention operation. | |
| void | forward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &mask, Tensor< TOutput, MR > &output) |
| Performs the forward pass with an explicit attention mask. | |
| void | forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output) |
| Performs the forward pass of the MultiHeadAttention operation. | |
| float | getDropout () const |
| Gets the dropout rate. | |
| size_t | getEmbeddingDim () const |
| Gets the embedding dimension. | |
| const std::vector< size_t > & | getInputShape () const |
| Gets the input shape. | |
| size_t | getNumHeads () const |
| Gets the number of attention heads. | |
| void | load (ModelArchive &archive) override |
| Loads the module state from a ZIP archive. | |
| size_t | parameterCount () const override |
| Gets the number of trainable parameters in this module. | |
| void | save (ModelArchive &zip) const override |
| Saves the module state to a ZIP archive. | |
| std::string | toString () const override |
| Generates a string representation of this module's configuration. | |
| bool | useCausalMask () const |
| Checks if causal masking is enabled. | |
Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| Module (const std::string &device_name, const ComponentConfig &config) | |
| Constructor with device name. | |
| Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
| Constructor with a specific device context. | |
| virtual | ~Module ()=default |
| Virtual destructor for proper cleanup in derived classes. | |
| std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
| Get the device context for this module. | |
| Compute::DeviceType | getDeviceType () const |
| Get the device type of the current device context. | |
| std::string | getName () const |
| Get the name of the module. | |
| const auto & | getParameterTensors () const |
| Get the parameter tensors of this module. | |
| const ComputePrecision::Policy & | getPrecision () const |
| const auto & | getStateTensors () const |
| Get the state tensors of this module. | |
| bool | isTraining () const |
| Check if the module is in training mode. | |
| virtual void | setTraining (bool is_training) |
| Set the training mode of this module. | |
Private Member Functions | |
| void | createOperation () |
| Creates the appropriate attention operation for the current device. | |
| void | initializeTensors () |
| Initializes the tensors needed for attention computation. | |
Private Attributes | |
| std::shared_ptr< Tensor< TOutput, MR > > | attn_ { nullptr } |
| Attention weight tensor from the forward pass. | |
| MultiHeadAttentionConfig | config_ |
| Configuration for the MultiHeadAttention module. | |
| std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > | operation_ { nullptr } |
| The operation that implements the attention mechanism. | |
| std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | output_state_ |
| Collection of output state tensors for caching. | |
| std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | parameters_ |
| Collection of parameters for this module. | |
| std::shared_ptr< Tensor< TOutput, MR > > | pre_attn_ { nullptr } |
| Pre-softmax attention scores from the forward pass. | |
| OperationAttributes | properties_ |
| Operation attributes and configuration. | |
Additional Inherited Members | |
Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| const std::string | parametersToString () const |
| Helper method to convert parameters to string representation. | |
| const std::string | stateToString () const |
| Helper method to convert state tensors to string representation. | |
Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
| Map of parameter names to parameter tensors. | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
| Map of state names to state tensors. | |
Multi-head attention module for transformer architectures.
This module implements the multi-head attention mechanism, which allows different parts of the input to attend to different parts of the sequence. This is a core component of transformer architectures.
The attention mechanism computes: Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))V where Q, K, and V are the query, key, and value projections of the input.
Multi-head attention projects the input into multiple subspaces, computes attention independently in each subspace, then concatenates the results to form the output.
| TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
| TInput | The data type of the input tensor elements. |
| TOutput | The data type of the output tensor elements, defaults to TInput. |
|
export |
Alias for base module type.
|
export |
Memory resource type used for tensors, selected based on device type.
|
inlineexplicitexport |
Constructs a new MultiHeadAttention module with a device name.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
| device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
| config | Configuration parameters for the MultiHeadAttention module. |
| std::invalid_argument | If the device name is invalid or the configuration is invalid |
| std::runtime_error | If device type doesn't match template parameter TDeviceType |

|
inlineexplicitexport |
Constructs a new MultiHeadAttention module with a provided device context.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
| device_context | The device context to use for this module. |
| config | Configuration parameters for the MultiHeadAttention module. |
| std::invalid_argument | If device_context is null or configuration is invalid |
| std::runtime_error | If device context type doesn't match template parameter TDeviceType |

|
inlineexport |
Performs the backward pass of the MultiHeadAttention operation.
| input | The input tensor from the forward pass. |
| output_grad | The gradient of loss with respect to the output. |
| input_grad | The tensor to store gradients with respect to input. |
|
inlineexportprivate |
Creates the appropriate attention operation for the current device.
Instantiates either a CPU or CUDA attention operation based on the device type.


|
inlineexport |
Performs the forward pass with an explicit attention mask.
| input | The input tensor to be processed. |
| mask | The attention mask tensor (0s for masked positions). |
| output | The output tensor where the results will be stored. |

|
inlineexport |
Performs the forward pass of the MultiHeadAttention operation.
| input | The input tensor to be processed. |
| output | The output tensor where the results will be stored. |
|
inlineexport |
Gets the dropout rate.

|
inlineexport |
Gets the embedding dimension.

|
inlineexport |
Gets the input shape.

|
inlineexport |
Gets the number of attention heads.

|
inlineexportprivate |
Initializes the tensors needed for attention computation.
Creates and initializes intermediate tensors used during the attention computation, including attention weights and pre-softmax scores.


|
inlineoverrideexportvirtual |
Loads the module state from a ZIP archive.
Implementation of the Module interface for deserialization. Currently a no-op in the base implementation as there are no parameters to load.
| zip | ZIP archive for deserialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Gets the number of trainable parameters in this module.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Saves the module state to a ZIP archive.
Implementation of the Module interface for serialization. Currently a no-op in the base implementation as there are no parameters to save.
| zip | ZIP archive for serialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Generates a string representation of this module's configuration.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

|
inlineexport |
Checks if causal masking is enabled.

|
exportprivate |
Attention weight tensor from the forward pass.
Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the attention weights between all token pairs.
|
exportprivate |
Configuration for the MultiHeadAttention module.
|
exportprivate |
The operation that implements the attention mechanism.
|
exportprivate |
Collection of output state tensors for caching.
|
exportprivate |
Collection of parameters for this module.
|
exportprivate |
Pre-softmax attention scores from the forward pass.
Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the raw attention scores before softmax normalization.
|
exportprivate |
Operation attributes and configuration.