Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput > Class Template Referenceexport

Multi-head attention module for transformer architectures. More...

Inheritance diagram for Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >:
Collaboration diagram for Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >:

Public Types

using ModuleBase = Module< TDeviceType, TInput, TOutput >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 MultiHeadAttention (const std::string &device_name, const MultiHeadAttentionConfig &config)
 Constructs a new MultiHeadAttention module with a device name.
 
 MultiHeadAttention (std::shared_ptr< DeviceContext > device_context, const MultiHeadAttentionConfig &config)
 Constructs a new MultiHeadAttention module with a provided device context.
 
void backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad)
 Performs the backward pass of the MultiHeadAttention operation.
 
void forward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &mask, Tensor< TOutput, MR > &output)
 Performs the forward pass with an explicit attention mask.
 
void forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output)
 Performs the forward pass of the MultiHeadAttention operation.
 
float getDropout () const
 Gets the dropout rate.
 
size_t getEmbeddingDim () const
 Gets the embedding dimension.
 
const std::vector< size_t > & getInputShape () const
 Gets the input shape.
 
size_t getNumHeads () const
 Gets the number of attention heads.
 
void load (ModelArchive &archive) override
 Loads the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &zip) const override
 Saves the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 
bool useCausalMask () const
 Checks if causal masking is enabled.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 
virtual void setTraining (bool is_training)
 Set the training mode of this module.
 

Private Member Functions

void createOperation ()
 Creates the appropriate attention operation for the current device.
 
void initializeTensors ()
 Initializes the tensors needed for attention computation.
 

Private Attributes

std::shared_ptr< Tensor< TOutput, MR > > attn_ { nullptr }
 Attention weight tensor from the forward pass.
 
MultiHeadAttentionConfig config_
 Configuration for the MultiHeadAttention module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > operation_ { nullptr }
 The operation that implements the attention mechanism.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > output_state_
 Collection of output state tensors for caching.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameters_
 Collection of parameters for this module.
 
std::shared_ptr< Tensor< TOutput, MR > > pre_attn_ { nullptr }
 Pre-softmax attention scores from the forward pass.
 
OperationAttributes properties_
 Operation attributes and configuration.
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
requires ValidTensorTypes<TInput, TOutput>
class Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >

Multi-head attention module for transformer architectures.

This module implements the multi-head attention mechanism, which allows different parts of the input to attend to different parts of the sequence. This is a core component of transformer architectures.

The attention mechanism computes: Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))V where Q, K, and V are the query, key, and value projections of the input.

Multi-head attention projects the input into multiple subspaces, computes attention independently in each subspace, then concatenates the results to form the output.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which to perform computations.
TInputThe data type of the input tensor elements.
TOutputThe data type of the output tensor elements, defaults to TInput.

Member Typedef Documentation

◆ ModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::ModuleBase = Module<TDeviceType, TInput, TOutput>
export

Alias for base module type.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>
export

Memory resource type used for tensors, selected based on device type.

Constructor & Destructor Documentation

◆ MultiHeadAttention() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::MultiHeadAttention ( const std::string &  device_name,
const MultiHeadAttentionConfig config 
)
inlineexplicitexport

Constructs a new MultiHeadAttention module with a device name.

Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
configConfiguration parameters for the MultiHeadAttention module.
Exceptions
std::invalid_argumentIf the device name is invalid or the configuration is invalid
std::runtime_errorIf device type doesn't match template parameter TDeviceType
Here is the call graph for this function:

◆ MultiHeadAttention() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::MultiHeadAttention ( std::shared_ptr< DeviceContext device_context,
const MultiHeadAttentionConfig config 
)
inlineexplicitexport

Constructs a new MultiHeadAttention module with a provided device context.

Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.

Parameters
device_contextThe device context to use for this module.
configConfiguration parameters for the MultiHeadAttention module.
Exceptions
std::invalid_argumentIf device_context is null or configuration is invalid
std::runtime_errorIf device context type doesn't match template parameter TDeviceType
Here is the call graph for this function:

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::backward ( const Tensor< TInput, MR > &  input,
const Tensor< TOutput, MR > &  output_grad,
Tensor< TInput, MR > &  input_grad 
)
inlineexport

Performs the backward pass of the MultiHeadAttention operation.

Parameters
inputThe input tensor from the forward pass.
output_gradThe gradient of loss with respect to the output.
input_gradThe tensor to store gradients with respect to input.

◆ createOperation()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::createOperation ( )
inlineexportprivate

Creates the appropriate attention operation for the current device.

Instantiates either a CPU or CUDA attention operation based on the device type.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ forward() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::forward ( const Tensor< TInput, MR > &  input,
const Tensor< TInput, MR > &  mask,
Tensor< TOutput, MR > &  output 
)
inlineexport

Performs the forward pass with an explicit attention mask.

Parameters
inputThe input tensor to be processed.
maskThe attention mask tensor (0s for masked positions).
outputThe output tensor where the results will be stored.
Here is the call graph for this function:

◆ forward() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::forward ( const Tensor< TInput, MR > &  input,
Tensor< TOutput, MR > &  output 
)
inlineexport

Performs the forward pass of the MultiHeadAttention operation.

Parameters
inputThe input tensor to be processed.
outputThe output tensor where the results will be stored.

◆ getDropout()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
float Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::getDropout ( ) const
inlineexport

Gets the dropout rate.

Returns
float The dropout rate.
Here is the call graph for this function:

◆ getEmbeddingDim()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
size_t Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::getEmbeddingDim ( ) const
inlineexport

Gets the embedding dimension.

Returns
size_t The embedding dimension.
Here is the call graph for this function:

◆ getInputShape()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
const std::vector< size_t > & Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::getInputShape ( ) const
inlineexport

Gets the input shape.

Returns
const std::vector<size_t>& The input shape.
Here is the call graph for this function:

◆ getNumHeads()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
size_t Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::getNumHeads ( ) const
inlineexport

Gets the number of attention heads.

Returns
size_t The number of attention heads.
Here is the call graph for this function:

◆ initializeTensors()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::initializeTensors ( )
inlineexportprivate

Initializes the tensors needed for attention computation.

Creates and initializes intermediate tensors used during the attention computation, including attention weights and pre-softmax scores.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::load ( ModelArchive archive)
inlineoverrideexportvirtual

Loads the module state from a ZIP archive.

Implementation of the Module interface for deserialization. Currently a no-op in the base implementation as there are no parameters to load.

Parameters
zipZIP archive for deserialization

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
size_t Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::parameterCount ( ) const
inlineoverrideexportvirtual

Gets the number of trainable parameters in this module.

Returns
size_t The total number of parameters.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::save ( ModelArchive zip) const
inlineoverrideexportvirtual

Saves the module state to a ZIP archive.

Implementation of the Module interface for serialization. Currently a no-op in the base implementation as there are no parameters to save.

Parameters
zipZIP archive for serialization

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::string Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::toString ( ) const
inlineoverrideexportvirtual

Generates a string representation of this module's configuration.

Returns
std::string A formatted string with module information

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Here is the call graph for this function:

◆ useCausalMask()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
bool Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::useCausalMask ( ) const
inlineexport

Checks if causal masking is enabled.

Returns
bool True if causal masking is enabled, false otherwise.
Here is the call graph for this function:

Member Data Documentation

◆ attn_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::shared_ptr<Tensor<TOutput, MR> > Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::attn_ { nullptr }
exportprivate

Attention weight tensor from the forward pass.

Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the attention weights between all token pairs.

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
MultiHeadAttentionConfig Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::config_
exportprivate

Configuration for the MultiHeadAttention module.

◆ operation_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::shared_ptr<UnaryOperation<TDeviceType, TInput, TOutput> > Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::operation_ { nullptr }
exportprivate

The operation that implements the attention mechanism.

◆ output_state_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::vector<std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::output_state_
exportprivate

Collection of output state tensors for caching.

◆ parameters_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::vector<std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::parameters_
exportprivate

Collection of parameters for this module.

◆ pre_attn_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
std::shared_ptr<Tensor<TOutput, MR> > Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::pre_attn_ { nullptr }
exportprivate

Pre-softmax attention scores from the forward pass.

Shape: [batch_size, num_heads, sequence_length, sequence_length] Stores the raw attention scores before softmax normalization.

◆ properties_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
OperationAttributes Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >::properties_
exportprivate

Operation attributes and configuration.


The documentation for this class was generated from the following file: