Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GptModel< TDeviceType, TPrecision > Class Template Referenceexport

GPT inference model. More...

Inheritance diagram for Mila::Dnn::GptModel< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::GptModel< TDeviceType, TPrecision >:

Public Types

using GptTransformerType = GptTransformer<TDeviceType, TPrecision>
using ModelBase = LanguageModel<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
Public Types inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
using Base = Model<TDeviceType, TPrecision>
Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
using NetworkType = Network<TDeviceType, TPrecision>

Public Member Functions

 GptModel (const GptModel &)=delete
 GptModel (GptModel &&)=default
 ~GptModel ()=default
const GptConfiggetConfig () const noexcept
GptModeloperator= (const GptModel &)=delete
GptModeloperator= (GptModel &&)=default
std::string toString () const override
 Human-readable summary of this model's configuration.
Public Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
 LanguageModel (const LanguageModel &)=delete
 LanguageModel (LanguageModel &&)=default
virtual ~LanguageModel ()=default
std::vector< int32_t > generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0)
 Blocking generation.
void generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={})
 Synchronous per-token streaming.
const GenerationStatisticsgetLastGenerationStatistics () const noexcept
 Returns statistics from the most recent generateStreaming() call.
LanguageModeloperator= (const LanguageModel &)=delete
LanguageModeloperator= (LanguageModel &&)=default
Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (const Model &)=delete
 Model (Model &&)=default
virtual ~Model ()=default
DeviceId getDeviceId () const noexcept
 The device this model runs on.
MemoryStats getMemoryStats () const
 Current memory allocation breakdown for this model.
RuntimeMode getRuntimeMode () const noexcept
 The runtime mode this model was constructed for.
bool isEval () const noexcept
 True if this model is currently in eval sub-state.
bool isInferenceMode () const noexcept
 True if this model was constructed for inference.
bool isTrainingMode () const noexcept
 True if this model was constructed for training.
Modeloperator= (const Model &)=delete
Modeloperator= (Model &&)=default
void setEval (bool eval)
 Toggle eval sub-state for this model.
void train ()
 Run the training loop for this model.

Static Public Member Functions

static std::unique_ptr< GptModelfromCheckpoint (const std::filesystem::path &path, DeviceId device_id=DeviceId{ TDeviceType, 0 })
 Load from a Mila-native serialized artifact.
static std::unique_ptr< GptModelfromPretrained (const std::filesystem::path &path, dim_t context_length, DeviceId device_id=DeviceId{ TDeviceType, 0 }, bool strict=true)
 Load from third-party pretrained weights.

Protected Member Functions

int32_t eosToken () const noexcept override
 GPT-2 end-of-text token id.
int64_t maxSequenceLength () const noexcept override
 Maximum sequence length from GPT config.
void onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop) override
 Prefill + KV-cache decode loop with per-token streaming.
void onTraining () override
 Training loop — not yet implemented for GptModel.
int64_t vocabSize () const noexcept override
 Vocabulary size from GPT config.
Protected Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
 LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode)
const LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () const noexcept
LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () noexcept
virtual std::unordered_set< int32_t > stopTokens () const
Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode)
 Construct with a fully built network and runtime mode.

Private Member Functions

 GptModel (std::unique_ptr< GptTransformerType > network, const GptConfig &config)
 GptModel (std::unique_ptr< GptTransformerType > network, const GptConfig &config, RuntimeMode runtime_mode)
TokenIndexType makeTokenTensor (const std::vector< int32_t > &token_ids) const
int32_t sampleFromLogits (const TensorType &logits, int64_t position, float temperature, int top_k, std::mt19937 &rng) const
void truncateIfNeeded (std::vector< int32_t > &tokens) const

Static Private Member Functions

static GptConfig configFromMetadata (const PretrainedMetadata &metadata)
static int32_t sampleToken (const float *logits, size_t vocab_size, float temperature, int top_k, std::mt19937 &rng)

Private Attributes

GptConfig config_

Static Private Attributes

static constexpr int32_t eos_token_ = 50256

Additional Inherited Members

Protected Attributes inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
GenerationStatistics last_generation_statistics_ {}
 Statistics populated by onGenerating() for each completed generation run.
Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
std::unique_ptr< NetworkTypenetwork_
 The owned Network instance.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::GptModel< TDeviceType, TPrecision >

GPT inference model.

Owns a loaded, built GptTransformer and exposes generateStreaming() for autoregressive text generation.

Construction is only possible via fromPretrained() or fromCheckpoint(). The network is always in a built, weights-loaded, inference-mode state when generation is called.

Member Typedef Documentation

◆ GptTransformerType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptTransformerType = GptTransformer<TDeviceType, TPrecision>

◆ ModelBase

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::GptModel< TDeviceType, TPrecision >::ModelBase = LanguageModel<TDeviceType, TPrecision>

◆ MR

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::GptModel< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource

◆ TensorType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::GptModel< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ TokenIndexType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::GptModel< TDeviceType, TPrecision >::TokenIndexType = Tensor<dtype_t::INT32, MR>

Constructor & Destructor Documentation

◆ GptModel() [1/4]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptModel ( const GptModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:
Here is the caller graph for this function:

◆ GptModel() [2/4]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptModel ( GptModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ ~GptModel()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptModel< TDeviceType, TPrecision >::~GptModel ( )
default

◆ GptModel() [3/4]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptModel ( std::unique_ptr< GptTransformerType > network,
const GptConfig & config,
RuntimeMode runtime_mode )
inlineexplicitprivate

◆ GptModel() [4/4]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptModel ( std::unique_ptr< GptTransformerType > network,
const GptConfig & config )
inlineexplicitprivate

Member Function Documentation

◆ configFromMetadata()

template<DeviceType TDeviceType, TensorDataType TPrecision>
GptConfig Mila::Dnn::GptModel< TDeviceType, TPrecision >::configFromMetadata ( const PretrainedMetadata & metadata)
inlinestaticprivate
Here is the call graph for this function:

◆ eosToken()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::eosToken ( ) const
inlineoverrideprotectedvirtualnoexcept

GPT-2 end-of-text token id.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

◆ fromCheckpoint()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< GptModel > Mila::Dnn::GptModel< TDeviceType, TPrecision >::fromCheckpoint ( const std::filesystem::path & path,
DeviceId device_id = DeviceId{ TDeviceType, 0 } )
inlinestatic

Load from a Mila-native serialized artifact.

Reads a checkpoint or weights-only artifact produced by GptTransformer::save() via ModelArchive.

Parameters
pathPath to the Mila archive.
device_idTarget device.
Returns
Inference-ready GptModel.

◆ fromPretrained()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< GptModel > Mila::Dnn::GptModel< TDeviceType, TPrecision >::fromPretrained ( const std::filesystem::path & path,
dim_t context_length,
DeviceId device_id = DeviceId{ TDeviceType, 0 },
bool strict = true )
inlinestatic

Load from third-party pretrained weights.

Reads weights from a Mila-compatible pretrained artifact produced by converting third-party checkpoints (e.g. HuggingFace GPT-2) via PretrainedModelReader.

Parameters
pathPath to the pretrained artifact.
context_lengthMaximum sequence length to build for.
device_idTarget device.
strictThrows on unknown parameter names if true.
Returns
Inference-ready GptModel.

◆ getConfig()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const GptConfig & Mila::Dnn::GptModel< TDeviceType, TPrecision >::getConfig ( ) const
inlinenoexcept

◆ makeTokenTensor()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TokenIndexType Mila::Dnn::GptModel< TDeviceType, TPrecision >::makeTokenTensor ( const std::vector< int32_t > & token_ids) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ maxSequenceLength()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int64_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::maxSequenceLength ( ) const
inlineoverrideprotectedvirtualnoexcept

Maximum sequence length from GPT config.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

◆ onGenerating()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptModel< TDeviceType, TPrecision >::onGenerating ( const std::vector< int32_t > & prompt_tokens,
const std::function< void(int32_t)> & on_token,
size_t max_new_tokens,
float temperature,
int top_k,
std::stop_token stop )
inlineoverrideprotectedvirtual

Prefill + KV-cache decode loop with per-token streaming.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTraining()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptModel< TDeviceType, TPrecision >::onTraining ( )
inlineoverrideprotectedvirtual

Training loop — not yet implemented for GptModel.

Exceptions
std::runtime_erroralways.

Implements Mila::Dnn::Model< TDeviceType, TPrecision >.

◆ operator=() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
GptModel & Mila::Dnn::GptModel< TDeviceType, TPrecision >::operator= ( const GptModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ operator=() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
GptModel & Mila::Dnn::GptModel< TDeviceType, TPrecision >::operator= ( GptModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ sampleFromLogits()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::sampleFromLogits ( const TensorType & logits,
int64_t position,
float temperature,
int top_k,
std::mt19937 & rng ) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ sampleToken()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::sampleToken ( const float * logits,
size_t vocab_size,
float temperature,
int top_k,
std::mt19937 & rng )
inlinestaticprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::GptModel< TDeviceType, TPrecision >::toString ( ) const
inlineoverridevirtual

Human-readable summary of this model's configuration.

Implements Mila::Dnn::Model< TDeviceType, TPrecision >.

◆ truncateIfNeeded()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptModel< TDeviceType, TPrecision >::truncateIfNeeded ( std::vector< int32_t > & tokens) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ vocabSize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int64_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::vocabSize ( ) const
inlineoverrideprotectedvirtualnoexcept

Vocabulary size from GPT config.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

Member Data Documentation

◆ config_

template<DeviceType TDeviceType, TensorDataType TPrecision>
GptConfig Mila::Dnn::GptModel< TDeviceType, TPrecision >::config_
private

◆ eos_token_

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::GptModel< TDeviceType, TPrecision >::eos_token_ = 50256
staticconstexprprivate

The documentation for this class was generated from the following file: