|
Mila 0.13.48
Deep Neural Network Library
|
Exported Modules | |
| module | Serialization.Tensor |
| module | Dnn.Components.TokenEmbedding |
| module | Dnn.TensorTypes |
| module | Dnn.ComponentType |
| module | Dnn.TensorDataType |
| module | Compute.GqaState |
| module | Dnn.Tensor |
| module | Compute.DeviceId |
| module | Compute.Device |
| module | Compute.ExecutionContextFactory |
| module | Dnn.Components.Rope |
| module | Dnn.Component |
| module | Compute.CpuMemoryResource |
| module | Dnn.Quantization.KvCache.Policy |
| module | Compute.DeviceTypeTraits |
| module | Dnn.Components.Linear |
| module | Dnn.TensorDataTypeTraits |
| module | Serialization.ModelArchive |
| module | Dnn.Quantization.Weight.Policies |
| module | Serialization.PretrainedReader |
| module | Dnn.ActivationType |
| module | Dnn.Components.RmsNorm |
| module | Dnn.LanguageNetwork |
| module | Compute.DeviceType |
| module | Compute.ExecutionContext |
| module | Dnn.ITensor |
Classes | |
| class | Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy > |
| class | Mila::Dnn::LlamaConfig |
| Network-level configuration for LLaMA-style transformer networks. More... | |
| class | Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy > |
| LLaMA-style transformer (decoder-only) for autoregressive token prediction. More... | |
Typedefs | |
| using | ComponentPtr = typename NetworkBase::ComponentPtr |
| using | LinearType = Linear<TDeviceType, TPrecision, TWeightQuantization> |
| using | LmHeadLinearType = Linear<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | NetworkBase = LanguageNetwork<TDeviceType, TPrecision> |
| using | RmsNormType = RmsNorm<TDeviceType, TPrecision> |
| using | TensorType = Tensor<TPrecision, MR> |
| using | TokenEmbeddingType = TokenEmbedding<TDeviceType, dtype_t::INT32, TPrecision> |
| using | TokenIndexType = Tensor<dtype_t::INT32, MR> |
| using | TransformerBlockType = LlamaBlock<TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy> |
Functions | |
| LlamaTransformer (const std::string &name, const LlamaConfig &config, DeviceId device_id) | |
| ~LlamaTransformer () override=default | |
| TokenIndexType & | backward (const TokenIndexType &input, const TensorType &output_grad) override |
| int64_t | Mila::Dnn::computePrefillChunkSize (int64_t batch, int64_t num_heads, int64_t head_dim, int64_t context_length, int64_t precision_bytes) |
| static LlamaConfig | createConfigFromMetadata (const PretrainedMetadata &metadata) |
| void | createGraph () |
| TensorType & | decode (const TokenIndexType &input, int position) override |
| TensorType & | forward (const TokenIndexType &input) override |
| IExecutionContext * | getExecutionContext () const |
| MemoryStats | getMemoryStats () const override |
| Return the current memory allocation breakdown for this component. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| LlamaConfig | Mila::Dnn::Llama2_13B () |
| Llama 2 13B. | |
| LlamaConfig | Mila::Dnn::Llama2_70B () |
| Llama 2 70B. | |
| LlamaConfig | Mila::Dnn::Llama2_7B () |
| Llama 2 7B. | |
| LlamaConfig | Mila::Dnn::Llama3_1_405B () |
| Llama 3.1 405B. | |
| LlamaConfig | Mila::Dnn::Llama3_1_70B () |
| Llama 3.1 70B. | |
| LlamaConfig | Mila::Dnn::Llama3_1_8B () |
| Llama 3.1 8B. | |
| LlamaConfig | Mila::Dnn::Llama3_2_1B () |
| Usage Examples: | |
| LlamaConfig | Mila::Dnn::Llama3_2_3B () |
| Llama 3.2 3B. | |
| LlamaConfig | Mila::Dnn::Llama3_70B () |
| Llama 3 70B (Original release). | |
| LlamaConfig | Mila::Dnn::Llama3_8B () |
| Llama 3 8B (Original release). | |
| void | loadParameters (PretrainedModelReader &reader) |
| void | onBuilding (const BuildContext &context) override |
| Hook invoked by build() to allocate component buffers. | |
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| Hook invoked when training mode is about to change. | |
| std::pair< std::string, std::string > | parseParameterPath (const std::string &full_name) const |
| TensorType & | prefill (const TokenIndexType &input) override |
| void | save_ (ModelArchive &archive, SerializationMode) const override |
| Hook for concrete classes to save type-specific state. | |
| std::string | toString () const override |
| Generate a human-readable description. | |
| void | validateBuildContext (const BuildContext &context) const |
| void | validateLeadingShape (const shape_t &leading_shape) const |
| void | zeroGradients () override |
| Clear all model-owned gradients for this component. | |
Variables | |
| int64_t | batch_size_ { 0 } |
| std::vector< TensorType * > | block_input_ptrs_ |
| std::vector< TensorType * > | block_output_ptrs_ |
| LlamaConfig | config_ |
| shape_t | embedding_shape_ |
| std::unique_ptr< IExecutionContext > | exec_context_ { nullptr } |
| std::shared_ptr< RmsNormType > | final_rmsnorm_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_att_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_att_decode_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_preatt_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_preatt_decode_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_q_permute_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_v_out_ { nullptr } |
| std::unique_ptr< TensorType > | gqa_v_out_decode_ { nullptr } |
| shape_t | input_shape_ |
| constexpr int64_t | Mila::Dnn::kPrefillScratchByteCap = int64_t{ 1536 } * 1024 * 1024 |
| std::shared_ptr< LmHeadLinearType > | lm_head_ { nullptr } |
| TensorType * | logits_ptr_ { nullptr } |
| TensorType * | normalized_ptr_ { nullptr } |
| shape_t | output_shape_ |
| std::unique_ptr< TensorType > | prefill_ { nullptr } |
| int64_t | prefill_chunk_size_ { 0 } |
| int64_t | seq_length_ { 0 } |
| TensorType * | token_embed_out_ptr_ { nullptr } |
| std::shared_ptr< TokenEmbeddingType > | token_embedding_ { nullptr } |
| std::vector< std::shared_ptr< TransformerBlockType > > | transformer_blocks_ |
Files | |
| file | /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.ixx |
| LLaMA-style decoder-only transformer network. | |
| file | /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Block.ixx |
| LLaMA transformer block — module partition of LlamaTransformer. | |
| file | /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Config.ixx |
| LLaMA network-level configuration. | |
| file | /__w/Mila/Mila/Mila/Src/Dnn/Components/Transformers/LlaMa/Llama.Presets.ixx |