CUDA implementation of the Multi-Head Attention operation for transformer models.
More...
|
| CudaMultiHeadAttentionOp (const MultiHeadAttentionConfig &config) |
| Constructs a new CUDA Multi-Head Attention operation with the default device context.
|
|
| CudaMultiHeadAttentionOp (std::shared_ptr< DeviceContext > context, const MultiHeadAttentionConfig &config) |
| Constructs a new CUDA Multi-Head Attention operation with a specific device context.
|
|
void | backward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &output, const Tensor< TInput, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TInput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TInput, MR > > > ¶meter_gradients, Tensor< TInput, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TInput, MR > > > &output_state) const |
| Performs the backward pass of the Multi-Head Attention operation.
|
|
void | forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const override |
| Performs the forward pass of the Multi-Head Attention operation on CUDA.
|
|
std::string | getName () const override |
| Gets the name of this operation.
|
|
| UnaryOperation (OperationType operation_type) |
| Constructs a UnaryOperation with the specified operation type.
|
|
| UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs a UnaryOperation with the specified operation type and device context.
|
|
virtual | ~UnaryOperation ()=default |
| Virtual destructor for proper cleanup of derived classes.
|
|
virtual void | backward (const Tensor< TInput, MR > &grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_grads) const |
| Executes the backward pass of a unary operation.
|
|
virtual void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meter_grads, Tensor< TInput, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const |
| Executes the comprehensive backward pass of a unary operation.
|
|
virtual void | forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0 |
| Executes the forward pass of a unary operation.
|
|
| OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs an OperationBase object with a specific device context and compute precision.
|
|
virtual | ~OperationBase ()=default |
| Virtual destructor for the OperationBase class.
|
|
std::shared_ptr< DeviceContext > | getDeviceContext () const |
| Gets the device context associated with this operation.
|
|
DeviceType | getDeviceType () const |
| Gets the device type for this operation.
|
|
OperationType | getOperationType () const |
| Gets the operation type enumeration value.
|
|
template<typename TInput = float, typename TOutput = TInput>
requires ValidFloatTensorTypes<TInput, TOutput>
class Mila::Dnn::Compute::CudaMultiHeadAttentionOp< TInput, TOutput >
CUDA implementation of the Multi-Head Attention operation for transformer models.
This class provides a CUDA-based implementation of the Multi-Head Attention operation, which is a key component of transformer architectures. The operation allows the model to jointly attend to information from different representation subspaces at different positions.
Multi-Head Attention consists of several attention mechanisms operating in parallel:
- Linear projections of the input into query, key, and value vectors
- Scaled dot-product attention computation between queries and keys
- Applying attention weights to values
- Concatenation of attention outputs from different heads
The implementation is optimized for NVIDIA GPUs using CUDA for high-performance computation.
- Template Parameters
-
TInput | The data type of the input tensor elements. |
TDataType | The data type of the output tensor elements (defaults to the input type). |
template<typename TInput = float, typename TOutput = TInput>
Performs the forward pass of the Multi-Head Attention operation on CUDA.
Computes attention scores, applies softmax to get attention weights, and uses these weights to compute a weighted sum of value vectors. This process is performed in parallel for multiple attention heads, then outputs are concatenated and projected.
The computation is performed on the GPU using CUDA kernels for optimal performance.
- Parameters
-
input | Input tensor of shape [B, TDataType, C] containing the input sequence, where B is batch size, TDataType is sequence length, and C is the input feature dimension. |
parameters | Vector of parameter tensors [weight, bias], where weight contains the query, key, value projections and output projection, and bias contains the corresponding biases. |
properties | Additional attributes for the operation, such as number of attention heads. |
output | Output tensor of shape [B, TDataType, OC] containing the attention output, where OC is the output feature dimension. |
output_state | Intermediate results like attention scores and weights for potential use in backward pass or visualization. |