Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.LlamaTransformer Module Reference

Exported Modules

module  Serialization.Tensor
module  Dnn.Components.TokenEmbedding
module  Dnn.TensorTypes
module  Dnn.ComponentType
module  Dnn.TensorDataType
module  Compute.GqaState
module  Dnn.Tensor
module  Compute.DeviceId
module  Compute.Device
module  Compute.ExecutionContextFactory
module  Dnn.Components.Rope
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Dnn.Quantization.KvCache.Policy
module  Compute.DeviceTypeTraits
module  Dnn.Components.Linear
module  Dnn.TensorDataTypeTraits
module  Serialization.ModelArchive
module  Dnn.Quantization.Weight.Policies
module  Serialization.PretrainedReader
module  Dnn.ActivationType
module  Dnn.Components.RmsNorm
module  Dnn.LanguageNetwork
module  Compute.DeviceType
module  Compute.ExecutionContext
module  Dnn.ITensor

Classes

class  Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >
class  Mila::Dnn::LlamaConfig
 Network-level configuration for LLaMA-style transformer networks. More...
class  Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >
 LLaMA-style transformer (decoder-only) for autoregressive token prediction. More...

Typedefs

using ComponentPtr = typename NetworkBase::ComponentPtr
using LinearType = Linear<TDeviceType, TPrecision, TWeightQuantization>
using LmHeadLinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = LanguageNetwork<TDeviceType, TPrecision>
using RmsNormType = RmsNorm<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenEmbeddingType = TokenEmbedding<TDeviceType, dtype_t::INT32, TPrecision>
using TokenIndexType = Tensor<dtype_t::INT32, MR>
using TransformerBlockType = LlamaBlock<TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy>

Functions

 LlamaTransformer (const std::string &name, const LlamaConfig &config, DeviceId device_id)
 ~LlamaTransformer () override=default
TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad) override
int64_t Mila::Dnn::computePrefillChunkSize (int64_t batch, int64_t num_heads, int64_t head_dim, int64_t context_length, int64_t precision_bytes)
static LlamaConfig createConfigFromMetadata (const PretrainedMetadata &metadata)
void createGraph ()
TensorTypedecode (const TokenIndexType &input, int position) override
TensorTypeforward (const TokenIndexType &input) override
IExecutionContextgetExecutionContext () const
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
LlamaConfig Mila::Dnn::Llama2_13B ()
 Llama 2 13B.
LlamaConfig Mila::Dnn::Llama2_70B ()
 Llama 2 70B.
LlamaConfig Mila::Dnn::Llama2_7B ()
 Llama 2 7B.
LlamaConfig Mila::Dnn::Llama3_1_405B ()
 Llama 3.1 405B.
LlamaConfig Mila::Dnn::Llama3_1_70B ()
 Llama 3.1 70B.
LlamaConfig Mila::Dnn::Llama3_1_8B ()
 Llama 3.1 8B.
LlamaConfig Mila::Dnn::Llama3_2_1B ()
 Usage Examples:
LlamaConfig Mila::Dnn::Llama3_2_3B ()
 Llama 3.2 3B.
LlamaConfig Mila::Dnn::Llama3_70B ()
 Llama 3 70B (Original release).
LlamaConfig Mila::Dnn::Llama3_8B ()
 Llama 3 8B (Original release).
void loadParameters (PretrainedModelReader &reader)
void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
std::pair< std::string, std::string > parseParameterPath (const std::string &full_name) const
TensorTypeprefill (const TokenIndexType &input) override
void save_ (ModelArchive &archive, SerializationMode) const override
 Hook for concrete classes to save type-specific state.
std::string toString () const override
 Generate a human-readable description.
void validateBuildContext (const BuildContext &context) const
void validateLeadingShape (const shape_t &leading_shape) const
void zeroGradients () override
 Clear all model-owned gradients for this component.

Variables

int64_t batch_size_ { 0 }
std::vector< TensorType * > block_input_ptrs_
std::vector< TensorType * > block_output_ptrs_
LlamaConfig config_
shape_t embedding_shape_
std::unique_ptr< IExecutionContextexec_context_ { nullptr }
std::shared_ptr< RmsNormTypefinal_rmsnorm_ { nullptr }
std::unique_ptr< TensorTypegqa_att_ { nullptr }
std::unique_ptr< TensorTypegqa_att_decode_ { nullptr }
std::unique_ptr< TensorTypegqa_preatt_ { nullptr }
std::unique_ptr< TensorTypegqa_preatt_decode_ { nullptr }
std::unique_ptr< TensorTypegqa_q_permute_ { nullptr }
std::unique_ptr< TensorTypegqa_v_out_ { nullptr }
std::unique_ptr< TensorTypegqa_v_out_decode_ { nullptr }
shape_t input_shape_
constexpr int64_t Mila::Dnn::kPrefillScratchByteCap = int64_t{ 1536 } * 1024 * 1024
std::shared_ptr< LmHeadLinearTypelm_head_ { nullptr }
TensorTypelogits_ptr_ { nullptr }
TensorTypenormalized_ptr_ { nullptr }
shape_t output_shape_
std::unique_ptr< TensorTypeprefill_ { nullptr }
int64_t prefill_chunk_size_ { 0 }
int64_t seq_length_ { 0 }
TensorTypetoken_embed_out_ptr_ { nullptr }
std::shared_ptr< TokenEmbeddingTypetoken_embedding_ { nullptr }
std::vector< std::shared_ptr< TransformerBlockType > > transformer_blocks_

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.ixx
 LLaMA-style decoder-only transformer network.
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Block.ixx
 LLaMA transformer block — module partition of LlamaTransformer.
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Config.ixx
 LLaMA network-level configuration.
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Presets.ixx