|
Mila 0.13.48
Deep Neural Network Library
|
GPT inference model. More...


Public Types | |
| using | GptTransformerType = GptTransformer<TDeviceType, TPrecision> |
| using | ModelBase = LanguageModel<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | TensorType = Tensor<TPrecision, MR> |
| using | TokenIndexType = Tensor<dtype_t::INT32, MR> |
| Public Types inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| using | Base = Model<TDeviceType, TPrecision> |
| Public Types inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| using | NetworkType = Network<TDeviceType, TPrecision> |
Public Member Functions | |
| GptModel (const GptModel &)=delete | |
| GptModel (GptModel &&)=default | |
| ~GptModel ()=default | |
| const GptConfig & | getConfig () const noexcept |
| GptModel & | operator= (const GptModel &)=delete |
| GptModel & | operator= (GptModel &&)=default |
| std::string | toString () const override |
| Human-readable summary of this model's configuration. | |
| Public Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| LanguageModel (const LanguageModel &)=delete | |
| LanguageModel (LanguageModel &&)=default | |
| virtual | ~LanguageModel ()=default |
| std::vector< int32_t > | generate (const std::vector< int32_t > &prompt_tokens, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0) |
| Blocking generation. | |
| void | generateStreaming (const std::vector< int32_t > &prompt_tokens, std::function< void(int32_t)> on_token, size_t max_new_tokens=64, float temperature=1.0f, int top_k=0, std::stop_token stop={}) |
| Synchronous per-token streaming. | |
| const GenerationStatistics & | getLastGenerationStatistics () const noexcept |
| Returns statistics from the most recent generateStreaming() call. | |
| LanguageModel & | operator= (const LanguageModel &)=delete |
| LanguageModel & | operator= (LanguageModel &&)=default |
| Public Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (const Model &)=delete | |
| Model (Model &&)=default | |
| virtual | ~Model ()=default |
| DeviceId | getDeviceId () const noexcept |
| The device this model runs on. | |
| MemoryStats | getMemoryStats () const |
| Current memory allocation breakdown for this model. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| The runtime mode this model was constructed for. | |
| bool | isEval () const noexcept |
| True if this model is currently in eval sub-state. | |
| bool | isInferenceMode () const noexcept |
| True if this model was constructed for inference. | |
| bool | isTrainingMode () const noexcept |
| True if this model was constructed for training. | |
| Model & | operator= (const Model &)=delete |
| Model & | operator= (Model &&)=default |
| void | setEval (bool eval) |
| Toggle eval sub-state for this model. | |
| void | train () |
| Run the training loop for this model. | |
Static Public Member Functions | |
| static std::unique_ptr< GptModel > | fromCheckpoint (const std::filesystem::path &path, DeviceId device_id=DeviceId{ TDeviceType, 0 }) |
| Load from a Mila-native serialized artifact. | |
| static std::unique_ptr< GptModel > | fromPretrained (const std::filesystem::path &path, dim_t context_length, DeviceId device_id=DeviceId{ TDeviceType, 0 }, bool strict=true) |
| Load from third-party pretrained weights. | |
Protected Member Functions | |
| int32_t | eosToken () const noexcept override |
| GPT-2 end-of-text token id. | |
| int64_t | maxSequenceLength () const noexcept override |
| Maximum sequence length from GPT config. | |
| void | onGenerating (const std::vector< int32_t > &prompt_tokens, const std::function< void(int32_t)> &on_token, size_t max_new_tokens, float temperature, int top_k, std::stop_token stop) override |
| Prefill + KV-cache decode loop with per-token streaming. | |
| void | onTraining () override |
| Training loop — not yet implemented for GptModel. | |
| int64_t | vocabSize () const noexcept override |
| Vocabulary size from GPT config. | |
| Protected Member Functions inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| LanguageModel (std::unique_ptr< LanguageNetwork< TDeviceType, TPrecision > > network, RuntimeMode runtime_mode) | |
| const LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () const noexcept |
| LanguageNetwork< TDeviceType, TPrecision > & | getLanguageNetwork () noexcept |
| virtual std::unordered_set< int32_t > | stopTokens () const |
| Protected Member Functions inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode) | |
| Construct with a fully built network and runtime mode. | |
Private Member Functions | |
| GptModel (std::unique_ptr< GptTransformerType > network, const GptConfig &config) | |
| GptModel (std::unique_ptr< GptTransformerType > network, const GptConfig &config, RuntimeMode runtime_mode) | |
| TokenIndexType | makeTokenTensor (const std::vector< int32_t > &token_ids) const |
| int32_t | sampleFromLogits (const TensorType &logits, int64_t position, float temperature, int top_k, std::mt19937 &rng) const |
| void | truncateIfNeeded (std::vector< int32_t > &tokens) const |
Static Private Member Functions | |
| static GptConfig | configFromMetadata (const PretrainedMetadata &metadata) |
| static int32_t | sampleToken (const float *logits, size_t vocab_size, float temperature, int top_k, std::mt19937 &rng) |
Private Attributes | |
| GptConfig | config_ |
Static Private Attributes | |
| static constexpr int32_t | eos_token_ = 50256 |
Additional Inherited Members | |
| Protected Attributes inherited from Mila::Dnn::LanguageModel< TDeviceType, TPrecision > | |
| GenerationStatistics | last_generation_statistics_ {} |
| Statistics populated by onGenerating() for each completed generation run. | |
| Protected Attributes inherited from Mila::Dnn::Model< TDeviceType, TPrecision > | |
| std::unique_ptr< NetworkType > | network_ |
| The owned Network instance. | |
GPT inference model.
Owns a loaded, built GptTransformer and exposes generateStreaming() for autoregressive text generation.
Construction is only possible via fromPretrained() or fromCheckpoint(). The network is always in a built, weights-loaded, inference-mode state when generation is called.
| using Mila::Dnn::GptModel< TDeviceType, TPrecision >::GptTransformerType = GptTransformer<TDeviceType, TPrecision> |
| using Mila::Dnn::GptModel< TDeviceType, TPrecision >::ModelBase = LanguageModel<TDeviceType, TPrecision> |
| using Mila::Dnn::GptModel< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using Mila::Dnn::GptModel< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR> |
| using Mila::Dnn::GptModel< TDeviceType, TPrecision >::TokenIndexType = Tensor<dtype_t::INT32, MR> |
|
delete |


|
default |

|
default |
|
inlineexplicitprivate |
|
inlineexplicitprivate |
|
inlinestaticprivate |

|
inlineoverrideprotectedvirtualnoexcept |
GPT-2 end-of-text token id.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
inlinestatic |
Load from a Mila-native serialized artifact.
Reads a checkpoint or weights-only artifact produced by GptTransformer::save() via ModelArchive.
| path | Path to the Mila archive. |
| device_id | Target device. |
|
inlinestatic |
Load from third-party pretrained weights.
Reads weights from a Mila-compatible pretrained artifact produced by converting third-party checkpoints (e.g. HuggingFace GPT-2) via PretrainedModelReader.
| path | Path to the pretrained artifact. |
| context_length | Maximum sequence length to build for. |
| device_id | Target device. |
| strict | Throws on unknown parameter names if true. |
|
inlinenoexcept |
|
inlineprivate |


|
inlineoverrideprotectedvirtualnoexcept |
Maximum sequence length from GPT config.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
inlineoverrideprotectedvirtual |
Prefill + KV-cache decode loop with per-token streaming.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.

|
inlineoverrideprotectedvirtual |
Training loop — not yet implemented for GptModel.
| std::runtime_error | always. |
Implements Mila::Dnn::Model< TDeviceType, TPrecision >.
|
delete |

|
default |

|
inlineprivate |


|
inlinestaticprivate |


|
inlineoverridevirtual |
Human-readable summary of this model's configuration.
Implements Mila::Dnn::Model< TDeviceType, TPrecision >.
|
inlineprivate |


|
inlineoverrideprotectedvirtualnoexcept |
Vocabulary size from GPT config.
Implements Mila::Dnn::LanguageModel< TDeviceType, TPrecision >.
|
private |
|
staticconstexprprivate |