Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Modules.Attention Module Reference

Exported Modules

module  Compute.OperationAttributes
 
module  Compute.CudaDevice
 
module  Serialization.ModelArchive
 
module  Compute.OperationRegistry
 
module  Compute.Precision
 
module  Dnn.TensorTraits
 
module  Compute.MemoryResource
 
module  Compute.DeviceType
 
module  Dnn.Module
 
module  Dnn.Tensor
 
module  Dnn.TensorHelpers
 
module  Compute.CpuMemoryResource
 
module  Compute.OperationBase
 
module  Compute.UnaryOperation
 
module  Compute.CpuDevice
 
module  Compute.CudaMemoryResource
 
module  Compute.DeviceContext
 

Classes

class  Mila::Dnn::MultiHeadAttention< TDeviceType, TInput, TOutput >
 Multi-head attention module for transformer architectures. More...
 
class  Mila::Dnn::MultiHeadAttentionConfig
 Configuration class for MultiHeadAttention module. More...
 

Typedefs

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CpuMultiHeadAttention = MultiHeadAttention< DeviceType::Cpu, TInput, TOutput >
 Type alias for CPU-based multi-head attention module with customizable tensor types.
 
template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CudaMultiHeadAttention = MultiHeadAttention< DeviceType::Cuda, TInput, TOutput >
 Type alias for CUDA-based multi-head attention module with customizable tensor types.
 
using ModuleBase = Module< TDeviceType, TInput, TOutput >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 

Functions

 MultiHeadAttention (const std::string &device_name, const MultiHeadAttentionConfig &config)
 Constructs a new MultiHeadAttention module with a device name.
 
 MultiHeadAttention (std::shared_ptr< DeviceContext > device_context, const MultiHeadAttentionConfig &config)
 Constructs a new MultiHeadAttention module with a provided device context.
 
void backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad)
 Performs the backward pass of the MultiHeadAttention operation.
 
void createOperation ()
 Creates the appropriate attention operation for the current device.
 
void forward (const Tensor< TInput, MR > &input, const Tensor< TInput, MR > &mask, Tensor< TOutput, MR > &output)
 Performs the forward pass with an explicit attention mask.
 
void forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output)
 Performs the forward pass of the MultiHeadAttention operation.
 
float getDropout () const
 Gets the dropout rate.
 
size_t getEmbeddingDim () const
 Gets the embedding dimension.
 
const std::vector< size_t > & getInputShape () const
 Gets the input shape.
 
size_t getNumHeads () const
 Gets the number of attention heads.
 
void initializeTensors ()
 Initializes the tensors needed for attention computation.
 
void load (ModelArchive &archive) override
 Loads the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &zip) const override
 Saves the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 
bool useCausalMask () const
 Checks if causal masking is enabled.
 

Variables

std::shared_ptr< Tensor< TOutput, MR > > attn_ { nullptr }
 Attention weight tensor from the forward pass.
 
MultiHeadAttentionConfig config_
 Configuration for the MultiHeadAttention module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > operation_ { nullptr }
 The operation that implements the attention mechanism.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > output_state_
 Collection of output state tensors for caching.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameters_
 Collection of parameters for this module.
 
std::shared_ptr< Tensor< TOutput, MR > > pre_attn_ { nullptr }
 Pre-softmax attention scores from the forward pass.
 
OperationAttributes properties_
 Operation attributes and configuration.
 

Files

file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Layers/MultiHeadAttention.ixx
 Implementation of multi-head attention mechanism for transformer architectures.
 
file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Layers/MultiHeadAttentionConfig.ixx
 Configuration interface for the MultiHeadAttention module in the Mila DNN framework.