Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::MultiHeadAttentionConfig Class Referenceexport

Configuration class for Attention module. More...

Inheritance diagram for Mila::Dnn::MultiHeadAttentionConfig:
Collaboration diagram for Mila::Dnn::MultiHeadAttentionConfig:

Public Member Functions

 MultiHeadAttentionConfig (dim_t model_dim, dim_t num_heads)
 Constructor with required parameters.
void fromMetadata (const SerializationMetadata &meta)
 Populate configuration from serialization metadata.
dim_t getModelDim () const noexcept
 Get the model dimension.
dim_t getNumHeads () const noexcept
 Get the number of attention heads.
SerializationMetadata toMetadata () const
 Convert configuration to serialization metadata.
std::string toString () const override
 String representation of the configuration.
void validate () const override
 Validate configuration parameters.
template<typename Self>
decltype(auto) withModelDim (this Self &&self, dim_t model_dim)
 C++23-style fluent setter for model dimension.
template<typename Self>
decltype(auto) withNumHeads (this Self &&self, dim_t num_heads)
 C++23-style fluent setter for number of heads.
Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor for polymorphic base.

Private Attributes

dim_t model_dim_
dim_t num_heads_

Detailed Description

Configuration class for Attention module.

Note: Some configuration options are currently disabled and marked for future implementation. The base implementation provides core multi-head attention functionality with:

  • Fixed causal masking (enabled by default for autoregressive models)
  • Automatic scale factor (1/sqrt(head_dim))
  • No dropout (to be added in future versions)
  • Unified Q/K/V input (separate projections to be added in future versions)

Constructor & Destructor Documentation

◆ MultiHeadAttentionConfig()

Mila::Dnn::MultiHeadAttentionConfig::MultiHeadAttentionConfig ( dim_t model_dim,
dim_t num_heads )
inline

Constructor with required parameters.

Parameters
model_dimThe model dimension size
num_headsThe number of attention heads

Member Function Documentation

◆ fromMetadata()

void Mila::Dnn::MultiHeadAttentionConfig::fromMetadata ( const SerializationMetadata & meta)
inlinevirtual

Populate configuration from serialization metadata.

Reads available fields from the provided metadata and updates the configuration object accordingly.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ getModelDim()

dim_t Mila::Dnn::MultiHeadAttentionConfig::getModelDim ( ) const
inlinenoexcept

Get the model dimension.

◆ getNumHeads()

dim_t Mila::Dnn::MultiHeadAttentionConfig::getNumHeads ( ) const
inlinenoexcept

Get the number of attention heads.

◆ toMetadata()

SerializationMetadata Mila::Dnn::MultiHeadAttentionConfig::toMetadata ( ) const
inlinevirtual

Convert configuration to serialization metadata.

Produces a SerializationMetadata object containing the configuration fields suitable for writing into an archive by the caller.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::MultiHeadAttentionConfig::toString ( ) const
inlineoverridevirtual

String representation of the configuration.

Returns
std::string Human-readable description of the configuration.

Implements Mila::Dnn::ComponentConfig.

◆ validate()

void Mila::Dnn::MultiHeadAttentionConfig::validate ( ) const
inlineoverridevirtual

Validate configuration parameters.

Exceptions
std::invalid_argumentIf validation fails

Implements Mila::Dnn::ComponentConfig.

◆ withModelDim()

template<typename Self>
decltype(auto) Mila::Dnn::MultiHeadAttentionConfig::withModelDim ( this Self && self,
dim_t model_dim )
inline

C++23-style fluent setter for model dimension.

◆ withNumHeads()

template<typename Self>
decltype(auto) Mila::Dnn::MultiHeadAttentionConfig::withNumHeads ( this Self && self,
dim_t num_heads )
inline

C++23-style fluent setter for number of heads.

Member Data Documentation

◆ model_dim_

dim_t Mila::Dnn::MultiHeadAttentionConfig::model_dim_
private

◆ num_heads_

dim_t Mila::Dnn::MultiHeadAttentionConfig::num_heads_
private

The documentation for this class was generated from the following file: