|
Mila 0.13.48
Deep Neural Network Library
|
GPT-2 style transformer (decoder-only) for autoregressive token prediction. More...


Public Types | |
| using | ComponentPtr = typename NetworkBase::ComponentPtr |
| using | EncoderType = Lpe<TDeviceType, dtype_t::INT32, TPrecision> |
| using | LayerNormType = LayerNorm<TDeviceType, TPrecision> |
| using | LinearType = Linear<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | NetworkBase = LanguageNetwork<TDeviceType, TPrecision> |
| using | TensorType = Tensor<TPrecision, MR> |
| using | TokenIndexType = Tensor<dtype_t::INT32, MR> |
| using | TransformerBlockType = GptBlock<TDeviceType, TPrecision> |
| Public Types inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision > | |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | NetworkBase = Network<TDeviceType, TPrecision> |
| using | TensorType = Tensor<TPrecision, MR> |
| using | TokenIndexType = Tensor<TensorDataType::INT32, MR> |
| Public Types inherited from Mila::Dnn::Network< TDeviceType, TPrecision > | |
| using | ComponentPtr = typename CompositeBase::ComponentPtr |
| using | CompositeBase = CompositeComponent<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| using | ComponentBase = Component<TDeviceType, TPrecision> |
| using | ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>> |
Public Member Functions | |
| GptTransformer (const std::string &name, const GptConfig &config, DeviceId device_id) | |
| Construct Gpt type transformer. | |
| ~GptTransformer () override=default | |
| TokenIndexType & | backward (const TokenIndexType &input, const TensorType &output_grad) override |
| TensorType & | decode (const TokenIndexType &input, int position) override |
| Inference-only single-token decode pass. | |
| TensorType & | forward (const TokenIndexType &input) override |
| Load GptTransformer from archive. | |
| IExecutionContext * | getExecutionContext () const |
| MemoryStats | getMemoryStats () const override |
| Return the current memory allocation breakdown for this component. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| void | loadParameters (PretrainedModelReader &reader, bool strict) |
| Initialize this transformer's components from a GPT-2 checkpoint. | |
| TensorType & | prefill (const TokenIndexType &input) override |
| Inference prefill — process full prompt and return last-token logits. | |
| std::string | toString () const override |
| Generate a human-readable description. | |
| void | zeroGradients () override |
| Clear all model-owned gradients for this component. | |
| Public Member Functions inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision > | |
| LanguageNetwork (const std::string &name) | |
| ~LanguageNetwork () override=default | |
| virtual TokenIndexType & | backward (const TokenIndexType &input, const TensorType &output_grad)=0 |
| Full backward pass (training). | |
| virtual TensorType & | decode (const TokenIndexType &input, int position)=0 |
| Inference decode — single-token autoregressive step. | |
| virtual TensorType & | forward (const TokenIndexType &input)=0 |
| Full-sequence forward pass. | |
| virtual TensorType & | prefill (const TokenIndexType &input)=0 |
| Inference prefill — process full prompt and populate the KV cache. | |
| Public Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision > | |
| Network (const std::string &name) | |
| Construct network (context managed by derived class). | |
| ~Network () override=default | |
| template<typename TOptimizer, typename TConfig> | |
| std::shared_ptr< TOptimizer > | createOptimizer (const TConfig &config) |
| Create and configure an optimizer for this network's parameters. | |
| DeviceId | getDeviceId () const noexcept |
| Get the compute device for this composite. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| void | save (ModelArchive &archive, SerializationMode mode) const |
| Save network to archive. | |
| void | synchronize () override |
| Synchronize all child components. | |
| std::string | toString () const override |
| Generate a human-readable description. | |
| Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| CompositeComponent (CompositeComponent &&) noexcept=default | |
| CompositeComponent (const CompositeComponent &)=delete | |
| CompositeComponent (const std::string &name) | |
| Construct composite component with name. | |
| virtual | ~CompositeComponent ()=default |
| CompositeComponent & | addComponent (ComponentPtr component) |
| Add a pre-constructed child component (chainable). | |
| size_t | childCount () const noexcept |
| Get the number of direct children. | |
| void | clearComponents () |
| Clear all child components. | |
| ComponentPtr | findComponent (const std::string &path) const |
| Resolve a dot-separated component path within this composite. | |
| ComponentPtr | getComponent (const std::string &name) const |
| Retrieve a direct child component by name. | |
| const std::vector< ComponentPtr > & | getComponents () const |
| Get all child components in insertion order. | |
| std::vector< ITensor * > | getGradients () const override |
| Get all parameter gradients from all children. | |
| std::vector< ITensor * > | getParameters () const override |
| Get all parameters from all children. | |
| bool | hasChildren () const noexcept |
| Check if this composite has any children. | |
| bool | hasComponent (const std::string &name) const |
| Check if a named child component exists. | |
| CompositeComponent & | operator= (CompositeComponent &&) noexcept=default |
| CompositeComponent & | operator= (const CompositeComponent &)=delete |
| size_t | parameterCount () const override |
| Count parameters across all children. | |
| bool | removeComponent (const std::string &name) |
| Get the named child components map. | |
| ComponentPtr | tryFindComponent (const std::string &path) const |
| Try to resolve a dot-separated component path within this composite. | |
| Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
Static Public Member Functions | |
| static std::unique_ptr< GptTransformer< TDeviceType, TPrecision > > | fromPretrained (const std::filesystem::path &model_path, std::size_t batch_size, std::size_t seq_length, DeviceId device_id=DeviceId{ TDeviceType, 0 }, bool strict=true) |
| Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
Protected Member Functions | |
| void | onBuilding (const BuildContext &context) override |
| Hook invoked by build() to allocate component buffers. | |
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| Hook invoked when training mode is about to change. | |
| void | save_ (ModelArchive &archive, SerializationMode) const override |
| Hook for concrete classes to save type-specific state. | |
| Protected Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision > | |
| void | verifyArchitectureCompatibility (const PretrainedMetadata &metadata) |
| Verify that imported model is compatible with network architecture. | |
| Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| template<typename TComponent> | |
| std::shared_ptr< TComponent > | getComponentAs (const std::string &name) const |
| Retrieve a typed child component by name. | |
| void | onExecutionContextSet () override |
| Hook invoked after ExecutionContext is set. | |
| virtual void | optimize () |
| Virtual hook for graph optimization after construction. | |
| Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| template<TensorDataType TParameterPrecision, typename TMemoryResource> | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Private Member Functions | |
| void | createGraph () |
| std::pair< std::string, std::string > | parseParameterPath (const std::string &full_name) const |
| void | validateBuildContext (const BuildContext &context) const |
| void | validateInputShape (const shape_t &input_shape) const |
Static Private Member Functions | |
| static auto | createConfigFromMetadata (const PretrainedMetadata &metadata) -> GptConfig |
| Create GptConfig from Mila metadata. | |
Private Attributes | |
| int64_t | batch_size_ { 0 } |
| std::vector< TensorType * > | block_input_ptrs_ |
| std::vector< TensorType * > | block_output_ptrs_ |
| GptConfig | config_ |
| shape_t | embedding_shape_ |
| std::shared_ptr< EncoderType > | encoder_ { nullptr } |
| TensorType * | encoder_out_ptr_ { nullptr } |
| std::shared_ptr< LayerNormType > | final_layernorm_ { nullptr } |
| shape_t | leading_shape_ |
| std::shared_ptr< LinearType > | lm_head_ { nullptr } |
| TensorType * | logits_ptr_ { nullptr } |
| TensorType * | normalized_ptr_ { nullptr } |
| shape_t | output_shape_ |
| std::unique_ptr< IExecutionContext > | owned_context_ { nullptr } |
| int64_t | seq_length_ { 0 } |
| std::vector< std::shared_ptr< TransformerBlockType > > | transformer_blocks_ |
Additional Inherited Members | |
| Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| BuildContext | build_context_ { shape_t{ 1 }, RuntimeMode::Training } |
| The BuildContext stored at build time. | |
GPT-2 style transformer (decoder-only) for autoregressive token prediction.
Template parameters:
|
inlineexplicitexport |
|
overrideexportdefault |
|
inlineoverrideexport |


|
inlinestaticexportprivate |
|
inlineexportprivate |


|
inlineoverrideexport |
Inference-only single-token decode pass.
Mirrors forward() exactly except each transformer block is driven via decode() rather than forward(). Each block's decode() delegates to attn_->decode() for the attention step — Attention decides internally whether to use the fast KV cache path or fall back to forward(). All other components in each block use forward() unchanged.
The encoder (token + position embeddings) and final LayerNorm + LM head are identical to forward() — only the block traversal differs.
Precondition: forward() must have been called at least once (prefill) before decode() is called. Attention internally manages cache state — no explicit initializeKVCache / resetKVCache needed here.
Calling forward() again after decode() steps automatically resets the KV cache and begins a new prefill session.
| input | Single-token input [B, 1] token indices. |
| position | Current sequence position (0-based). |

|
inlineoverrideexport |
Load GptTransformer from archive.
Reads metadata, constructs network, builds with saved shape and loads weights.

|
inlinestaticexport |
|
inlineexport |


|
inlineoverrideexportvirtual |
Return the current memory allocation breakdown for this component.
Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:
After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients
For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.
May be called at any time — no lifecycle preconditions.
Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

|
inlineoverrideexportvirtual |
Get the component type identifier.
Used for serialization and runtime type identification.
Implements Mila::Dnn::Component< TDeviceType, TPrecision >.
|
inlineexport |
Initialize this transformer's components from a GPT-2 checkpoint.
Delegates to small helpers that load checkpoint blobs and apply them to the encoder, per-layer blocks, and final layer-norm.
Load parameters (weights and biases) from an already-opened PretrainedModelReader
Separated from fromPretrained to allow flexibility in weight loading

|
inlineoverrideexportprotectedvirtual |
Hook invoked by build() to allocate component buffers.
Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.
The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.
| config | Build-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension. |
Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

|
inlineoverrideexportprotectedvirtual |
Hook invoked when training mode is about to change.
Propagates the new mode to all child components. The hook runs with the Component's training mutex held; it MUST NOT call setTraining().
| is_training | New training mode (true = training, false = eval) |
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

|
inlineexportprivate |

|
inlineoverrideexport |
Inference prefill — process full prompt and return last-token logits.
Populates the KV cache across all transformer blocks by running the full prompt through encoder + blocks via forward(). Then extracts only the last token's representation for the final LayerNorm + LM head, avoiding the T=1 output buffer overflow that forward() would cause on those components.
Unlike LlamaTransformer::prefill(), GPT does not need chunked prefill or explicit position offsets (no RoPE). The full sequence is processed in a single pass.
| input | Full prompt token indices [B, T]. |

|
inlineoverrideexportprotectedvirtual |
Hook for concrete classes to save type-specific state.
REQUIRED override for concrete networks. Must write:
This metadata enables the concrete class's Load() method to reconstruct the network.
Example implementation:
| archive | Archive to write to |
| mode | Serialization mode (passed from save()) |
Implements Mila::Dnn::Network< TDeviceType, TPrecision >.

|
inlineoverrideexportvirtual |
Generate a human-readable description.
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

|
inlineexportprivate |


|
inlineexportprivate |

|
inlineoverrideexportvirtual |
Clear all model-owned gradients for this component.
Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.
Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.
