|
Mila 0.13.48
Deep Neural Network Library
|
Configuration class for the Grouped-Query Attention module. More...


Public Member Functions | |
| GqaConfig (dim_t model_dim, dim_t num_heads, dim_t num_kv_heads) | |
| Constructor with all required parameters. | |
| void | fromMetadata (const SerializationMetadata &meta) |
| Populate configuration from serialization metadata. | |
| dim_t | getGroupSize () const noexcept |
| Number of Q heads sharing each K/V head (num_heads / num_kv_heads). | |
| dim_t | getHeadDim () const noexcept |
| Per-head feature dimension (model_dim / num_heads). | |
| dim_t | getModelDim () const noexcept |
| Total Q-projection output width (= num_heads * head_dim). | |
| dim_t | getNumHeads () const noexcept |
| Number of Q attention heads. | |
| dim_t | getNumKvHeads () const noexcept |
| Number of K/V attention heads. | |
| SerializationMetadata | toMetadata () const |
| Convert configuration to serialization metadata. | |
| std::string | toString () const override |
| Human-readable description of the configuration. | |
| void | validate () const override |
| Validate all configuration parameters. | |
| template<typename Self> | |
| decltype(auto) | withModelDim (this Self &&self, dim_t model_dim) |
| Fluent setter for model dimension. | |
| template<typename Self> | |
| decltype(auto) | withNumHeads (this Self &&self, dim_t num_heads) |
| Fluent setter for number of Q heads. | |
| template<typename Self> | |
| decltype(auto) | withNumKvHeads (this Self &&self, dim_t num_kv_heads) |
| Fluent setter for number of K/V heads. | |
| Public Member Functions inherited from Mila::Dnn::ComponentConfig | |
| virtual | ~ComponentConfig ()=default |
| Virtual destructor for polymorphic base. | |
Private Attributes | |
| dim_t | model_dim_ |
| dim_t | num_heads_ |
| dim_t | num_kv_heads_ |
Configuration class for the Grouped-Query Attention module.
Carries the three parameters that uniquely define a GQA layer:
The derived head_dim = model_dim / num_heads and group_size = num_heads / num_kv_heads are computed on demand rather than stored, so they always stay consistent with the three primary fields.
Fluent setters follow the C++23 explicit-object-parameter pattern used throughout the codebase, enabling value-category-preserving method chaining on both lvalue and rvalue configs.
Constructor with all required parameters.
| model_dim | Total Q-projection output width. |
| num_heads | Number of Q attention heads. |
| num_kv_heads | Number of K/V attention heads. |
|
inlinevirtual |
Populate configuration from serialization metadata.
Reads available fields from the provided metadata and updates the configuration object in place. Missing keys are silently ignored so that older checkpoints without num_kv_heads fall back to the constructor-supplied default.
Implements Mila::Dnn::ComponentConfig.

|
inlinenoexcept |
Number of Q heads sharing each K/V head (num_heads / num_kv_heads).
Equivalent to the GQA "group size". A value of 1 recovers standard MHA; a value equal to num_heads recovers Multi-Query Attention.

|
inlinenoexcept |
Per-head feature dimension (model_dim / num_heads).
Derived quantity — always consistent with model_dim and num_heads.

|
inlinenoexcept |
Total Q-projection output width (= num_heads * head_dim).
|
inlinenoexcept |
Number of Q attention heads.
|
inlinenoexcept |
Number of K/V attention heads.
The KV cache is sized proportionally to this value, giving GQA its memory bandwidth advantage over full MHA.
|
inlinevirtual |
Convert configuration to serialization metadata.
Produces a SerializationMetadata object containing all configuration fields suitable for writing into an archive by the caller.
Implements Mila::Dnn::ComponentConfig.

|
inlineoverridevirtual |
Human-readable description of the configuration.
Implements Mila::Dnn::ComponentConfig.

|
inlineoverridevirtual |
Validate all configuration parameters.
Checks:
| std::invalid_argument | If any constraint is violated. |
Implements Mila::Dnn::ComponentConfig.
|
inline |
Fluent setter for model dimension.
|
inline |
Fluent setter for number of Q heads.
|
inline |
Fluent setter for number of K/V heads.
|
private |
|
private |
|
private |