Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::RopeConfig Class Referenceexport
Inheritance diagram for Mila::Dnn::RopeConfig:
Collaboration diagram for Mila::Dnn::RopeConfig:

Public Member Functions

 RopeConfig (size_t channels, size_t n_heads, size_t n_kv_heads, size_t max_seq_len)
 Construct with all structurally required parameters.
void fromMetadata (const SerializationMetadata &meta) override
 Populate configuration from provided metadata.
float getBase () const noexcept
size_t getEmbeddingDim () const noexcept
size_t getHeadDim () const noexcept
 Per-head dimension, derived as channels / n_heads.
size_t getMaxSequenceLength () const noexcept
 Returns the training maximum sequence length.
size_t getNumHeads () const noexcept
size_t getNumKVHeads () const noexcept
size_t getRotaryDim () const noexcept
SerializationMetadata toMetadata () const override
 Convert configuration into a SerializationMetadata object.
std::string toString () const override
 Produce a short, human-readable summary of the configuration.
void validate () const override
 Validate configuration.
template<typename Self>
decltype(auto) withBase (this Self &&self, float base)
 Set frequency base for rotary angle computation.
template<typename Self>
decltype(auto) withRotaryDim (this Self &&self, size_t rotary_dim)
 Set rotary sub-dimension per head (number of channels to rotate).
Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor for polymorphic base.

Private Attributes

float base_ { 10000.0f }
size_t channels_ { 0 }
size_t max_seq_len_ { 0 }
size_t n_heads_ { 0 }
size_t n_kv_heads_ { 0 }
size_t rotary_dim_ { 0 }
 0 = use full head_dim

Constructor & Destructor Documentation

◆ RopeConfig()

Mila::Dnn::RopeConfig::RopeConfig ( size_t channels,
size_t n_heads,
size_t n_kv_heads,
size_t max_seq_len )
inline

Construct with all structurally required parameters.

Parameters
channelsTotal Q embedding width (n_heads * head_dim).
n_headsNumber of query heads.
n_kv_headsNumber of key/value heads (GQA: <= n_heads).
max_seq_lenMaximum sequence length for cos/sin cache precomputation.

Member Function Documentation

◆ fromMetadata()

void Mila::Dnn::RopeConfig::fromMetadata ( const SerializationMetadata & meta)
inlineoverridevirtual

Populate configuration from provided metadata.

Implementations should read available keys and leave missing keys at their current/default values to preserve forward/backward compatibility.

Parameters
metaMetadata to read configuration values from.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ getBase()

float Mila::Dnn::RopeConfig::getBase ( ) const
inlinenoexcept

◆ getEmbeddingDim()

size_t Mila::Dnn::RopeConfig::getEmbeddingDim ( ) const
inlinenoexcept

◆ getHeadDim()

size_t Mila::Dnn::RopeConfig::getHeadDim ( ) const
inlinenoexcept

Per-head dimension, derived as channels / n_heads.

Valid only after validate() has confirmed consistency.

Here is the caller graph for this function:

◆ getMaxSequenceLength()

size_t Mila::Dnn::RopeConfig::getMaxSequenceLength ( ) const
inlinenoexcept

Returns the training maximum sequence length.

Returns
The maximum sequence length.

◆ getNumHeads()

size_t Mila::Dnn::RopeConfig::getNumHeads ( ) const
inlinenoexcept

◆ getNumKVHeads()

size_t Mila::Dnn::RopeConfig::getNumKVHeads ( ) const
inlinenoexcept

◆ getRotaryDim()

size_t Mila::Dnn::RopeConfig::getRotaryDim ( ) const
inlinenoexcept

◆ toMetadata()

SerializationMetadata Mila::Dnn::RopeConfig::toMetadata ( ) const
inlineoverridevirtual

Convert configuration into a SerializationMetadata object.

Implementations should include any fields required to fully reconstruct the configuration via fromMetadata.

Returns
SerializationMetadata Metadata representation of the config.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::RopeConfig::toString ( ) const
inlineoverridevirtual

Produce a short, human-readable summary of the configuration.

Implementations should return a compact, single-line description suitable for logging and debugging.

Returns
std::string Human-readable summary of the configuration.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ validate()

void Mila::Dnn::RopeConfig::validate ( ) const
inlineoverridevirtual

Validate configuration.

Enforces: required fields are positive, channels is divisible by n_heads, head_dim is even (RoPE requires paired dimensions), n_kv_heads <= n_heads, and rotary_dim (if set) does not exceed head_dim.

Exceptions
std::invalid_argumenton any violated constraint.

Implements Mila::Dnn::ComponentConfig.

◆ withBase()

template<typename Self>
decltype(auto) Mila::Dnn::RopeConfig::withBase ( this Self && self,
float base )
inline

Set frequency base for rotary angle computation.

Standard RoPE default is 10000.0f. Llama 3 uses 500000.0f. Default: 10000.0f.

Here is the caller graph for this function:

◆ withRotaryDim()

template<typename Self>
decltype(auto) Mila::Dnn::RopeConfig::withRotaryDim ( this Self && self,
size_t rotary_dim )
inline

Set rotary sub-dimension per head (number of channels to rotate).

Default: 0 — the full head_dim is rotated.

Member Data Documentation

◆ base_

float Mila::Dnn::RopeConfig::base_ { 10000.0f }
private

◆ channels_

size_t Mila::Dnn::RopeConfig::channels_ { 0 }
private

◆ max_seq_len_

size_t Mila::Dnn::RopeConfig::max_seq_len_ { 0 }
private

◆ n_heads_

size_t Mila::Dnn::RopeConfig::n_heads_ { 0 }
private

◆ n_kv_heads_

size_t Mila::Dnn::RopeConfig::n_kv_heads_ { 0 }
private

◆ rotary_dim_

size_t Mila::Dnn::RopeConfig::rotary_dim_ { 0 }
private

0 = use full head_dim


The documentation for this class was generated from the following file: