Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Llama.Block.ixx File Reference

LLaMA transformer block — module partition of LlamaTransformer. More...

#include <memory>
#include <vector>
#include <string>
#include <format>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <optional>
import Dnn.Quantization.Weight.Policies;
import Serialization.ModelArchive;
import Dnn.Components.Linear;
import Dnn.Components.Residual;
import Dnn.Components.Rope;
import Dnn.Components.RmsNorm;
import Compute.ExecutionContextFactory;
import Compute.DeviceTypeTraits;
import Dnn.Components.LlamaTransformer:Config;
import Dnn.ITensor;
import Compute.IExecutionContext;
import Compute.ExecutionContext;
import Dnn.ComponentType;
import Dnn.Components.Swiglu;
import Compute.GqaState;
import Dnn.TensorTypes;
import Dnn.TensorDataType;
import Dnn.Quantization.KvCache.Policy;
import Serialization.Mode;
import Compute.DeviceId;
import Dnn.Tensor;
import Dnn.TensorOps;
import Dnn.TensorDataTypeTraits;
import Dnn.Component;
import Dnn.Components.Gqa;
import Dnn.CompositeComponent;
import Compute.Device;
import Compute.DeviceType;

Classes

class  Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy >

Namespaces

namespace  Mila
 Mila main API namespace.
namespace  Mila::Dnn

Detailed Description

LLaMA transformer block — module partition of LlamaTransformer.

Implements the correct Llama 3.x attention sub-graph using fused projections, zero-copy tensor views, and in-place RoPE rotation:

input [B, T, model_dim] └─ RMSNorm (ln_1) └─ fc_qkv_proj [model_dim → (n_heads + 2*n_kv_heads) * head_dim] 1 GEMM └─ view Q [B, T, n_heads * head_dim] ──┐ └─ view K [B, T, n_kv_heads * head_dim] ──┤ RoPE in-place └─ view V [B, T, n_kv_heads * head_dim] │ (V untouched) └─ GroupedQueryAttention (packed QKV) └─ fc_out_proj [model_dim → model_dim] └─ Residual (input + out_proj) res_1 └─ RMSNorm (ln_2) └─ fc_gate_up [model_dim → 2*hidden_dim] 1 GEMM └─ SwiGLU → [B, T, hidden_dim] └─ fc_down [hidden_dim → model_dim] └─ Residual (res1 + ffn) res_2

Key design points:

  • Single GEMM for QKV (GQA-correct output dim).
  • RoPE applied in-place via zero-copy views — no concat/split ops needed.
  • Single GEMM for gate+up (SwiGLU kernel expects [gate | up] layout).
  • FFN composed directly from Linear + SwiGLU primitives; no MLP composite.
  • Component names match convert_llama32.py tensor name mapping.

Weight loader note: HuggingFace stores fc_gate and fc_up as separate tensors. The weight loader must concatenate them along dim 0 into a single [2*hidden_dim, model_dim] matrix when loading into fc_gate_up.