Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::LlamaModel< TDeviceType, TPrecision > Class Template Referenceexport

LLaMA 3 compatible inference model. More...

Inheritance diagram for Mila::Dnn::LlamaModel< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::LlamaModel< TDeviceType, TPrecision >:

Public Types

using ModelBase = LanguageModel<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using StagingMR = CpuMemoryResource
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
Public Types inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
using Base = Model<TDeviceType, TPrecision>
Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
using NetworkType = Network<TDeviceType, TPrecision>

Public Member Functions

 LlamaModel (const LlamaModel &)=delete
 LlamaModel (LlamaModel &&)=default
 ~LlamaModel ()=default
const LlamaConfiggetConfig () const noexcept
LlamaModeloperator= (const LlamaModel &)=delete
LlamaModeloperator= (LlamaModel &&)=default
void profilePrefill (const std::vector< int32_t > &token_ids)
std::string toString () const override
 Human-readable summary of this model's configuration.
Public Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
 LanguageModel (const LanguageModel &)=delete
 LanguageModel (LanguageModel &&)=default
virtual ~LanguageModel ()=default
std::vector< int32_t > generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0)
 Blocking generation.
void generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={})
 Synchronous per-token streaming.
const GenerationStatisticsgetLastGenerationStatistics () const noexcept
 Returns statistics from the most recent generateStreaming() call.
LanguageModeloperator= (const LanguageModel &)=delete
LanguageModeloperator= (LanguageModel &&)=default
Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (const Model &)=delete
 Model (Model &&)=default
virtual ~Model ()=default
DeviceId getDeviceId () const noexcept
 The device this model runs on.
MemoryStats getMemoryStats () const
 Current memory allocation breakdown for this model.
RuntimeMode getRuntimeMode () const noexcept
 The runtime mode this model was constructed for.
bool isEval () const noexcept
 True if this model is currently in eval sub-state.
bool isInferenceMode () const noexcept
 True if this model was constructed for inference.
bool isTrainingMode () const noexcept
 True if this model was constructed for training.
Modeloperator= (const Model &)=delete
Modeloperator= (Model &&)=default
void setEval (bool eval)
 Toggle eval sub-state for this model.
void train ()
 Run the training loop for this model.

Static Public Member Functions

static std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > fromPretrained (const std::filesystem::path &path, const LlamaModelConfig &model_config, DeviceId device_id=DeviceId{ TDeviceType, 0 })
 Load from third-party pretrained weights.

Protected Member Functions

int64_t maxSequenceLength () const noexcept override
 Maximum sequence length from LLaMA config.
void onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop) override
 Prefill + KV-cache decode loop with per-token streaming.
void onTraining () override
 Training loop — not yet implemented for LlamaModel.
int64_t vocabSize () const noexcept override
 Vocabulary size from LLaMA config.
Protected Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
 LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode)
const LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () const noexcept
LanguageNetwork< TDeviceType, TPrecision > & getLanguageNetwork () noexcept
Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
 Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode)
 Construct with a fully built network and runtime mode.

Private Member Functions

 LlamaModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, const LlamaConfig &config, RuntimeMode runtime_mode)
int32_t eosToken () const noexcept override
 LLaMA 3.x end-of-sequence token.
TokenIndexType makeTokenTensor (const std::vector< int32_t > &token_ids) const
int32_t sampleFromLogits (const TensorType &logits, int64_t position, float temperature, int top_k, std::mt19937 &rng)
std::unordered_set< int32_t > stopTokens () const override
 Llama 3.x generation stop tokens.
void truncateIfNeeded (std::vector< int32_t > &tokens) const

Static Private Member Functions

static LlamaConfig configFromMetadata (const PretrainedMetadata &metadata)
template<WeightQuantPolicy TWeightQuantization, KvCachePolicy TKvCachePolicy>
static std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > fromPretrainedImpl (const std::filesystem::path &path, const LlamaModelConfig &model_config, DeviceId device_id)
static int32_t sampleToken (const float *logits, size_t vocab_size, float temperature, int top_k, std::mt19937 &rng)

Private Attributes

LlamaConfig config_
TokenIndexType decode_token_device_
Tensor< dtype_t::INT32, StagingMRdecode_token_staging_
Tensor< TensorDataType::FP32, StagingMRlogits_staging_

Additional Inherited Members

Protected Attributes inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >
GenerationStatistics last_generation_statistics_ {}
 Statistics populated by onGenerating() for each completed generation run.
Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision >
std::unique_ptr< NetworkTypenetwork_
 The owned Network instance.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::LlamaModel< TDeviceType, TPrecision >

LLaMA 3 compatible inference model.

Owns a loaded, built LlamaTransformer and exposes generateStreaming() for autoregressive text generation. Supports the prefill + KV-cache decode two-phase generation loop.

Construction is only possible via fromPretrained(). The network is always in a built, weights-loaded, inference-mode state when generation is called.

Thread safety: not thread-safe; external synchronization required if shared.

Member Typedef Documentation

◆ ModelBase

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::ModelBase = LanguageModel<TDeviceType, TPrecision>

◆ MR

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource

◆ StagingMR

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::StagingMR = CpuMemoryResource

◆ TensorType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ TokenIndexType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::TokenIndexType = Tensor<dtype_t::INT32, MR>

Constructor & Destructor Documentation

◆ LlamaModel() [1/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::LlamaModel ( const LlamaModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:
Here is the caller graph for this function:

◆ LlamaModel() [2/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::LlamaModel ( LlamaModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ ~LlamaModel()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::~LlamaModel ( )
default

◆ LlamaModel() [3/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::LlamaModel ( std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network,
const LlamaConfig & config,
RuntimeMode runtime_mode )
inlineexplicitprivate
Here is the call graph for this function:

Member Function Documentation

◆ configFromMetadata()

template<DeviceType TDeviceType, TensorDataType TPrecision>
LlamaConfig Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::configFromMetadata ( const PretrainedMetadata & metadata)
inlinestaticprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ eosToken()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::eosToken ( ) const
inlineoverrideprivatevirtualnoexcept

LLaMA 3.x end-of-sequence token.

<|end_of_text|> = 128001.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

◆ fromPretrained()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::fromPretrained ( const std::filesystem::path & path,
const LlamaModelConfig & model_config,
DeviceId device_id = DeviceId{ TDeviceType, 0 } )
inlinestatic

Load from third-party pretrained weights.

Reads a Mila-compatible pretrained artifact (e.g. converted from a HuggingFace LLaMA checkpoint) via PretrainedModelReader. The network is built at the context length specified in model_config so RoPE embeddings and KV cache buffers cover the full range.

The model_config carries all deployment decisions:

  • context_length — maximum sequence length to build for
  • weight_quantization — compile-time dispatch to quantized or BF16 path
  • kv_cache_compression — compile-time dispatch to KV cache policy
Parameters
pathPath to the pretrained Llama model artifact.
model_configDeployment configuration for this load.
device_idTarget device; must match TDeviceType.
Returns
Inference-ready LlamaModel.
Exceptions
std::invalid_argumenton device type mismatch or zero context length.
std::runtime_erroron load or parameter binding failure.
std::runtime_errorif model_config requests unsupported quantization (e.g. FP4).

◆ fromPretrainedImpl()

template<DeviceType TDeviceType, TensorDataType TPrecision>
template<WeightQuantPolicy TWeightQuantization, KvCachePolicy TKvCachePolicy>
std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::fromPretrainedImpl ( const std::filesystem::path & path,
const LlamaModelConfig & model_config,
DeviceId device_id )
inlinestaticprivate
Here is the call graph for this function:

◆ getConfig()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const LlamaConfig & Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::getConfig ( ) const
inlinenoexcept

◆ makeTokenTensor()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TokenIndexType Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::makeTokenTensor ( const std::vector< int32_t > & token_ids) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ maxSequenceLength()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int64_t Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::maxSequenceLength ( ) const
inlineoverrideprotectedvirtualnoexcept

Maximum sequence length from LLaMA config.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

◆ onGenerating()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::onGenerating ( const std::vector< int32_t > & prompt_tokens,
const std::function< void(int32_t)> & on_token,
size_t max_new_tokens,
float temperature,
int top_k,
std::stop_token stop )
inlineoverrideprotectedvirtual

Prefill + KV-cache decode loop with per-token streaming.

Phase 1 (prefill): runs the full prompt through prefill() to populate the KV cache and samples the first new token from the last position. Phase 2 (decode): iterates one token at a time until max_new_tokens is reached, EOS is emitted, or stop is requested.

on_token is called for every generated token except EOS.

Parameters
prompt_tokensInput token ids; truncated from the start if they exceed the model's max sequence length.
on_tokenCallback invoked once per generated token (not EOS).
max_new_tokensMaximum number of tokens to generate beyond the prompt.
temperatureSampling temperature; <= 0 selects the argmax.
top_kRestrict sampling to the top-k logits; 0 disables.
stopStop token for cooperative cancellation.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTraining()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::onTraining ( )
inlineoverrideprotectedvirtual

Training loop — not yet implemented for LlamaModel.

Exceptions
std::runtime_erroralways.

Implements Mila::Dnn::Model< TDeviceType, TPrecision >.

◆ operator=() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
LlamaModel & Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::operator= ( const LlamaModel< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ operator=() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
LlamaModel & Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::operator= ( LlamaModel< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ profilePrefill()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::profilePrefill ( const std::vector< int32_t > & token_ids)
inline
Here is the call graph for this function:

◆ sampleFromLogits()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::sampleFromLogits ( const TensorType & logits,
int64_t position,
float temperature,
int top_k,
std::mt19937 & rng )
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ sampleToken()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int32_t Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::sampleToken ( const float * logits,
size_t vocab_size,
float temperature,
int top_k,
std::mt19937 & rng )
inlinestaticprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ stopTokens()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unordered_set< int32_t > Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::stopTokens ( ) const
inlineoverrideprivatevirtual

Llama 3.x generation stop tokens.

Halts on <|end_of_text|> (128001), <|eot_id|> (128009), and <|eom_id|> (128008). The latter two are the primary turn and tool-call boundaries in instruct-format generation.

Reimplemented from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

Here is the caller graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::toString ( ) const
inlineoverridevirtual

Human-readable summary of this model's configuration.

Implements Mila::Dnn::Model< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ truncateIfNeeded()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::truncateIfNeeded ( std::vector< int32_t > & tokens) const
inlineprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ vocabSize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
int64_t Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::vocabSize ( ) const
inlineoverrideprotectedvirtualnoexcept

Vocabulary size from LLaMA config.

Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

Member Data Documentation

◆ config_

template<DeviceType TDeviceType, TensorDataType TPrecision>
LlamaConfig Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::config_
private

◆ decode_token_device_

template<DeviceType TDeviceType, TensorDataType TPrecision>
TokenIndexType Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::decode_token_device_
private

◆ decode_token_staging_

template<DeviceType TDeviceType, TensorDataType TPrecision>
Tensor<dtype_t::INT32, StagingMR> Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::decode_token_staging_
private

◆ logits_staging_

template<DeviceType TDeviceType, TensorDataType TPrecision>
Tensor<TensorDataType::FP32, StagingMR> Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::logits_staging_
private

The documentation for this class was generated from the following file: