Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::CpuMultiHeadAttentionOp Class Referenceexport

CPU implementation of the Multi-Head Attention operation for neural networks. More...

Inheritance diagram for Mila::Dnn::Compute::CpuMultiHeadAttentionOp:
Collaboration diagram for Mila::Dnn::Compute::CpuMultiHeadAttentionOp:

Public Types

using MR = typename CpuDevice::MR
 
using OperationBase = UnaryOperation< DeviceType::Cpu, float >
 
- Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, float >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 

Public Member Functions

 CpuMultiHeadAttentionOp (const MultiHeadAttentionConfig &config)
 Constructs a new CPU Attention operation with the default device context.
 
 CpuMultiHeadAttentionOp (std::shared_ptr< DeviceContext > context, const MultiHeadAttentionConfig &config)
 Constructs a new CPU Attention operation with a specific device context.
 
void backward (const Tensor< float, MR > &input, const Tensor< float, MR > &output, const Tensor< float, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< float, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< float, MR > > > &parameter_gradients, Tensor< float, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const
 Performs the backward pass of the Multi-Head Attention operation.
 
void backward_impl (float *dinp, float *dpreatt, float *datt, float *dout, float *inp, float *att, int B, int T, int C, int NH) const
 Helper method for backward pass implementation.
 
void forward (const Tensor< float, MR > &input, const std::vector< std::shared_ptr< Tensor< float, MR > > > &parameters, const OperationAttributes &properties, Tensor< float, MR > &output, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const override
 Performs the forward pass of the Multi-Head Attention operation.
 
std::string getName () const override
 Gets the name of this operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, float >
 UnaryOperation (OperationType operation_type)
 Constructs a UnaryOperation with the specified operation type.
 
 UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs a UnaryOperation with the specified operation type and device context.
 
virtual ~UnaryOperation ()=default
 Virtual destructor for proper cleanup of derived classes.
 
virtual void backward (const Tensor< float, MR > &grad, const std::vector< std::shared_ptr< Tensor< float, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_grads) const
 Executes the backward pass of a unary operation.
 
virtual void backward (const Tensor< float, MR > &input, const Tensor< float, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< float, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< float, MR > > > &parameter_grads, Tensor< float, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const
 Executes the comprehensive backward pass of a unary operation.
 
virtual void forward (const Tensor< float, MR > &input, const std::vector< std::shared_ptr< Tensor< float, MR > > > &parameters, const OperationAttributes &properties, Tensor< float, MR > &output, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const=0
 Executes the forward pass of a unary operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >
 OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs an OperationBase object with a specific device context and compute precision.
 
virtual ~OperationBase ()=default
 Virtual destructor for the OperationBase class.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Gets the device context associated with this operation.
 
DeviceType getDeviceType () const
 Gets the device type for this operation.
 
OperationType getOperationType () const
 Gets the operation type enumeration value.
 

Private Attributes

MultiHeadAttentionConfig config_
 Configuration for the Multi-Head Attention operation.
 

Detailed Description

CPU implementation of the Multi-Head Attention operation for neural networks.

This class provides a CPU-based implementation of the Multi-Head Attention operation, which is a key component of transformer architectures. The operation performs scaled dot-product attention with multiple attention heads operating in parallel, allowing the model to jointly attend to information from different representation subspaces at different positions.

The implementation handles the full attention process:

  • Query-Key dot products
  • Scaling
  • Softmax computation
  • Attention weighting of values
Template Parameters
floatThe data type of the input tensor elements.
TDataTypeThe data type used for computation and output (defaults to the input type).

Member Typedef Documentation

◆ MR

◆ OperationBase

Constructor & Destructor Documentation

◆ CpuMultiHeadAttentionOp() [1/2]

Mila::Dnn::Compute::CpuMultiHeadAttentionOp::CpuMultiHeadAttentionOp ( const MultiHeadAttentionConfig config)
inline

Constructs a new CPU Attention operation with the default device context.

Initializes the operation with a CPU device context.

◆ CpuMultiHeadAttentionOp() [2/2]

Mila::Dnn::Compute::CpuMultiHeadAttentionOp::CpuMultiHeadAttentionOp ( std::shared_ptr< DeviceContext context,
const MultiHeadAttentionConfig config 
)
inline

Constructs a new CPU Attention operation with a specific device context.

Parameters
contextThe device context to use for this operation.
Exceptions
std::runtime_errorIf the context is not for a CPU device.

Member Function Documentation

◆ backward()

void Mila::Dnn::Compute::CpuMultiHeadAttentionOp::backward ( const Tensor< float, MR > &  input,
const Tensor< float, MR > &  output,
const Tensor< float, MR > &  output_gradient,
const std::vector< std::shared_ptr< Tensor< float, MR > > > &  parameters,
std::vector< std::shared_ptr< Tensor< float, MR > > > &  parameter_gradients,
Tensor< float, MR > &  input_gradient,
const OperationAttributes properties,
const std::vector< std::shared_ptr< Tensor< float, MR > > > &  output_state 
) const
inline

Performs the backward pass of the Multi-Head Attention operation.

Computes gradients with respect to inputs (query, key, value vectors), pre-softmax attention scores, and attention weights based on the output gradient.

Parameters
inputInput tensor from the forward pass.
outputOutput tensor from the forward pass.
output_gradientGradient of the loss with respect to the output.
parametersParameters used in forward pass (not used in this operation).
parameter_gradientsGradients for parameters (not used in this operation).
input_gradientGradient of the loss with respect to the input.
propertiesAdditional attributes for the operation.
output_stateCache tensors [preatt, att] from forward pass.
Here is the call graph for this function:

◆ backward_impl()

void Mila::Dnn::Compute::CpuMultiHeadAttentionOp::backward_impl ( float *  dinp,
float *  dpreatt,
float *  datt,
float *  dout,
float *  inp,
float *  att,
int  B,
int  T,
int  C,
int  NH 
) const
inline

Helper method for backward pass implementation.

Parameters
dinpPointer to the gradient buffer for the input (query, key, value).
dpreattPointer to the gradient buffer for pre-softmax attention scores.
dattPointer to the gradient buffer for attention weights.
doutPointer to the gradient buffer from the output.
inpPointer to the original input values (query, key, value).
attPointer to the attention weights computed during forward pass.
BBatch size.
TDataTypeSequence length.
CFeature dimension (divided by 3 for query, key, value).
NHNumber of attention heads.
Here is the caller graph for this function:

◆ forward()

void Mila::Dnn::Compute::CpuMultiHeadAttentionOp::forward ( const Tensor< float, MR > &  input,
const std::vector< std::shared_ptr< Tensor< float, MR > > > &  parameters,
const OperationAttributes properties,
Tensor< float, MR > &  output,
std::vector< std::shared_ptr< Tensor< float, MR > > > &  output_state 
) const
inlineoverride

Performs the forward pass of the Multi-Head Attention operation.

Computes attention scores between queries and keys, applies softmax to get attention weights, and uses these weights to compute a weighted sum of value vectors.

Parameters
inputInput tensor of shape [B, TDataType, 3*C] containing concatenated query, key, and value vectors.
parametersAdditional parameters (not used in this operation).
propertiesAdditional attributes for the operation.
outputOutput tensor of shape [B, TDataType, C] containing the attention output.
output_stateCache for intermediate results [preatt, att] that are used in the backward pass.
Here is the call graph for this function:

◆ getName()

std::string Mila::Dnn::Compute::CpuMultiHeadAttentionOp::getName ( ) const
inlineoverridevirtual

Gets the name of this operation.

Returns
std::string The name of the operation ("Cpu::AttentionOp").

Implements Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >.

Member Data Documentation

◆ config_

MultiHeadAttentionConfig Mila::Dnn::Compute::CpuMultiHeadAttentionOp::config_
private

Configuration for the Multi-Head Attention operation.


The documentation for this class was generated from the following file: