Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GptConfig Class Referenceexport

Network-level configuration for GPT-style transformer networks. More...

Inheritance diagram for Mila::Dnn::GptConfig:
Collaboration diagram for Mila::Dnn::GptConfig:

Public Member Functions

 GptConfig (dim_t embedding_size, dim_t num_layers)
void fromMetadata (const SerializationMetadata &meta)
 Populate configuration from provided metadata.
dim_t getEmbeddingSize () const noexcept
dim_t getHiddenSize () const noexcept
dim_t getMaxSequenceLength () const noexcept
dim_t getNumHeads () const noexcept
dim_t getNumLayers () const noexcept
bool getUseBias () const noexcept
dim_t getVocabSize () const noexcept
SerializationMetadata toMetadata () const
 Convert configuration into a SerializationMetadata object.
std::string toString () const override
 Produce a short, human-readable summary of the configuration.
void validate () const override
 Validate configuration parameters.
template<typename Self>
decltype(auto) withBias (this Self &&self, bool use_bias)
template<typename Self>
decltype(auto) withHiddenSize (this Self &&self, dim_t hidden_size)
template<typename Self>
decltype(auto) withMaxSequenceLength (this Self &&self, dim_t max_seq_len)
template<typename Self>
decltype(auto) withNumHeads (this Self &&self, dim_t num_heads)
template<typename Self>
decltype(auto) withNumLayers (this Self &&self, dim_t num_layers)
template<typename Self>
decltype(auto) withVocabSize (this Self &&self, dim_t vocab_size)
Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor for polymorphic base.

Private Attributes

dim_t embedding_size_ = 768
dim_t hidden_size_ = 768
dim_t max_seq_len_ = 1024
dim_t num_heads_ = 12
dim_t num_layers_ = 12
bool use_bias_ = true
dim_t vocab_size_ = 50257

Detailed Description

Network-level configuration for GPT-style transformer networks.

Contains only the minimal network-level settings required by GPT networks: embedding dim, number of layers, heads, vocabulary and max seq len.

Constructor & Destructor Documentation

◆ GptConfig()

Mila::Dnn::GptConfig::GptConfig ( dim_t embedding_size,
dim_t num_layers )
inline

Member Function Documentation

◆ fromMetadata()

void Mila::Dnn::GptConfig::fromMetadata ( const SerializationMetadata & meta)
inlinevirtual

Populate configuration from provided metadata.

Implementations should read available keys and leave missing keys at their current/default values to preserve forward/backward compatibility.

Parameters
metaMetadata to read configuration values from.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ getEmbeddingSize()

dim_t Mila::Dnn::GptConfig::getEmbeddingSize ( ) const
inlinenoexcept

◆ getHiddenSize()

dim_t Mila::Dnn::GptConfig::getHiddenSize ( ) const
inlinenoexcept

◆ getMaxSequenceLength()

dim_t Mila::Dnn::GptConfig::getMaxSequenceLength ( ) const
inlinenoexcept

◆ getNumHeads()

dim_t Mila::Dnn::GptConfig::getNumHeads ( ) const
inlinenoexcept

◆ getNumLayers()

dim_t Mila::Dnn::GptConfig::getNumLayers ( ) const
inlinenoexcept

◆ getUseBias()

bool Mila::Dnn::GptConfig::getUseBias ( ) const
inlinenoexcept

◆ getVocabSize()

dim_t Mila::Dnn::GptConfig::getVocabSize ( ) const
inlinenoexcept

◆ toMetadata()

SerializationMetadata Mila::Dnn::GptConfig::toMetadata ( ) const
inlinevirtual

Convert configuration into a SerializationMetadata object.

Implementations should include any fields required to fully reconstruct the configuration via fromMetadata.

Returns
SerializationMetadata Metadata representation of the config.

Implements Mila::Dnn::ComponentConfig.

Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::GptConfig::toString ( ) const
inlineoverridevirtual

Produce a short, human-readable summary of the configuration.

Implementations should return a compact, single-line description suitable for logging and debugging.

Returns
std::string Human-readable summary of the configuration.

Implements Mila::Dnn::ComponentConfig.

◆ validate()

void Mila::Dnn::GptConfig::validate ( ) const
inlineoverridevirtual

Validate configuration parameters.

Called by callers to ensure the configuration represents a valid, constructible component. Implementations must throw std::invalid_argument (or a derived exception) when validation fails.

Exceptions
std::invalid_argumentIf the configuration is invalid.

Implements Mila::Dnn::ComponentConfig.

◆ withBias()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withBias ( this Self && self,
bool use_bias )
inline

◆ withHiddenSize()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withHiddenSize ( this Self && self,
dim_t hidden_size )
inline

◆ withMaxSequenceLength()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withMaxSequenceLength ( this Self && self,
dim_t max_seq_len )
inline

◆ withNumHeads()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withNumHeads ( this Self && self,
dim_t num_heads )
inline
Here is the caller graph for this function:

◆ withNumLayers()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withNumLayers ( this Self && self,
dim_t num_layers )
inline

◆ withVocabSize()

template<typename Self>
decltype(auto) Mila::Dnn::GptConfig::withVocabSize ( this Self && self,
dim_t vocab_size )
inline
Here is the caller graph for this function:

Member Data Documentation

◆ embedding_size_

dim_t Mila::Dnn::GptConfig::embedding_size_ = 768
private

◆ hidden_size_

dim_t Mila::Dnn::GptConfig::hidden_size_ = 768
private

◆ max_seq_len_

dim_t Mila::Dnn::GptConfig::max_seq_len_ = 1024
private

◆ num_heads_

dim_t Mila::Dnn::GptConfig::num_heads_ = 12
private

◆ num_layers_

dim_t Mila::Dnn::GptConfig::num_layers_ = 12
private

◆ use_bias_

bool Mila::Dnn::GptConfig::use_bias_ = true
private

◆ vocab_size_

dim_t Mila::Dnn::GptConfig::vocab_size_ = 50257
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/Gpt/Gpt.Config.ixx