CPU implementation of the Multi-Head Attention operation for neural networks.
More...
|
| CpuMultiHeadAttentionOp (const MultiHeadAttentionConfig &config) |
| Constructs a new CPU Attention operation with the default device context.
|
|
| CpuMultiHeadAttentionOp (std::shared_ptr< DeviceContext > context, const MultiHeadAttentionConfig &config) |
| Constructs a new CPU Attention operation with a specific device context.
|
|
void | backward (const Tensor< float, MR > &input, const Tensor< float, MR > &output, const Tensor< float, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meter_gradients, Tensor< float, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const |
| Performs the backward pass of the Multi-Head Attention operation.
|
|
void | backward_impl (float *dinp, float *dpreatt, float *datt, float *dout, float *inp, float *att, int B, int T, int C, int NH) const |
| Helper method for backward pass implementation.
|
|
void | forward (const Tensor< float, MR > &input, const std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meters, const OperationAttributes &properties, Tensor< float, MR > &output, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const override |
| Performs the forward pass of the Multi-Head Attention operation.
|
|
std::string | getName () const override |
| Gets the name of this operation.
|
|
| UnaryOperation (OperationType operation_type) |
| Constructs a UnaryOperation with the specified operation type.
|
|
| UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs a UnaryOperation with the specified operation type and device context.
|
|
virtual | ~UnaryOperation ()=default |
| Virtual destructor for proper cleanup of derived classes.
|
|
virtual void | backward (const Tensor< float, MR > &grad, const std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_grads) const |
| Executes the backward pass of a unary operation.
|
|
virtual void | backward (const Tensor< float, MR > &input, const Tensor< float, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meter_grads, Tensor< float, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const |
| Executes the comprehensive backward pass of a unary operation.
|
|
virtual void | forward (const Tensor< float, MR > &input, const std::vector< std::shared_ptr< Tensor< float, MR > > > ¶meters, const OperationAttributes &properties, Tensor< float, MR > &output, std::vector< std::shared_ptr< Tensor< float, MR > > > &output_state) const=0 |
| Executes the forward pass of a unary operation.
|
|
| OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs an OperationBase object with a specific device context and compute precision.
|
|
virtual | ~OperationBase ()=default |
| Virtual destructor for the OperationBase class.
|
|
std::shared_ptr< DeviceContext > | getDeviceContext () const |
| Gets the device context associated with this operation.
|
|
DeviceType | getDeviceType () const |
| Gets the device type for this operation.
|
|
OperationType | getOperationType () const |
| Gets the operation type enumeration value.
|
|
CPU implementation of the Multi-Head Attention operation for neural networks.
This class provides a CPU-based implementation of the Multi-Head Attention operation, which is a key component of transformer architectures. The operation performs scaled dot-product attention with multiple attention heads operating in parallel, allowing the model to jointly attend to information from different representation subspaces at different positions.
The implementation handles the full attention process:
- Query-Key dot products
- Scaling
- Softmax computation
- Attention weighting of values
- Template Parameters
-
float | The data type of the input tensor elements. |
TDataType | The data type used for computation and output (defaults to the input type). |