|
Mila 0.13.48
Deep Neural Network Library
|


Public Types | |
| using | Base = Model<TDeviceType, TPrecision> |
| Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| using | NetworkType = Network<TDeviceType, TPrecision> |
Public Member Functions | |
| LanguageModel (const LanguageModel &)=delete | |
| LanguageModel (LanguageModel &&)=default | |
| virtual | ~LanguageModel ()=default |
| std::vector< int32_t > | generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0) |
| Blocking generation. | |
| void | generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={}) |
| Synchronous per-token streaming. | |
| const GenerationStatistics & | getLastGenerationStatistics () const noexcept |
| Returns statistics from the most recent generateStreaming() call. | |
| LanguageModel & | operator= (const LanguageModel &)=delete |
| LanguageModel & | operator= (LanguageModel &&)=default |
| Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (const Model &)=delete | |
| Model (Model &&)=default | |
| virtual | ~Model ()=default |
| DeviceId | getDeviceId () const noexcept |
| The device this model runs on. | |
| MemoryStats | getMemoryStats () const |
| Current memory allocation breakdown for this model. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| The runtime mode this model was constructed for. | |
| bool | isEval () const noexcept |
| True if this model is currently in eval sub-state. | |
| bool | isInferenceMode () const noexcept |
| True if this model was constructed for inference. | |
| bool | isTrainingMode () const noexcept |
| True if this model was constructed for training. | |
| Model & | operator= (const Model &)=delete |
| Model & | operator= (Model &&)=default |
| void | setEval (bool eval) |
| Toggle eval sub-state for this model. | |
| virtual std::string | toString () const =0 |
| Human-readable summary of this model's configuration. | |
| void | train () |
| Run the training loop for this model. | |
Protected Member Functions | |
| LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode) | |
| virtual int32_t | eosToken () const noexcept=0 |
| const LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () const noexcept |
| LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () noexcept |
| virtual int64_t | maxSequenceLength () const noexcept=0 |
| virtual void | onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop)=0 |
| Prefill + decode implementation hook. | |
| virtual std::unordered_set< int32_t > | stopTokens () const |
| virtual int64_t | vocabSize () const noexcept=0 |
| Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode) | |
| Construct with a fully built network and runtime mode. | |
| virtual void | onTraining ()=0 |
| Training loop hook — derived class owns the implementation. | |
Protected Attributes | |
| GenerationStatistics | last_generation_statistics_ {} |
| Statistics populated by onGenerating() for each completed generation run. | |
| Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| std::unique_ptr< NetworkType > | network_ |
| The owned Network instance. | |
| using Mila::Dnn::LanguageModel< TDeviceType, TPrecision >::Base = Model<TDeviceType, TPrecision> |
|
delete |


|
default |

|
virtualdefault |
|
inlineexplicitprotected |
|
protectedpure virtualnoexcept |
Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

|
inline |
Blocking generation.
Returns the prompt tokens followed by all generated tokens (EOS excluded).
| prompt_tokens | Input token ids. |
| max_new_tokens | Maximum tokens to generate beyond the prompt. |
| temperature | Sampling temperature; <= 0 selects argmax. |
| top_k | Top-k filter; 0 disables. |

|
inline |
Synchronous per-token streaming.
Blocks on the caller's thread until generation completes or stop is requested.
on_token is invoked on the caller's thread for every generated token (EOS excluded). Callers that own their own threading — such as the Python ModelWorker's single-thread executor — should use this directly.
| prompt_tokens | Input token ids. |
| on_token | Per-token callback invoked on the caller's thread. |
| max_new_tokens | Maximum tokens to generate beyond the prompt. |
| temperature | Sampling temperature; <= 0 selects argmax. |
| top_k | Top-k filter; 0 disables. |
| stop | Stop token for cooperative cancellation. |

|
inlineprotectednoexcept |
|
inlineprotectednoexcept |

|
inlinenodiscardnoexcept |
Returns statistics from the most recent generateStreaming() call.
Only valid after at least one generateStreaming() call has returned. Check GenerationStatistics::valid() before using the values.
|
protectedpure virtualnoexcept |
|
protectedpure virtual |
Prefill + decode implementation hook.
Derived classes own the full autoregressive generation loop. on_token must be called for every generated token except EOS. stop.stop_requested() must be checked on each decode step and generation must abort early when signalled.
| prompt_tokens | Input token ids. |
| on_token | Per-token callback. |
| max_new_tokens | Maximum tokens to generate beyond the prompt. |
| temperature | Sampling temperature; <= 0 selects argmax. |
| top_k | Top-k filter; 0 disables. |
| stop | Stop token for cooperative cancellation. |
Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.
|
delete |

|
default |

|
inlineprotectedvirtual |
Reimplemented in Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.


|
protectedpure virtualnoexcept |
Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.


|
protected |
Statistics populated by onGenerating() for each completed generation run.