Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::CpuAttentionOp Class Referenceexport

CPU implementation of Multi-Head Attention operation. More...

Inheritance diagram for Mila::Dnn::Compute::CpuAttentionOp:
Collaboration diagram for Mila::Dnn::Compute::CpuAttentionOp:

Public Types

using CpuExecutionContext = ExecutionContext<DeviceType::Cpu>
using MR = CpuMemoryResource
using TensorType = Tensor<TensorDataType::FP32, MR>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, TensorDataType::FP32 >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
using DataTypeTraits

Public Member Functions

 CpuAttentionOp (IExecutionContext *context, const MultiHeadAttentionConfig &config)
 ~CpuAttentionOp () override=default
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override
 Backward pass: compute gradient wrt input given output gradient.
void build (const BuildContext &config) override
 Prepare the operation for a concrete input shape.
void forward (const ITensor &input, ITensor &output) const override
 Forward pass: compute output = f(input).
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void setGradients (ITensor *, ITensor *) override
 Bind module-owned gradient tensors to the operation.
void setParameters (ITensor *, ITensor *) override
 Bind module-owned parameter tensors to the operation.
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, TensorDataType::FP32 >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Private Member Functions

void allocateStateTensors ()
void applySoftmax () const
void computeAttentionScores (float scale) const
void computeGradientAtt () const
void computeGradientK () const
void computeGradientPreatt (float scale) const
void computeGradientQ () const
void computeGradientV () const
void computeOutputValues () const
void permute_backward (float *dX) const
void permuteQKV (const float *X) const
void unpermute (float *Y) const
void unpermute_backward (const float *dY) const
void validateInputShape (const shape_t &input_shape) const

Private Attributes

float * att_ { nullptr }
std::shared_ptr< TensorTypeatt_tensor_
int B_ { 0 }
MultiHeadAttentionConfig config_
IExecutionContextcontext_
float * datt_ { nullptr }
std::shared_ptr< TensorTypedatt_tensor_
float * dk_ { nullptr }
std::shared_ptr< TensorTypedk_tensor_
float * dpreatt_ { nullptr }
std::shared_ptr< TensorTypedpreatt_tensor_
float * dq_ { nullptr }
std::shared_ptr< TensorTypedq_tensor_
float * dv_ { nullptr }
std::shared_ptr< TensorTypedv_tensor_
float * dvout_ { nullptr }
std::shared_ptr< TensorTypedvout_tensor_
int embedding_dim_ { 0 }
int HS_ { 0 }
bool is_built_ { false }
float * k_ { nullptr }
std::shared_ptr< TensorTypek_tensor_
int NH_ { 0 }
float * preatt_ { nullptr }
std::shared_ptr< TensorTypepreatt_tensor_
float * q_ { nullptr }
std::shared_ptr< TensorTypeq_tensor_
int qkv_dim_ { 0 }
int T_ { 0 }
float * v_ { nullptr }
float * v_out_ { nullptr }
std::shared_ptr< TensorTypev_out_tensor_
std::shared_ptr< TensorTypev_tensor_

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, TensorDataType::FP32 >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
bool is_built_
TrainingMode training_mode_

Detailed Description

CPU implementation of Multi-Head Attention operation.

Design philosophy:

  • Two-phase initialization: build() creates all required tensors, forward()/backward() use them
  • All dimension computation and tensor allocation happens once in build()
  • Forward/backward are hot-path methods with zero setup overhead

Forward pass:

  1. Permute QKV from [B, T, 3*C] to separate Q, K, V [B, NH, T, HS]
  2. Compute attention scores: preatt = Q @ K^T
  3. Apply softmax with causal masking: att = softmax(preatt / sqrt(HS))
  4. Compute values: v_out = Att @ V
  5. Unpermute output from [B, NH, T, HS] to [B, T, C]

Backward pass:

  1. Unpermute output gradient to [B, NH, T, HS]
  2. Compute dV = Att^T @ dVout
  3. Compute dAtt = dVout @ V^T
  4. Softmax backward: dPreatt = softmax_backward(dAtt, Att) * scale
  5. Compute dQ = dPreatt @ K
  6. Compute dK = dPreatt^T @ Q
  7. Permute gradients back to concatenated QKV format

Member Typedef Documentation

◆ CpuExecutionContext

◆ MR

◆ TensorType

Constructor & Destructor Documentation

◆ CpuAttentionOp()

Mila::Dnn::Compute::CpuAttentionOp::CpuAttentionOp ( IExecutionContext * context,
const MultiHeadAttentionConfig & config )
inlineexplicit

◆ ~CpuAttentionOp()

Mila::Dnn::Compute::CpuAttentionOp::~CpuAttentionOp ( )
overridedefault

Member Function Documentation

◆ allocateStateTensors()

void Mila::Dnn::Compute::CpuAttentionOp::allocateStateTensors ( )
inlineprivate
Here is the caller graph for this function:

◆ applySoftmax()

void Mila::Dnn::Compute::CpuAttentionOp::applySoftmax ( ) const
inlineprivate
Here is the caller graph for this function:

◆ backward()

void Mila::Dnn::Compute::CpuAttentionOp::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Backward pass: compute gradient wrt input given output gradient.

Signature ordered as (input, output_grad, input_grad) to match module and operation implementations across the codebase.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, TensorDataType::FP32 >.

Here is the call graph for this function:

◆ build()

void Mila::Dnn::Compute::CpuAttentionOp::build ( const BuildContext & build_context)
inlineoverridevirtual

Prepare the operation for a concrete input shape.

Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

Here is the call graph for this function:

◆ computeAttentionScores()

void Mila::Dnn::Compute::CpuAttentionOp::computeAttentionScores ( float scale) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeGradientAtt()

void Mila::Dnn::Compute::CpuAttentionOp::computeGradientAtt ( ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeGradientK()

void Mila::Dnn::Compute::CpuAttentionOp::computeGradientK ( ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeGradientPreatt()

void Mila::Dnn::Compute::CpuAttentionOp::computeGradientPreatt ( float scale) const
inlineprivate
Here is the caller graph for this function:

◆ computeGradientQ()

void Mila::Dnn::Compute::CpuAttentionOp::computeGradientQ ( ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeGradientV()

void Mila::Dnn::Compute::CpuAttentionOp::computeGradientV ( ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeOutputValues()

void Mila::Dnn::Compute::CpuAttentionOp::computeOutputValues ( ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ forward()

void Mila::Dnn::Compute::CpuAttentionOp::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Forward pass: compute output = f(input).

Implementations should accept polymorphic ITensor references and may use the typed aliases / helpers to obtain typed tensor references.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cpu, TensorDataType::FP32 >.

Here is the call graph for this function:

◆ getName()

std::string Mila::Dnn::Compute::CpuAttentionOp::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ getOperationType()

OperationType Mila::Dnn::Compute::CpuAttentionOp::getOperationType ( ) const
inlineoverridevirtual

◆ permute_backward()

void Mila::Dnn::Compute::CpuAttentionOp::permute_backward ( float * dX) const
inlineprivate
Here is the caller graph for this function:

◆ permuteQKV()

void Mila::Dnn::Compute::CpuAttentionOp::permuteQKV ( const float * X) const
inlineprivate
Here is the caller graph for this function:

◆ setGradients()

void Mila::Dnn::Compute::CpuAttentionOp::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlineoverridevirtual

Bind module-owned gradient tensors to the operation.

New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().

The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ setParameters()

void Mila::Dnn::Compute::CpuAttentionOp::setParameters ( ITensor * weight,
ITensor * bias )
inlineoverridevirtual

Bind module-owned parameter tensors to the operation.

The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ unpermute()

void Mila::Dnn::Compute::CpuAttentionOp::unpermute ( float * Y) const
inlineprivate
Here is the caller graph for this function:

◆ unpermute_backward()

void Mila::Dnn::Compute::CpuAttentionOp::unpermute_backward ( const float * dY) const
inlineprivate
Here is the caller graph for this function:

◆ validateInputShape()

void Mila::Dnn::Compute::CpuAttentionOp::validateInputShape ( const shape_t & input_shape) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

Member Data Documentation

◆ att_

float* Mila::Dnn::Compute::CpuAttentionOp::att_ { nullptr }
private

◆ att_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::att_tensor_
private

◆ B_

int Mila::Dnn::Compute::CpuAttentionOp::B_ { 0 }
private

◆ config_

MultiHeadAttentionConfig Mila::Dnn::Compute::CpuAttentionOp::config_
private

◆ context_

IExecutionContext* Mila::Dnn::Compute::CpuAttentionOp::context_
private

◆ datt_

float* Mila::Dnn::Compute::CpuAttentionOp::datt_ { nullptr }
private

◆ datt_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::datt_tensor_
private

◆ dk_

float* Mila::Dnn::Compute::CpuAttentionOp::dk_ { nullptr }
private

◆ dk_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::dk_tensor_
private

◆ dpreatt_

float* Mila::Dnn::Compute::CpuAttentionOp::dpreatt_ { nullptr }
private

◆ dpreatt_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::dpreatt_tensor_
private

◆ dq_

float* Mila::Dnn::Compute::CpuAttentionOp::dq_ { nullptr }
private

◆ dq_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::dq_tensor_
private

◆ dv_

float* Mila::Dnn::Compute::CpuAttentionOp::dv_ { nullptr }
private

◆ dv_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::dv_tensor_
private

◆ dvout_

float* Mila::Dnn::Compute::CpuAttentionOp::dvout_ { nullptr }
private

◆ dvout_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::dvout_tensor_
private

◆ embedding_dim_

int Mila::Dnn::Compute::CpuAttentionOp::embedding_dim_ { 0 }
private

◆ HS_

int Mila::Dnn::Compute::CpuAttentionOp::HS_ { 0 }
private

◆ is_built_

bool Mila::Dnn::Compute::CpuAttentionOp::is_built_ { false }
private

◆ k_

float* Mila::Dnn::Compute::CpuAttentionOp::k_ { nullptr }
private

◆ k_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::k_tensor_
private

◆ NH_

int Mila::Dnn::Compute::CpuAttentionOp::NH_ { 0 }
private

◆ preatt_

float* Mila::Dnn::Compute::CpuAttentionOp::preatt_ { nullptr }
private

◆ preatt_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::preatt_tensor_
private

◆ q_

float* Mila::Dnn::Compute::CpuAttentionOp::q_ { nullptr }
private

◆ q_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::q_tensor_
private

◆ qkv_dim_

int Mila::Dnn::Compute::CpuAttentionOp::qkv_dim_ { 0 }
private

◆ T_

int Mila::Dnn::Compute::CpuAttentionOp::T_ { 0 }
private

◆ v_

float* Mila::Dnn::Compute::CpuAttentionOp::v_ { nullptr }
private

◆ v_out_

float* Mila::Dnn::Compute::CpuAttentionOp::v_out_ { nullptr }
private

◆ v_out_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::v_out_tensor_
private

◆ v_tensor_

std::shared_ptr<TensorType> Mila::Dnn::Compute::CpuAttentionOp::v_tensor_
private

The documentation for this class was generated from the following file: