Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GqaConfig Class Referenceexport

Configuration class for the Grouped-Query Attention module. More...

Inheritance diagram for Mila::Dnn::GqaConfig:
Collaboration diagram for Mila::Dnn::GqaConfig:

Public Member Functions

 GqaConfig (dim_t model_dim, dim_t num_heads, dim_t num_kv_heads)
 Constructor with all required parameters.
void fromMetadata (const SerializationMetadata &meta)
 Populate configuration from serialization metadata.
dim_t getGroupSize () const noexcept
 Number of Q heads sharing each K/V head (num_heads / num_kv_heads).
dim_t getHeadDim () const noexcept
 Per-head feature dimension (model_dim / num_heads).
dim_t getModelDim () const noexcept
 Total Q-projection output width (= num_heads * head_dim).
dim_t getNumHeads () const noexcept
 Number of Q attention heads.
dim_t getNumKvHeads () const noexcept
 Number of K/V attention heads.
SerializationMetadata toMetadata () const
 Convert configuration to serialization metadata.
std::string toString () const override
 Human-readable description of the configuration.
void validate () const override
 Validate all configuration parameters.
template<typename Self>
decltype(auto) withModelDim (this Self &&self, dim_t model_dim)
 Fluent setter for model dimension.
template<typename Self>
decltype(auto) withNumHeads (this Self &&self, dim_t num_heads)
 Fluent setter for number of Q heads.
template<typename Self>
decltype(auto) withNumKvHeads (this Self &&self, dim_t num_kv_heads)
 Fluent setter for number of K/V heads.
Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor for polymorphic base.

Private Attributes

dim_t model_dim_
dim_t num_heads_
dim_t num_kv_heads_

Detailed Description

Configuration class for the Grouped-Query Attention module.

Carries the three parameters that uniquely define a GQA layer:

  • model_dim : total Q-projection output width (= num_heads * head_dim)
  • num_heads : number of Q attention heads
  • num_kv_heads: number of K/V attention heads (must divide num_heads)

The derived head_dim = model_dim / num_heads and group_size = num_heads / num_kv_heads are computed on demand rather than stored, so they always stay consistent with the three primary fields.

Fluent setters follow the C++23 explicit-object-parameter pattern used throughout the codebase, enabling value-category-preserving method chaining on both lvalue and rvalue configs.

Constructor & Destructor Documentation

◆ GqaConfig()

Mila::Dnn::GqaConfig::GqaConfig ( dim_t model_dim,
dim_t num_heads,
dim_t num_kv_heads )
inline

Constructor with all required parameters.

Parameters
model_dimTotal Q-projection output width.
num_headsNumber of Q attention heads.
num_kv_headsNumber of K/V attention heads.

Member Function Documentation

◆ fromMetadata()

void Mila::Dnn::GqaConfig::fromMetadata ( const SerializationMetadata & meta)
inlinevirtual

Populate configuration from serialization metadata.

Reads available fields from the provided metadata and updates the configuration object in place. Missing keys are silently ignored so that older checkpoints without num_kv_heads fall back to the constructor-supplied default.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ getGroupSize()

dim_t Mila::Dnn::GqaConfig::getGroupSize ( ) const
inlinenoexcept

Number of Q heads sharing each K/V head (num_heads / num_kv_heads).

Equivalent to the GQA "group size". A value of 1 recovers standard MHA; a value equal to num_heads recovers Multi-Query Attention.

Here is the caller graph for this function:

◆ getHeadDim()

dim_t Mila::Dnn::GqaConfig::getHeadDim ( ) const
inlinenoexcept

Per-head feature dimension (model_dim / num_heads).

Derived quantity — always consistent with model_dim and num_heads.

Here is the caller graph for this function:

◆ getModelDim()

dim_t Mila::Dnn::GqaConfig::getModelDim ( ) const
inlinenoexcept

Total Q-projection output width (= num_heads * head_dim).

◆ getNumHeads()

dim_t Mila::Dnn::GqaConfig::getNumHeads ( ) const
inlinenoexcept

Number of Q attention heads.

◆ getNumKvHeads()

dim_t Mila::Dnn::GqaConfig::getNumKvHeads ( ) const
inlinenoexcept

Number of K/V attention heads.

The KV cache is sized proportionally to this value, giving GQA its memory bandwidth advantage over full MHA.

◆ toMetadata()

SerializationMetadata Mila::Dnn::GqaConfig::toMetadata ( ) const
inlinevirtual

Convert configuration to serialization metadata.

Produces a SerializationMetadata object containing all configuration fields suitable for writing into an archive by the caller.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::GqaConfig::toString ( ) const
inlineoverridevirtual

Human-readable description of the configuration.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ validate()

void Mila::Dnn::GqaConfig::validate ( ) const
inlineoverridevirtual

Validate all configuration parameters.

Checks:

  1. model_dim > 0
  2. num_heads >= 2
  3. model_dim % num_heads == 0 (integer head_dim)
  4. num_kv_heads >= 1
  5. num_kv_heads <= num_heads (KV heads cannot exceed Q heads)
  6. num_heads % num_kv_heads == 0 (integer group size)
Exceptions
std::invalid_argumentIf any constraint is violated.

Implements Mila::Dnn::ComponentConfig.

◆ withModelDim()

template<typename Self>
decltype(auto) Mila::Dnn::GqaConfig::withModelDim ( this Self && self,
dim_t model_dim )
inline

Fluent setter for model dimension.

◆ withNumHeads()

template<typename Self>
decltype(auto) Mila::Dnn::GqaConfig::withNumHeads ( this Self && self,
dim_t num_heads )
inline

Fluent setter for number of Q heads.

◆ withNumKvHeads()

template<typename Self>
decltype(auto) Mila::Dnn::GqaConfig::withNumKvHeads ( this Self && self,
dim_t num_kv_heads )
inline

Fluent setter for number of K/V heads.

Member Data Documentation

◆ model_dim_

dim_t Mila::Dnn::GqaConfig::model_dim_
private

◆ num_heads_

dim_t Mila::Dnn::GqaConfig::num_heads_
private

◆ num_kv_heads_

dim_t Mila::Dnn::GqaConfig::num_kv_heads_
private

The documentation for this class was generated from the following file: