Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::LanguageModel< TDeviceType, TPrecision > Class Template Referenceabstractexport
Inheritance diagram for Mila::Dnn::LanguageModel< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::LanguageModel< TDeviceType, TPrecision >:

Public Types

using Base = Model<TDeviceType, TPrecision>
Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
using NetworkType = Network<TDeviceType, TPrecision>

Public Member Functions

 LanguageModel (const LanguageModel &)=delete
 LanguageModel (LanguageModel &&)=default
virtual ~LanguageModel ()=default
std::vector< int32_t > generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0)
 Blocking generation.
void generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={})
 Synchronous per-token streaming.
const GenerationStatisticsgetLastGenerationStatistics () const noexcept
 Returns statistics from the most recent generateStreaming() call.
LanguageModeloperator= (const LanguageModel &)=delete
LanguageModeloperator= (LanguageModel &&)=default
Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (const Model &)=delete
 Model (Model &&)=default
virtual ~Model ()=default
DeviceId getDeviceId () const noexcept
 The device this model runs on.
MemoryStats getMemoryStats () const
 Current memory allocation breakdown for this model.
RuntimeMode getRuntimeMode () const noexcept
 The runtime mode this model was constructed for.
bool isEval () const noexcept
 True if this model is currently in eval sub-state.
bool isInferenceMode () const noexcept
 True if this model was constructed for inference.
bool isTrainingMode () const noexcept
 True if this model was constructed for training.
Modeloperator= (const Model &)=delete
Modeloperator= (Model &&)=default
void setEval (bool eval)
 Toggle eval sub-state for this model.
virtual std::string toString () const =0
 Human-readable summary of this model's configuration.
void train ()
 Run the training loop for this model.

Protected Member Functions

 LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode)
virtual int32_t eosToken () const noexcept=0
const LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () const noexcept
LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () noexcept
virtual int64_t maxSequenceLength () const noexcept=0
virtual void onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop)=0
 Prefill + decode implementation hook.
virtual std::unordered_set< int32_t > stopTokens () const
virtual int64_t vocabSize () const noexcept=0
Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode)
 Construct with a fully built network and runtime mode.
virtual void onTraining ()=0
 Training loop hook — derived class owns the implementation.

Protected Attributes

GenerationStatistics last_generation_statistics_ {}
 Statistics populated by onGenerating() for each completed generation run.
Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
std::unique_ptr< NetworkTypenetwork_
 The owned Network instance.

Member Typedef Documentation

◆ Base

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::Base = Model<TDeviceType, TPrecision>

Constructor & Destructor Documentation

◆ LanguageModel() [1/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::LanguageModel ( const LanguageModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:
Here is the caller graph for this function:

◆ LanguageModel() [2/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::LanguageModel ( LanguageModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ ~LanguageModel()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::~LanguageModel ( )
virtualdefault

◆ LanguageModel() [3/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::LanguageModel ( std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network,
RuntimeMode runtime_mode )
inlineexplicitprotected

Member Function Documentation

◆ eosToken()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual int32_t Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::eosToken ( ) const
protectedpure virtualnoexcept

Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

Here is the caller graph for this function:

◆ generate()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector< int32_t > Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::generate ( const std::vector< int32_t > & prompt_tokens,
size_t max_new_tokens = 64,
float temperature = 1.0f,
int top_k = 0 )
inline

Blocking generation.

Returns the prompt tokens followed by all generated tokens (EOS excluded).

Parameters
prompt_tokensInput token ids.
max_new_tokensMaximum tokens to generate beyond the prompt.
temperatureSampling temperature; <= 0 selects argmax.
top_kTop-k filter; 0 disables.
Returns
Full token sequence including the prompt.
Here is the call graph for this function:

◆ generateStreaming()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::generateStreaming ( const std::vector< int32_t > & prompt_tokens,
std::function< void(int32_t)> on_token,
size_t max_new_tokens = 64,
float temperature = 1.0f,
int top_k = 0,
std::stop_token stop = {} )
inline

Synchronous per-token streaming.

Blocks on the caller's thread until generation completes or stop is requested.

on_token is invoked on the caller's thread for every generated token (EOS excluded). Callers that own their own threading — such as the Python ModelWorker's single-thread executor — should use this directly.

Parameters
prompt_tokensInput token ids.
on_tokenPer-token callback invoked on the caller's thread.
max_new_tokensMaximum tokens to generate beyond the prompt.
temperatureSampling temperature; <= 0 selects argmax.
top_kTop-k filter; 0 disables.
stopStop token for cooperative cancellation.
Here is the caller graph for this function:

◆ getLanguageNetwork() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
const LanguageNetwork< TDeviceType, TPrecision > & Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::getLanguageNetwork ( ) const
inlineprotectednoexcept

◆ getLanguageNetwork() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
LanguageNetwork< TDeviceType, TPrecision > & Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::getLanguageNetwork ( )
inlineprotectednoexcept
Here is the caller graph for this function:

◆ getLastGenerationStatistics()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const GenerationStatistics & Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::getLastGenerationStatistics ( ) const
inlinenodiscardnoexcept

Returns statistics from the most recent generateStreaming() call.

Only valid after at least one generateStreaming() call has returned. Check GenerationStatistics::valid() before using the values.

Returns
Reference to the last captured generation statistics.

◆ maxSequenceLength()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual int64_t Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::maxSequenceLength ( ) const
protectedpure virtualnoexcept

◆ onGenerating()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::onGenerating ( const std::vector< int32_t > & prompt_tokens,
const std::function< void(int32_t)> & on_token,
size_t max_new_tokens,
float temperature,
int top_k,
std::stop_token stop )
protectedpure virtual

Prefill + decode implementation hook.

Derived classes own the full autoregressive generation loop. on_token must be called for every generated token except EOS. stop.stop_requested() must be checked on each decode step and generation must abort early when signalled.

Parameters
prompt_tokensInput token ids.
on_tokenPer-token callback.
max_new_tokensMaximum tokens to generate beyond the prompt.
temperatureSampling temperature; <= 0 selects argmax.
top_kTop-k filter; 0 disables.
stopStop token for cooperative cancellation.

Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

◆ operator=() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
LanguageModel & Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::operator= ( const LanguageModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ operator=() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
LanguageModel & Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::operator= ( LanguageModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ stopTokens()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::unordered_set< int32_t > Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::stopTokens ( ) const
inlineprotectedvirtual

Reimplemented in Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ vocabSize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual int64_t Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::vocabSize ( ) const
protectedpure virtualnoexcept

Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

Member Data Documentation

◆ last_generation_statistics_

template<DeviceType TDeviceType, TensorDataType TPrecision>
GenerationStatistics Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::last_generation_statistics_ {}
protected

Statistics populated by onGenerating() for each completed generation run.


The documentation for this class was generated from the following file: