Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.GptTransformer Module Reference

Exported Modules

module  Dnn.Components.LayerNorm
module  Compute.DeviceTypeTraits
module  Dnn.TensorDataType
module  Dnn.RuntimeMode
module  Dnn.TensorTypes
module  Dnn.Components.GptBlock
module  Serialization.PretrainedReader
module  Logging.Logger
module  Serialization.Tensor
module  Dnn.TensorDataTypeTraits
module  Dnn.Components.Linear
module  Serialization.ModelArchive
module  Compute.ExecutionContextFactory
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Compute.DeviceId
module  Dnn.Tensor
module  Compute.Device
module  Dnn.ComponentType
module  Dnn.ActivationType
module  Dnn.LanguageNetwork
module  Compute.ExecutionContext
module  Dnn.Components.Lpe
module  Serialization.Mode
module  Compute.DeviceType
module  Dnn.ITensor

Classes

class  Mila::Dnn::GptConfig
 Network-level configuration for GPT-style transformer networks. More...
class  Mila::Dnn::GptTransformer< TDeviceType, TPrecision >
 GPT-2 style transformer (decoder-only) for autoregressive token prediction. More...

Typedefs

using ComponentPtr = typename NetworkBase::ComponentPtr
using EncoderType = Lpe<TDeviceType, dtype_t::INT32, TPrecision>
using LayerNormType = LayerNorm<TDeviceType, TPrecision>
using LinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = LanguageNetwork<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
using TransformerBlockType = GptBlock<TDeviceType, TPrecision>

Functions

 GptTransformer (const std::string &name, const GptConfig &config, DeviceId device_id)
 Construct Gpt type transformer.
 ~GptTransformer () override=default
TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad) override
static auto createConfigFromMetadata (const PretrainedMetadata &metadata) -> GptConfig
 Create GptConfig from Mila metadata.
void createGraph ()
TensorTypedecode (const TokenIndexType &input, int position) override
 Inference-only single-token decode pass.
TensorTypeforward (const TokenIndexType &input) override
 Load GptTransformer from archive.
static std::unique_ptr< GptTransformer< TDeviceType, TPrecision > > fromPretrained (const std::filesystem::path &model_path, std::size_t batch_size, std::size_t seq_length, DeviceId device_id=DeviceId{ TDeviceType, 0 }, bool strict=true)
IExecutionContextgetExecutionContext () const
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
GptConfig Mila::Dnn::GPT2_Large ()
 GPT-2 Large (774M parameters).
GptConfig Mila::Dnn::GPT2_Medium ()
 GPT-2 Medium (345M parameters).
GptConfig Mila::Dnn::GPT2_Small ()
 Usage Examples:
GptConfig Mila::Dnn::GPT2_XL ()
 GPT-2 XL (1.5B parameters).
void loadParameters (PretrainedModelReader &reader, bool strict)
 Initialize this transformer's components from a GPT-2 checkpoint.
void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
std::pair< std::string, std::string > parseParameterPath (const std::string &full_name) const
TensorTypeprefill (const TokenIndexType &input) override
 Inference prefill — process full prompt and return last-token logits.
void save_ (ModelArchive &archive, SerializationMode) const override
 Hook for concrete classes to save type-specific state.
std::string toString () const override
 Generate a human-readable description.
void validateBuildContext (const BuildContext &context) const
void validateInputShape (const shape_t &input_shape) const
void zeroGradients () override
 Clear all model-owned gradients for this component.

Variables

int64_t batch_size_ { 0 }
std::vector< TensorType * > block_input_ptrs_
std::vector< TensorType * > block_output_ptrs_
GptConfig config_
shape_t embedding_shape_
std::shared_ptr< EncoderTypeencoder_ { nullptr }
TensorTypeencoder_out_ptr_ { nullptr }
std::shared_ptr< LayerNormTypefinal_layernorm_ { nullptr }
shape_t leading_shape_
std::shared_ptr< LinearTypelm_head_ { nullptr }
TensorTypelogits_ptr_ { nullptr }
TensorTypenormalized_ptr_ { nullptr }
shape_t output_shape_
std::unique_ptr< IExecutionContextowned_context_ { nullptr }
int64_t seq_length_ { 0 }
std::vector< std::shared_ptr< TransformerBlockType > > transformer_blocks_

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/Gpt/GptTransformer.ixx
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/Gpt/Gpt.Config.ixx
 Network-level configuration for GPT-style transformer networks.
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/Gpt/Gpt.Presets.ixx