Mila
Deep Neural Network Library
|
Multi-head attention module for transformer architectures. More...
Public Types | |
using | ModuleBase = Module< TDeviceType, TInput, TOutput > |
Alias for base module type. | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Memory resource type used for tensors, selected based on device type. | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
MultiHeadAttention (const std::string &device_name, const MultiHeadAttentionConfig &config) | |
Constructs a new MultiHeadAttention module with a device name. | |
MultiHeadAttention (std::shared_ptr< DeviceContext > device_context, const MultiHeadAttentionConfig &config) | |
Constructs a new MultiHeadAttention module with a provided device context. | |
void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad) |
Performs the backward pass of the MultiHeadAttention operation. | |
void | forward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &mask, Tensor< TOutput, MR > &output) |
Performs the forward pass with an explicit attention mask. | |
void | forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output) |
Performs the forward pass of the MultiHeadAttention operation. | |
float | getDropout () const |
Gets the dropout rate. | |
size_t | getEmbeddingDim () const |
Gets the embedding dimension. | |
const std::vector< size_t > & | getInputShape () const |
Gets the input shape. | |
size_t | getNumHeads () const |
Gets the number of attention heads. | |
void | load (ModelArchive &archive) override |
Loads the module state from a ZIP archive. | |
size_t | parameterCount () const override |
Gets the number of trainable parameters in this module. | |
void | save (ModelArchive &zip) const override |
Saves the module state to a ZIP archive. | |
std::string | toString () const override |
Generates a string representation of this module's configuration. | |
bool | useCausalMask () const |
Checks if causal masking is enabled. | |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
Private Member Functions | |
void | createOperation () |
Creates the appropriate attention operation for the current device. | |
void | initializeTensors () |
Initializes the tensors needed for attention computation. | |
Private Attributes | |
std::shared_ptr< Tensor< TOutput, MR > > | attn_ { nullptr } |
Attention weight tensor from the forward pass. | |
MultiHeadAttentionConfig | config_ |
Configuration for the MultiHeadAttention module. | |
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > | operation_ { nullptr } |
The operation that implements the attention mechanism. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | output_state_ |
Collection of output state tensors for caching. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | parameters_ |
Collection of parameters for this module. | |
std::shared_ptr< Tensor< TOutput, MR > > | pre_attn_ { nullptr } |
Pre-softmax attention scores from the forward pass. | |
OperationAttributes | properties_ |
Operation attributes and configuration. | |
Additional Inherited Members | |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
Multi-head attention module for transformer architectures.
This module implements the multi-head attention mechanism, which allows different parts of the input to attend to different parts of the sequence. This is a core component of transformer architectures.
The attention mechanism computes: Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))V where Q, K, and V are the query, key, and value projections of the input.
Multi-head attention projects the input into multiple subspaces, computes attention independently in each subspace, then concatenates the results to form the output.
TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
TInput | The data type of the input tensor elements. |
TOutput | The data type of the output tensor elements, defaults to TInput. |
|
export |
Alias for base module type.
|
export |
Memory resource type used for tensors, selected based on device type.
|
inlineexplicitexport |
Constructs a new MultiHeadAttention module with a device name.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
config | Configuration parameters for the MultiHeadAttention module. |
std::invalid_argument | If the device name is invalid or the configuration is invalid |
std::runtime_error | If device type doesn't match template parameter TDeviceType |
|
inlineexplicitexport |
Constructs a new MultiHeadAttention module with a provided device context.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
device_context | The device context to use for this module. |
config | Configuration parameters for the MultiHeadAttention module. |
std::invalid_argument | If device_context is null or configuration is invalid |
std::runtime_error | If device context type doesn't match template parameter TDeviceType |
|
inlineexport |
Performs the backward pass of the MultiHeadAttention operation.
input | The input tensor from the forward pass. |
output_grad | The gradient of loss with respect to the output. |
input_grad | The tensor to store gradients with respect to input. |
|
inlineexportprivate |
Creates the appropriate attention operation for the current device.
Instantiates either a CPU or CUDA attention operation based on the device type.
|
inlineexport |
Performs the forward pass with an explicit attention mask.
input | The input tensor to be processed. |
mask | The attention mask tensor (0s for masked positions). |
output | The output tensor where the results will be stored. |
|
inlineexport |
Performs the forward pass of the MultiHeadAttention operation.
input | The input tensor to be processed. |
output | The output tensor where the results will be stored. |
|
inlineexport |
Gets the dropout rate.
|
inlineexport |
Gets the embedding dimension.
|
inlineexport |
Gets the input shape.
|
inlineexport |
Gets the number of attention heads.
|
inlineexportprivate |
Initializes the tensors needed for attention computation.
Creates and initializes intermediate tensors used during the attention computation, including attention weights and pre-softmax scores.
|
inlineoverrideexportvirtual |
Loads the module state from a ZIP archive.
Implementation of the Module interface for deserialization. Currently a no-op in the base implementation as there are no parameters to load.
zip | ZIP archive for deserialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Gets the number of trainable parameters in this module.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Saves the module state to a ZIP archive.
Implementation of the Module interface for serialization. Currently a no-op in the base implementation as there are no parameters to save.
zip | ZIP archive for serialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Generates a string representation of this module's configuration.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineexport |
Checks if causal masking is enabled.
|
exportprivate |
Attention weight tensor from the forward pass.
Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the attention weights between all token pairs.
|
exportprivate |
Configuration for the MultiHeadAttention module.
|
exportprivate |
The operation that implements the attention mechanism.
|
exportprivate |
Collection of output state tensors for caching.
|
exportprivate |
Collection of parameters for this module.
|
exportprivate |
Pre-softmax attention scores from the forward pass.
Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the raw attention scores before softmax normalization.
|
exportprivate |
Operation attributes and configuration.