Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::RmsNormConfig Class Referenceexport
Inheritance diagram for Mila::Dnn::RmsNormConfig:
Collaboration diagram for Mila::Dnn::RmsNormConfig:

Public Member Functions

 RmsNormConfig (int64_t axis)
 Construct in axis mode.
 RmsNormConfig (shape_t normalized_shape)
 Construct in shape mode.
void fromMetadata (const SerializationMetadata &meta) override
 Populate configuration from provided metadata.
std::optional< int64_t > getAxis () const noexcept
float getEpsilon () const noexcept
const shape_tgetNormalizedShape () const noexcept
bool hasBias () const noexcept
bool hasNormalizedShape () const noexcept
SerializationMetadata toMetadata () const override
 Convert configuration into a SerializationMetadata object.
std::string toString () const override
 Produce a short, human-readable summary of the configuration.
void validate () const override
 Validate configuration parameters.
template<typename Self>
decltype(auto) withBias (this Self &&self, bool has_bias)
 Enable or disable learnable bias.
template<typename Self>
decltype(auto) withEpsilon (this Self &&self, float epsilon)
 Set epsilon for numerical stability.
Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor for polymorphic base.

Private Attributes

std::optional< dim_taxis_ { std::nullopt }
float epsilon_ { 1e-5f }
bool has_bias_ { true }
shape_t normalized_shape_ {}

Constructor & Destructor Documentation

◆ RmsNormConfig() [1/2]

Mila::Dnn::RmsNormConfig::RmsNormConfig ( shape_t normalized_shape)
inlineexplicit

Construct in shape mode.

Normalizes over the trailing dimensions described by shape.

Parameters
shapeTrailing dimensions to normalize over (e.g. shape_t{ model_dim }).

◆ RmsNormConfig() [2/2]

Mila::Dnn::RmsNormConfig::RmsNormConfig ( int64_t axis)
inlineexplicit

Construct in axis mode.

Normalizes over a single axis.

Parameters
axisAxis along which to normalize (negative indexing supported).

Member Function Documentation

◆ fromMetadata()

void Mila::Dnn::RmsNormConfig::fromMetadata ( const SerializationMetadata & meta)
inlineoverridevirtual

Populate configuration from provided metadata.

Implementations should read available keys and leave missing keys at their current/default values to preserve forward/backward compatibility.

Parameters
metaMetadata to read configuration values from.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ getAxis()

std::optional< int64_t > Mila::Dnn::RmsNormConfig::getAxis ( ) const
inlinenoexcept

◆ getEpsilon()

float Mila::Dnn::RmsNormConfig::getEpsilon ( ) const
inlinenoexcept

◆ getNormalizedShape()

const shape_t & Mila::Dnn::RmsNormConfig::getNormalizedShape ( ) const
inlinenoexcept

◆ hasBias()

bool Mila::Dnn::RmsNormConfig::hasBias ( ) const
inlinenoexcept

◆ hasNormalizedShape()

bool Mila::Dnn::RmsNormConfig::hasNormalizedShape ( ) const
inlinenoexcept

◆ toMetadata()

SerializationMetadata Mila::Dnn::RmsNormConfig::toMetadata ( ) const
inlineoverridevirtual

Convert configuration into a SerializationMetadata object.

Implementations should include any fields required to fully reconstruct the configuration via fromMetadata.

Returns
SerializationMetadata Metadata representation of the config.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::RmsNormConfig::toString ( ) const
inlineoverridevirtual

Produce a short, human-readable summary of the configuration.

Implementations should return a compact, single-line description suitable for logging and debugging.

Returns
std::string Human-readable summary of the configuration.

Implements Mila::Dnn::ComponentConfig.

◆ validate()

void Mila::Dnn::RmsNormConfig::validate ( ) const
inlineoverridevirtual

Validate configuration parameters.

Called by callers to ensure the configuration represents a valid, constructible component. Implementations must throw std::invalid_argument (or a derived exception) when validation fails.

Exceptions
std::invalid_argumentIf the configuration is invalid.

Implements Mila::Dnn::ComponentConfig.

◆ withBias()

template<typename Self>
decltype(auto) Mila::Dnn::RmsNormConfig::withBias ( this Self && self,
bool has_bias )
inline

Enable or disable learnable bias.

Default: true. Llama 3 uses false.

Here is the caller graph for this function:

◆ withEpsilon()

template<typename Self>
decltype(auto) Mila::Dnn::RmsNormConfig::withEpsilon ( this Self && self,
float epsilon )
inline

Set epsilon for numerical stability.

Default: 1e-5f. Llama 3 uses 1e-5f; some models use 1e-6f.

Member Data Documentation

◆ axis_

std::optional<dim_t> Mila::Dnn::RmsNormConfig::axis_ { std::nullopt }
private

◆ epsilon_

float Mila::Dnn::RmsNormConfig::epsilon_ { 1e-5f }
private

◆ has_bias_

bool Mila::Dnn::RmsNormConfig::has_bias_ { true }
private

◆ normalized_shape_

shape_t Mila::Dnn::RmsNormConfig::normalized_shape_ {}
private

The documentation for this class was generated from the following file: