|
Mila 0.13.48
Deep Neural Network Library
|
LLaMA transformer block — module partition of LlamaTransformer. More...
#include <memory>#include <vector>#include <string>#include <format>#include <iostream>#include <sstream>#include <stdexcept>#include <optional>import Dnn.Quantization.Weight.Policies;import Serialization.ModelArchive;import Dnn.Components.Linear;import Dnn.Components.Residual;import Dnn.Components.Rope;import Dnn.Components.RmsNorm;import Compute.ExecutionContextFactory;import Compute.DeviceTypeTraits;import Dnn.Components.LlamaTransformer:Config;import Dnn.ITensor;import Compute.IExecutionContext;import Compute.ExecutionContext;import Dnn.ComponentType;import Dnn.Components.Swiglu;import Compute.GqaState;import Dnn.TensorTypes;import Dnn.TensorDataType;import Dnn.Quantization.KvCache.Policy;import Serialization.Mode;import Compute.DeviceId;import Dnn.Tensor;import Dnn.TensorOps;import Dnn.TensorDataTypeTraits;import Dnn.Component;import Dnn.Components.Gqa;import Dnn.CompositeComponent;import Compute.Device;import Compute.DeviceType;Classes | |
| class | Mila::Dnn::LlamaBlock< TDeviceType, TPrecision, TWeightQuant, TKvPolicy > |
Namespaces | |
| namespace | Mila |
| Mila main API namespace. | |
| namespace | Mila::Dnn |
LLaMA transformer block — module partition of LlamaTransformer.
Implements the correct Llama 3.x attention sub-graph using fused projections, zero-copy tensor views, and in-place RoPE rotation:
input [B, T, model_dim] └─ RMSNorm (ln_1) └─ fc_qkv_proj [model_dim → (n_heads + 2*n_kv_heads) * head_dim] 1 GEMM └─ view Q [B, T, n_heads * head_dim] ──┐ └─ view K [B, T, n_kv_heads * head_dim] ──┤ RoPE in-place └─ view V [B, T, n_kv_heads * head_dim] │ (V untouched) └─ GroupedQueryAttention (packed QKV) └─ fc_out_proj [model_dim → model_dim] └─ Residual (input + out_proj) res_1 └─ RMSNorm (ln_2) └─ fc_gate_up [model_dim → 2*hidden_dim] 1 GEMM └─ SwiGLU → [B, T, hidden_dim] └─ fc_down [hidden_dim → model_dim] └─ Residual (res1 + ffn) res_2
Key design points:
Weight loader note: HuggingFace stores fc_gate and fc_up as separate tensors. The weight loader must concatenate them along dim 0 into a single [2*hidden_dim, model_dim] matrix when loading into fc_gate_up.