|
Mila 0.13.48
Deep Neural Network Library
|
LLaMA 3 compatible inference model. More...


Public Types | |
| using | ModelBase = LanguageModel<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | StagingMR = CpuMemoryResource |
| using | TensorType = Tensor<TPrecision, MR> |
| using | TokenIndexType = Tensor<dtype_t::INT32, MR> |
| Public Types inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| using | Base = Model<TDeviceType, TPrecision> |
| Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| using | NetworkType = Network<TDeviceType, TPrecision> |
Public Member Functions | |
| LlamaModel (const LlamaModel &)=delete | |
| LlamaModel (LlamaModel &&)=default | |
| ~LlamaModel ()=default | |
| const LlamaConfig & | getConfig () const noexcept |
| LlamaModel & | operator= (const LlamaModel &)=delete |
| LlamaModel & | operator= (LlamaModel &&)=default |
| void | profilePrefill (const std::vector< int32_t > &token_ids) |
| std::string | toString () const override |
| Human-readable summary of this model's configuration. | |
| Public Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| LanguageModel (const LanguageModel &)=delete | |
| LanguageModel (LanguageModel &&)=default | |
| virtual | ~LanguageModel ()=default |
| std::vector< int32_t > | generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0) |
| Blocking generation. | |
| void | generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={}) |
| Synchronous per-token streaming. | |
| const GenerationStatistics & | getLastGenerationStatistics () const noexcept |
| Returns statistics from the most recent generateStreaming() call. | |
| LanguageModel & | operator= (const LanguageModel &)=delete |
| LanguageModel & | operator= (LanguageModel &&)=default |
| Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (const Model &)=delete | |
| Model (Model &&)=default | |
| virtual | ~Model ()=default |
| DeviceId | getDeviceId () const noexcept |
| The device this model runs on. | |
| MemoryStats | getMemoryStats () const |
| Current memory allocation breakdown for this model. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| The runtime mode this model was constructed for. | |
| bool | isEval () const noexcept |
| True if this model is currently in eval sub-state. | |
| bool | isInferenceMode () const noexcept |
| True if this model was constructed for inference. | |
| bool | isTrainingMode () const noexcept |
| True if this model was constructed for training. | |
| Model & | operator= (const Model &)=delete |
| Model & | operator= (Model &&)=default |
| void | setEval (bool eval) |
| Toggle eval sub-state for this model. | |
| void | train () |
| Run the training loop for this model. | |
Static Public Member Functions | |
| static std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > | fromPretrained (const std::filesystem::path &path, const LlamaModelConfig &model_config, DeviceId device_id=DeviceId{ TDeviceType, 0 }) |
| Load from third-party pretrained weights. | |
Protected Member Functions | |
| int64_t | maxSequenceLength () const noexcept override |
| Maximum sequence length from LLaMA config. | |
| void | onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop) override |
| Prefill + KV-cache decode loop with per-token streaming. | |
| void | onTraining () override |
| Training loop — not yet implemented for LlamaModel. | |
| int64_t | vocabSize () const noexcept override |
| Vocabulary size from LLaMA config. | |
| Protected Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode) | |
| const LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () const noexcept |
| LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () noexcept |
| Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode) | |
| Construct with a fully built network and runtime mode. | |
Private Member Functions | |
| LlamaModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, const LlamaConfig &config, RuntimeMode runtime_mode) | |
| int32_t | eosToken () const noexcept override |
| LLaMA 3.x end-of-sequence token. | |
| TokenIndexType | makeTokenTensor (const std::vector< int32_t > &token_ids) const |
| int32_t | sampleFromLogits (const TensorType &logits, int64_t position, float temperature, int top_k, std::mt19937 &rng) |
| std::unordered_set< int32_t > | stopTokens () const override |
| Llama 3.x generation stop tokens. | |
| void | truncateIfNeeded (std::vector< int32_t > &tokens) const |
Static Private Member Functions | |
| static LlamaConfig | configFromMetadata (const PretrainedMetadata &metadata) |
| template<WeightQuantPolicy TWeightQuantization, KvCachePolicy TKvCachePolicy> | |
| static std::unique_ptr< LlamaModel< TDeviceType, TPrecision > > | fromPretrainedImpl (const std::filesystem::path &path, const LlamaModelConfig &model_config, DeviceId device_id) |
| static int32_t | sampleToken (const float *logits, size_t vocab_size, float temperature, int top_k, std::mt19937 &rng) |
Private Attributes | |
| LlamaConfig | config_ |
| TokenIndexType | decode_token_device_ |
| Tensor< dtype_t::INT32, StagingMR > | decode_token_staging_ |
| Tensor< TensorDataType::FP32, StagingMR > | logits_staging_ |
Additional Inherited Members | |
| Protected Attributes inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| GenerationStatistics | last_generation_statistics_ {} |
| Statistics populated by onGenerating() for each completed generation run. | |
| Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| std::unique_ptr< NetworkType > | network_ |
| The owned Network instance. | |
LLaMA 3 compatible inference model.
Owns a loaded, built LlamaTransformer and exposes generateStreaming() for autoregressive text generation. Supports the prefill + KV-cache decode two-phase generation loop.
Construction is only possible via fromPretrained(). The network is always in a built, weights-loaded, inference-mode state when generation is called.
Thread safety: not thread-safe; external synchronization required if shared.
| using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::ModelBase = LanguageModel<TDeviceType, TPrecision> |
| using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::StagingMR = CpuMemoryResource |
| using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR> |
| using Mila::Dnn::LlamaModel< TDeviceType, TPrecision >::TokenIndexType = Tensor<dtype_t::INT32, MR> |
|
delete |


|
default |

|
default |
|
inlineexplicitprivate |

|
inlinestaticprivate |


|
inlineoverrideprivatevirtualnoexcept |
LLaMA 3.x end-of-sequence token.
<|end_of_text|> = 128001.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
inlinestatic |
Load from third-party pretrained weights.
Reads a Mila-compatible pretrained artifact (e.g. converted from a HuggingFace LLaMA checkpoint) via PretrainedModelReader. The network is built at the context length specified in model_config so RoPE embeddings and KV cache buffers cover the full range.
The model_config carries all deployment decisions:
| path | Path to the pretrained Llama model artifact. |
| model_config | Deployment configuration for this load. |
| device_id | Target device; must match TDeviceType. |
| std::invalid_argument | on device type mismatch or zero context length. |
| std::runtime_error | on load or parameter binding failure. |
| std::runtime_error | if model_config requests unsupported quantization (e.g. FP4). |
|
inlinestaticprivate |

|
inlinenoexcept |
|
inlineprivate |


|
inlineoverrideprotectedvirtualnoexcept |
Maximum sequence length from LLaMA config.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
inlineoverrideprotectedvirtual |
Prefill + KV-cache decode loop with per-token streaming.
Phase 1 (prefill): runs the full prompt through prefill() to populate the KV cache and samples the first new token from the last position. Phase 2 (decode): iterates one token at a time until max_new_tokens is reached, EOS is emitted, or stop is requested.
on_token is called for every generated token except EOS.
| prompt_tokens | Input token ids; truncated from the start if they exceed the model's max sequence length. |
| on_token | Callback invoked once per generated token (not EOS). |
| max_new_tokens | Maximum number of tokens to generate beyond the prompt. |
| temperature | Sampling temperature; <= 0 selects the argmax. |
| top_k | Restrict sampling to the top-k logits; 0 disables. |
| stop | Stop token for cooperative cancellation. |
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

|
inlineoverrideprotectedvirtual |
Training loop — not yet implemented for LlamaModel.
| std::runtime_error | always. |
Implements Mila::Dnn::Model< TDeviceType, TPrecision >.
|
delete |

|
default |

|
inline |

|
inlineprivate |


|
inlinestaticprivate |


|
inlineoverrideprivatevirtual |
Llama 3.x generation stop tokens.
Halts on <|end_of_text|> (128001), <|eot_id|> (128009), and <|eom_id|> (128008). The latter two are the primary turn and tool-call boundaries in instruct-format generation.
Reimplemented from Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

|
inlineoverridevirtual |
Human-readable summary of this model's configuration.
Implements Mila::Dnn::Model< TDeviceType, TPrecision >.

|
inlineprivate |


|
inlineoverrideprotectedvirtualnoexcept |
Vocabulary size from LLaMA config.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
private |
|
private |
|
private |
|
private |