Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput > Class Template Referenceexport

CUDA implementation of the Multi-Head Attention operation for transformer models. More...

Inheritance diagram for Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >:
Collaboration diagram for Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >:

Public Types

using MR = typename CudaDevice::MR
 
using UnaryOperationBase = UnaryOperation< DeviceType::Cuda, TInput, TOutput >
 
- Public Types inherited from Mila::Dnn::Compute::UnaryOperation< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 

Public Member Functions

 CudaMultiHeadAttentionOp (const MultiHeadAttentionConfig &config)
 Constructs a new CUDA Multi-Head Attention operation with the default device context.
 
 CudaMultiHeadAttentionOp (std::shared_ptr< DeviceContext > context, const MultiHeadAttentionConfig &config)
 Constructs a new CUDA Multi-Head Attention operation with a specific device context.
 
void backward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &output, const Tensor< TInput, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TInput, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TInput, MR > > > &parameter_gradients, Tensor< TInput, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TInput, MR > > > &output_state) const
 Performs the backward pass of the Multi-Head Attention operation.
 
void forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const override
 Performs the forward pass of the Multi-Head Attention operation on CUDA.
 
std::string getName () const override
 Gets the name of this operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< TDeviceType, TInput, TOutput >
 UnaryOperation (OperationType operation_type)
 Constructs a UnaryOperation with the specified operation type.
 
 UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs a UnaryOperation with the specified operation type and device context.
 
virtual ~UnaryOperation ()=default
 Virtual destructor for proper cleanup of derived classes.
 
virtual void backward (const Tensor< TInput, MR > &grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_grads) const
 Executes the backward pass of a unary operation.
 
virtual void backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameter_grads, Tensor< TInput, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const
 Executes the comprehensive backward pass of a unary operation.
 
virtual void forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0
 Executes the forward pass of a unary operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >
 OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs an OperationBase object with a specific device context and compute precision.
 
virtual ~OperationBase ()=default
 Virtual destructor for the OperationBase class.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Gets the device context associated with this operation.
 
DeviceType getDeviceType () const
 Gets the device type for this operation.
 
OperationType getOperationType () const
 Gets the operation type enumeration value.
 

Private Attributes

MultiHeadAttentionConfig config_
 Configuration for the MHA operation.
 

Detailed Description

template<typename TInput = float, typename TOutput = TInput>
requires ValidFloatTensorTypes<TInput, TOutput>
class Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >

CUDA implementation of the Multi-Head Attention operation for transformer models.

This class provides a CUDA-based implementation of the Multi-Head Attention operation, which is a key component of transformer architectures. The operation allows the model to jointly attend to information from different representation subspaces at different positions.

Multi-Head Attention consists of several attention mechanisms operating in parallel:

  1. Linear projections of the input into query, key, and value vectors
  2. Scaled dot-product attention computation between queries and keys
  3. Applying attention weights to values
  4. Concatenation of attention outputs from different heads

The implementation is optimized for NVIDIA GPUs using CUDA for high-performance computation.

Template Parameters
TInputThe data type of the input tensor elements.
TDataTypeThe data type of the output tensor elements (defaults to the input type).

Member Typedef Documentation

◆ MR

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::MR = typename CudaDevice::MR

◆ UnaryOperationBase

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TInput, TOutput>

Constructor & Destructor Documentation

◆ CudaMultiHeadAttentionOp() [1/2]

template<typename TInput = float, typename TOutput = TInput>
Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::CudaMultiHeadAttentionOp ( const MultiHeadAttentionConfig config)
inline

Constructs a new CUDA Multi-Head Attention operation with the default device context.

Initializes the operation with a CUDA device context (defaults to CUDA:0).

◆ CudaMultiHeadAttentionOp() [2/2]

template<typename TInput = float, typename TOutput = TInput>
Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::CudaMultiHeadAttentionOp ( std::shared_ptr< DeviceContext context,
const MultiHeadAttentionConfig config 
)
inline

Constructs a new CUDA Multi-Head Attention operation with a specific device context.

Parameters
contextThe device context to use for this operation.
Exceptions
std::runtime_errorIf the context is not for a CUDA device.

Member Function Documentation

◆ backward()

template<typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::backward ( const Tensor< TInput, MR > &  input,
const Tensor< TInput, MR > &  output,
const Tensor< TInput, MR > &  output_gradient,
const std::vector< std::shared_ptr< Tensor< TInput, MR > > > &  parameters,
std::vector< std::shared_ptr< Tensor< TInput, MR > > > &  parameter_gradients,
Tensor< TInput, MR > &  input_gradient,
const OperationAttributes properties,
const std::vector< std::shared_ptr< Tensor< TInput, MR > > > &  output_state 
) const
inline

Performs the backward pass of the Multi-Head Attention operation.

Computes gradients with respect to inputs, weights, and biases.

Parameters
inputInput tensor from the forward pass.
outputOutput tensor from the forward pass.
output_gradientGradient of the loss with respect to the output.
parametersParameters tensor from forward pass [weight, bias].
parameter_gradientsGradients for parameters [d_weight, d_bias].
input_gradientGradient of the loss with respect to the input.
propertiesAdditional attributes for the operation.
output_stateCache tensors from forward pass (attention scores and weights).
Here is the call graph for this function:

◆ forward()

template<typename TInput = float, typename TOutput = TInput>
void Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::forward ( const Tensor< TInput, MR > &  input,
const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &  parameters,
const OperationAttributes properties,
Tensor< TOutput, MR > &  output,
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &  output_state 
) const
inlineoverride

Performs the forward pass of the Multi-Head Attention operation on CUDA.

Computes attention scores, applies softmax to get attention weights, and uses these weights to compute a weighted sum of value vectors. This process is performed in parallel for multiple attention heads, then outputs are concatenated and projected.

The computation is performed on the GPU using CUDA kernels for optimal performance.

Parameters
inputInput tensor of shape [B, TDataType, C] containing the input sequence, where B is batch size, TDataType is sequence length, and C is the input feature dimension.
parametersVector of parameter tensors [weight, bias], where weight contains the query, key, value projections and output projection, and bias contains the corresponding biases.
propertiesAdditional attributes for the operation, such as number of attention heads.
outputOutput tensor of shape [B, TDataType, OC] containing the attention output, where OC is the output feature dimension.
output_stateIntermediate results like attention scores and weights for potential use in backward pass or visualization.
Here is the call graph for this function:

◆ getName()

template<typename TInput = float, typename TOutput = TInput>
std::string Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::getName ( ) const
inlineoverridevirtual

Gets the name of this operation.

Returns
std::string The name of the operation ("Cuda::MultiHeadAttentionOp").

Implements Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >.

Member Data Documentation

◆ config_

template<typename TInput = float, typename TOutput = TInput>
MultiHeadAttentionConfig Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >::config_
private

Configuration for the MHA operation.


The documentation for this class was generated from the following file: