Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GptTransformer< TDeviceType, TPrecision > Class Template Referenceexport

GPT-2 style transformer (decoder-only) for autoregressive token prediction. More...

Inheritance diagram for Mila::Dnn::GptTransformer< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::GptTransformer< TDeviceType, TPrecision >:

Public Types

using ComponentPtr = typename NetworkBase::ComponentPtr
using EncoderType = Lpe<TDeviceType, dtype_t::INT32, TPrecision>
using LayerNormType = LayerNorm<TDeviceType, TPrecision>
using LinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = LanguageNetwork<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
using TransformerBlockType = GptBlock<TDeviceType, TPrecision>
Public Types inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = Network<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<TensorDataType::INT32, MR>
Public Types inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
using ComponentPtr = typename CompositeBase::ComponentPtr
using CompositeBase = CompositeComponent<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 GptTransformer (const std::string &name, const GptConfig &config, DeviceId device_id)
 Construct Gpt type transformer.
 ~GptTransformer () override=default
TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad) override
TensorTypedecode (const TokenIndexType &input, int position) override
 Inference-only single-token decode pass.
TensorTypeforward (const TokenIndexType &input) override
 Load GptTransformer from archive.
IExecutionContextgetExecutionContext () const
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
void loadParameters (PretrainedModelReader &reader, bool strict)
 Initialize this transformer's components from a GPT-2 checkpoint.
TensorTypeprefill (const TokenIndexType &input) override
 Inference prefill — process full prompt and return last-token logits.
std::string toString () const override
 Generate a human-readable description.
void zeroGradients () override
 Clear all model-owned gradients for this component.
Public Member Functions inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >
 LanguageNetwork (const std::string &name)
 ~LanguageNetwork () override=default
virtual TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad)=0
 Full backward pass (training).
virtual TensorTypedecode (const TokenIndexType &input, int position)=0
 Inference decode — single-token autoregressive step.
virtual TensorTypeforward (const TokenIndexType &input)=0
 Full-sequence forward pass.
virtual TensorTypeprefill (const TokenIndexType &input)=0
 Inference prefill — process full prompt and populate the KV cache.
Public Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
 Network (const std::string &name)
 Construct network (context managed by derived class).
 ~Network () override=default
template<typename TOptimizer, typename TConfig>
std::shared_ptr< TOptimizer > createOptimizer (const TConfig &config)
 Create and configure an optimizer for this network's parameters.
DeviceId getDeviceId () const noexcept
 Get the compute device for this composite.
const ComponentType getType () const override
 Get the component type identifier.
void save (ModelArchive &archive, SerializationMode mode) const
 Save network to archive.
void synchronize () override
 Synchronize all child components.
std::string toString () const override
 Generate a human-readable description.
Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.

Static Public Member Functions

static std::unique_ptr< GptTransformer< TDeviceType, TPrecision > > fromPretrained (const std::filesystem::path &model_path, std::size_t batch_size, std::size_t seq_length, DeviceId device_id=DeviceId{ TDeviceType, 0 }, bool strict=true)
Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.

Protected Member Functions

void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
void save_ (ModelArchive &archive, SerializationMode) const override
 Hook for concrete classes to save type-specific state.
Protected Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
void verifyArchitectureCompatibility (const PretrainedMetadata &metadata)
 Verify that imported model is compatible with network architecture.
Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

void createGraph ()
std::pair< std::string, std::string > parseParameterPath (const std::string &full_name) const
void validateBuildContext (const BuildContext &context) const
void validateInputShape (const shape_t &input_shape) const

Static Private Member Functions

static auto createConfigFromMetadata (const PretrainedMetadata &metadata) -> GptConfig
 Create GptConfig from Mila metadata.

Private Attributes

int64_t batch_size_ { 0 }
std::vector< TensorType * > block_input_ptrs_
std::vector< TensorType * > block_output_ptrs_
GptConfig config_
shape_t embedding_shape_
std::shared_ptr< EncoderTypeencoder_ { nullptr }
TensorTypeencoder_out_ptr_ { nullptr }
std::shared_ptr< LayerNormTypefinal_layernorm_ { nullptr }
shape_t leading_shape_
std::shared_ptr< LinearTypelm_head_ { nullptr }
TensorTypelogits_ptr_ { nullptr }
TensorTypenormalized_ptr_ { nullptr }
shape_t output_shape_
std::unique_ptr< IExecutionContextowned_context_ { nullptr }
int64_t seq_length_ { 0 }
std::vector< std::shared_ptr< TransformerBlockType > > transformer_blocks_

Additional Inherited Members

Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::GptTransformer< TDeviceType, TPrecision >

GPT-2 style transformer (decoder-only) for autoregressive token prediction.

Template parameters:

  • TDeviceType: device type (Cpu/Cuda)
  • TPrecision: tensor precision

Constructor & Destructor Documentation

◆ GptTransformer()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::GptTransformer ( const std::string & name,
const GptConfig & config,
DeviceId device_id )
inlineexplicitexport

Construct Gpt type transformer.

Parameters
nameNetwork name
configGPT transformer configuration
device_idDevice identifier for execution
Exceptions
std::invalid_argumenton invalid config or device mismatch
Here is the call graph for this function:

◆ ~GptTransformer()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::~GptTransformer ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TokenIndexType & Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::backward ( const TokenIndexType & input,
const TensorType & output_grad )
inlineoverrideexport
Here is the call graph for this function:
Here is the caller graph for this function:

◆ createConfigFromMetadata()

template<DeviceType TDeviceType, TensorDataType TPrecision>
auto Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::createConfigFromMetadata ( const PretrainedMetadata & metadata) ->GptConfig
inlinestaticexportprivate

Create GptConfig from Mila metadata.

Here is the call graph for this function:

◆ createGraph()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::createGraph ( )
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ decode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::decode ( const TokenIndexType & input,
int position )
inlineoverrideexport

Inference-only single-token decode pass.

Mirrors forward() exactly except each transformer block is driven via decode() rather than forward(). Each block's decode() delegates to attn_->decode() for the attention step — Attention decides internally whether to use the fast KV cache path or fall back to forward(). All other components in each block use forward() unchanged.

The encoder (token + position embeddings) and final LayerNorm + LM head are identical to forward() — only the block traversal differs.

Precondition: forward() must have been called at least once (prefill) before decode() is called. Attention internally manages cache state — no explicit initializeKVCache / resetKVCache needed here.

Calling forward() again after decode() steps automatically resets the KV cache and begins a new prefill session.

Parameters
inputSingle-token input [B, 1] token indices.
positionCurrent sequence position (0-based).
Returns
Reference to logits tensor [B, 1, vocab_size].
Here is the call graph for this function:

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::forward ( const TokenIndexType & input)
inlineoverrideexport

Load GptTransformer from archive.

Reads metadata, constructs network, builds with saved shape and loads weights.

Here is the call graph for this function:

◆ fromPretrained()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< GptTransformer< TDeviceType, TPrecision > > Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::fromPretrained ( const std::filesystem::path & model_path,
std::size_t batch_size,
std::size_t seq_length,
DeviceId device_id = DeviceId{ TDeviceType, 0 },
bool strict = true )
inlinestaticexport

◆ getExecutionContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
IExecutionContext * Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::getExecutionContext ( ) const
inlineexport
Here is the call graph for this function:
Here is the caller graph for this function:

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision>
MemoryStats Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const ComponentType Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ loadParameters()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::loadParameters ( PretrainedModelReader & reader,
bool strict )
inlineexport

Initialize this transformer's components from a GPT-2 checkpoint.

Delegates to small helpers that load checkpoint blobs and apply them to the encoder, per-layer blocks, and final layer-norm.

Load parameters (weights and biases) from an already-opened PretrainedModelReader

Separated from fromPretrained to allow flexibility in weight loading

Here is the call graph for this function:

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
const std::string getName() const
Get the component's name identifier.
Definition Component.ixx:410
GptConfig config_
Definition GptTransformer.ixx:646
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::onTrainingModeChanging ( TrainingMode training_mode)
inlineoverrideexportprotectedvirtual

Hook invoked when training mode is about to change.

Propagates the new mode to all child components. The hook runs with the Component's training mutex held; it MUST NOT call setTraining().

Parameters
is_trainingNew training mode (true = training, false = eval)

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ parseParameterPath()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::pair< std::string, std::string > Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::parseParameterPath ( const std::string & full_name) const
inlineexportprivate
Here is the caller graph for this function:

◆ prefill()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::prefill ( const TokenIndexType & input)
inlineoverrideexport

Inference prefill — process full prompt and return last-token logits.

Populates the KV cache across all transformer blocks by running the full prompt through encoder + blocks via forward(). Then extracts only the last token's representation for the final LayerNorm + LM head, avoiding the T=1 output buffer overflow that forward() would cause on those components.

Unlike LlamaTransformer::prefill(), GPT does not need chunked prefill or explicit position offsets (no RoPE). The full sequence is processed in a single pass.

Parameters
inputFull prompt token indices [B, T].
Returns
Logits for the last token [B, 1, vocab_size].
Here is the call graph for this function:

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportprotectedvirtual

Hook for concrete classes to save type-specific state.

REQUIRED override for concrete networks. Must write:

  • Type identifier (e.g., "type": "MnistClassifier")
  • Configuration parameters (batch_size, architecture constants)
  • Shape metadata (for validation during Load())

This metadata enables the concrete class's Load() method to reconstruct the network.

Example implementation:

void save_(ModelArchive& archive, SerializationMode mode) const override
{
json meta;
meta["type"] = "MnistClassifier"; // Type identifier for runtime dispatch
meta["batch_size"] = batch_size_;
meta["input_shape"] = leading_shape_;
// ... other configuration
archive.writeJson("network/classifier_meta.json", meta);
}
void save_(ModelArchive &archive, SerializationMode) const override
Hook for concrete classes to save type-specific state.
Definition GptTransformer.ixx:569
int64_t batch_size_
Definition GptTransformer.ixx:651
shape_t leading_shape_
Definition GptTransformer.ixx:648
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
nlohmann::json json
Definition Linear.ixx:57
Parameters
archiveArchive to write to
modeSerialization mode (passed from save())

Implements Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::toString ( ) const
inlineoverrideexportvirtual

Generate a human-readable description.

Returns
String representation showing children

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ validateBuildContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::validateBuildContext ( const BuildContext & context) const
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validateInputShape()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineexportprivate
Here is the call graph for this function:

◆ zeroGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::GptTransformer< TDeviceType, TPrecision >::zeroGradients ( )
inlineoverrideexportvirtual

Clear all model-owned gradients for this component.

Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

The documentation for this class was generated from the following file: