Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy > Class Template Referenceexport

LLaMA-style transformer (decoder-only) for autoregressive token prediction. More...

Inheritance diagram for Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >:
Collaboration diagram for Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >:

Public Types

using ComponentPtr = typename NetworkBase::ComponentPtr
using LinearType = Linear<TDeviceType, TPrecision, TWeightQuantization>
using LmHeadLinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = LanguageNetwork<TDeviceType, TPrecision>
using RmsNormType = RmsNorm<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenEmbeddingType = TokenEmbedding<TDeviceType, dtype_t::INT32, TPrecision>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
using TransformerBlockType = LlamaBlock<TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy>
Public Types inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = Network<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<TensorDataType::INT32, MR>
Public Types inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
using ComponentPtr = typename CompositeBase::ComponentPtr
using CompositeBase = CompositeComponent<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 LlamaTransformer (const std::string &name, const LlamaConfig &config, DeviceId device_id)
 ~LlamaTransformer () override=default
TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad) override
TensorTypedecode (const TokenIndexType &input, int position) override
TensorTypeforward (const TokenIndexType &input) override
IExecutionContextgetExecutionContext () const
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
void loadParameters (PretrainedModelReader &reader)
TensorTypeprefill (const TokenIndexType &input) override
std::string toString () const override
 Generate a human-readable description.
void zeroGradients () override
 Clear all model-owned gradients for this component.
Public Member Functions inherited from Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >
 LanguageNetwork (const std::string &name)
 ~LanguageNetwork () override=default
virtual TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad)=0
 Full backward pass (training).
virtual TensorTypedecode (const TokenIndexType &input, int position)=0
 Inference decode — single-token autoregressive step.
virtual TensorTypeforward (const TokenIndexType &input)=0
 Full-sequence forward pass.
virtual TensorTypeprefill (const TokenIndexType &input)=0
 Inference prefill — process full prompt and populate the KV cache.
Public Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
 Network (const std::string &name)
 Construct network (context managed by derived class).
 ~Network () override=default
template<typename TOptimizer, typename TConfig>
std::shared_ptr< TOptimizer > createOptimizer (const TConfig &config)
 Create and configure an optimizer for this network's parameters.
DeviceId getDeviceId () const noexcept
 Get the compute device for this composite.
const ComponentType getType () const override
 Get the component type identifier.
void save (ModelArchive &archive, SerializationMode mode) const
 Save network to archive.
void synchronize () override
 Synchronize all child components.
std::string toString () const override
 Generate a human-readable description.
Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.

Protected Member Functions

void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
void save_ (ModelArchive &archive, SerializationMode) const override
 Hook for concrete classes to save type-specific state.
Protected Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
void verifyArchitectureCompatibility (const PretrainedMetadata &metadata)
 Verify that imported model is compatible with network architecture.
Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

void createGraph ()
std::pair< std::string, std::string > parseParameterPath (const std::string &full_name) const
void validateBuildContext (const BuildContext &context) const
void validateLeadingShape (const shape_t &leading_shape) const

Static Private Member Functions

static LlamaConfig createConfigFromMetadata (const PretrainedMetadata &metadata)

Private Attributes

int64_t batch_size_ { 0 }
std::vector< TensorType * > block_input_ptrs_
std::vector< TensorType * > block_output_ptrs_
LlamaConfig config_
shape_t embedding_shape_
std::unique_ptr< IExecutionContextexec_context_ { nullptr }
std::shared_ptr< RmsNormTypefinal_rmsnorm_ { nullptr }
std::unique_ptr< TensorTypegqa_att_ { nullptr }
std::unique_ptr< TensorTypegqa_att_decode_ { nullptr }
std::unique_ptr< TensorTypegqa_preatt_ { nullptr }
std::unique_ptr< TensorTypegqa_preatt_decode_ { nullptr }
std::unique_ptr< TensorTypegqa_q_permute_ { nullptr }
std::unique_ptr< TensorTypegqa_v_out_ { nullptr }
std::unique_ptr< TensorTypegqa_v_out_decode_ { nullptr }
shape_t input_shape_
std::shared_ptr< LmHeadLinearTypelm_head_ { nullptr }
TensorTypelogits_ptr_ { nullptr }
TensorTypenormalized_ptr_ { nullptr }
shape_t output_shape_
std::unique_ptr< TensorTypeprefill_ { nullptr }
int64_t prefill_chunk_size_ { 0 }
int64_t seq_length_ { 0 }
TensorTypetoken_embed_out_ptr_ { nullptr }
std::shared_ptr< TokenEmbeddingTypetoken_embedding_ { nullptr }
std::vector< std::shared_ptr< TransformerBlockType > > transformer_blocks_

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >

LLaMA-style transformer (decoder-only) for autoregressive token prediction.

Graph: TokenEmbedding → RoPE → LlamaBlock × N → RmsNormLinear (lm_head). RoPE is applied to the full embedding stream after the token lookup; each LlamaBlock receives rotary-encoded embeddings as input.

Template parameters:

  • TDeviceType: device type (Cpu/Cuda)
  • TPrecision: tensor precision

Constructor & Destructor Documentation

◆ LlamaTransformer()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::LlamaTransformer ( const std::string & name,
const LlamaConfig & config,
DeviceId device_id )
inlineexplicitexport
Here is the call graph for this function:

◆ ~LlamaTransformer()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::~LlamaTransformer ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
TokenIndexType & Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::backward ( const TokenIndexType & input,
const TensorType & output_grad )
inlineoverrideexport
Here is the call graph for this function:
Here is the caller graph for this function:

◆ createConfigFromMetadata()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
LlamaConfig Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::createConfigFromMetadata ( const PretrainedMetadata & metadata)
inlinestaticexportprivate
Here is the call graph for this function:

◆ createGraph()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::createGraph ( )
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ decode()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
TensorType & Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::decode ( const TokenIndexType & input,
int position )
inlineoverrideexport

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
TensorType & Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::forward ( const TokenIndexType & input)
inlineoverrideexport
Here is the call graph for this function:

◆ getExecutionContext()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
IExecutionContext * Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::getExecutionContext ( ) const
inlineexport
Here is the call graph for this function:
Here is the caller graph for this function:

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
MemoryStats Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
const ComponentType Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ loadParameters()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::loadParameters ( PretrainedModelReader & reader)
inlineexport
Here is the call graph for this function:

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
const std::string getName() const
Get the component's name identifier.
Definition Component.ixx:410
LlamaConfig config_
Definition Llama.ixx:595
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::onTrainingModeChanging ( TrainingMode training_mode)
inlineoverrideexportprotectedvirtual

Hook invoked when training mode is about to change.

Propagates the new mode to all child components. The hook runs with the Component's training mutex held; it MUST NOT call setTraining().

Parameters
is_trainingNew training mode (true = training, false = eval)

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ parseParameterPath()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
std::pair< std::string, std::string > Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::parseParameterPath ( const std::string & full_name) const
inlineexportprivate
Here is the caller graph for this function:

◆ prefill()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
TensorType & Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::prefill ( const TokenIndexType & input)
inlineoverrideexport
Here is the call graph for this function:

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportprotectedvirtual

Hook for concrete classes to save type-specific state.

REQUIRED override for concrete networks. Must write:

  • Type identifier (e.g., "type": "MnistClassifier")
  • Configuration parameters (batch_size, architecture constants)
  • Shape metadata (for validation during Load())

This metadata enables the concrete class's Load() method to reconstruct the network.

Example implementation:

void save_(ModelArchive& archive, SerializationMode mode) const override
{
json meta;
meta["type"] = "MnistClassifier"; // Type identifier for runtime dispatch
meta["batch_size"] = batch_size_;
meta["input_shape"] = leading_shape_;
// ... other configuration
archive.writeJson("network/classifier_meta.json", meta);
}
void save_(ModelArchive &archive, SerializationMode) const override
Hook for concrete classes to save type-specific state.
Definition Llama.ixx:564
int64_t batch_size_
Definition Llama.ixx:600
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
nlohmann::json json
Definition Linear.ixx:57
Parameters
archiveArchive to write to
modeSerialization mode (passed from save())

Implements Mila::Dnn::Network< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
std::string Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::toString ( ) const
inlineoverrideexportvirtual

Generate a human-readable description.

Returns
String representation showing children

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ validateBuildContext()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::validateBuildContext ( const BuildContext & context) const
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validateLeadingShape()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::validateLeadingShape ( const shape_t & leading_shape) const
inlineexportprivate
Here is the call graph for this function:

◆ zeroGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision, WeightQuantPolicy TWeightQuantization = NoWeightQuant, KvCachePolicy TKvCachePolicy = NoKvCompression>
void Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >::zeroGradients ( )
inlineoverrideexportvirtual

Clear all model-owned gradients for this component.

Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.ixx