|
Mila 0.13.48
Deep Neural Network Library
|
Grouped-Query Attention module that accepts concatenated QKV input. More...


Public Types | |
| using | ComponentBase = Component<TDeviceType, TComputePrecision> |
| using | KvCacheTensorType = Tensor<kCacheDtype, MR> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | OpType = typename Compute::OperationTraits<Compute::OperationType::GroupedQueryAttentionOp, TDeviceType, TComputePrecision, TKvPolicy>::type |
| using | TensorType = Tensor<TComputePrecision, MR> |
Public Member Functions | |
| GroupedQueryAttention (const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt) | |
| Construct a GroupedQueryAttention component. | |
| ~GroupedQueryAttention () override=default | |
| TensorType & | backward (const TensorType &input, const TensorType &output_grad) |
| Run backward pass and return the component-owned input-gradient tensor. | |
| TensorType & | decode (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) |
| Inference-only single-token decode pass. | |
| TensorType & | forward (const TensorType &input) |
| Standard forward pass. | |
| const GqaConfig & | getConfig () const noexcept |
| DeviceId | getDeviceId () const override |
| Get the compute device id associated with this component. | |
| std::vector< ITensor * > | getGradients () const override |
| Return non-owning pointers to parameter gradient tensors. | |
| MemoryStats | getMemoryStats () const override |
| Return the current memory allocation breakdown for this component. | |
| int64_t | getModelDim () const noexcept |
| int64_t | getNumHeads () const noexcept |
| int64_t | getNumKvHeads () const noexcept |
| std::vector< ITensor * > | getParameters () const override |
| Return non-owning pointers to parameter tensors. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| size_t | parameterCount () const override |
| Return number of trainable parameters. | |
| TensorType & | prefill (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) |
| Chunked prefill pass with explicit position offset. | |
| void | save_ (ModelArchive &archive, SerializationMode mode) const override |
| void | setState (const GqaState &state) |
| Forward the shared transient workspace to the underlying operation. | |
| bool | supportsKVCache () const noexcept |
| Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle. | |
| void | synchronize () override |
| Wait for outstanding device work submitted by this component. | |
| std::string | toString () const override |
| Produce a short, human-readable description of the component. | |
| Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision > | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
| virtual void | zeroGradients () |
| Clear all model-owned gradients for this component. | |
Static Public Attributes | |
| static constexpr TensorDataType | kCacheDtype |
| static constexpr bool | kKvCompressed = TKvPolicy::kIsActive |
Protected Member Functions | |
| void | onBuilding (const BuildContext &context) override |
| Hook invoked by build() to allocate component buffers. | |
| void | onExecutionContextSet () override |
| Lifecycle hook: Called immediately after ExecutionContext is set. | |
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| Hook called before TrainingMode transitions. | |
| Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision > | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Private Member Functions | |
| void | createOperation () |
| void | validateConcatenatedQKVShape (const shape_t &shape) const |
| Validate that the input tensor has the expected GQA-packed QKV shape. | |
Private Attributes | |
| bool | cache_initialized_ { false } |
| GqaConfig | config_ |
| std::unique_ptr< IExecutionContext > | context_ { nullptr } |
| bool | decode_active_ { false } |
| std::unique_ptr< TensorType > | decode_output_ { nullptr } |
| std::unique_ptr< TensorType > | input_grad_ { nullptr } |
| IKvCacheLifecycle * | kv_cache_op_ { nullptr } |
| shape_t | max_input_shape_ |
| std::shared_ptr< OpType > | operation_ { nullptr } |
| std::unique_ptr< TensorType > | output_ { nullptr } |
| std::optional< TensorType > | output_view_ |
| IKvInference * | positional_op_ { nullptr } |
Additional Inherited Members | |
| Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision > | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
| Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision > | |
| BuildContext | build_context_ |
| The BuildContext stored at build time. | |
Grouped-Query Attention module that accepts concatenated QKV input.
GQA generalises MHA by allowing num_kv_heads < num_heads. Each K/V head is shared by a group of (num_heads / num_kv_heads) Q heads, reducing KV cache memory and bandwidth during inference.
The module requires a single input tensor in model-layout containing concatenated Q, K and V along the feature axis:
input shape == [B, T, (num_heads + 2 * num_kv_heads) * head_dim] output shape == [B, T, model_dim] (model_dim = num_heads * head_dim)
The backend compute implementation (registered as "GroupedQueryAttentionOp") must accept this layout and produce the output above.
KV-cache inference is an optional backend capability. After build(), supportsKVCache() indicates whether the underlying operation implements both IPositionalUnaryOp (prefill/decode dispatch) and IKVCacheLifecycle (cache init/reset). Both pointers are resolved once at build time.
The KV cache lifecycle (initializeKVCache / resetKVCache) is intended to be driven exclusively by the owning transformer's generate() method.
REVIEW: initializeKVCache() and resetKVCache() are currently public. When TransformerBase<> is introduced as the common base for GptTransformer, LlamaTransformer, MistralTransformer etc., revisit whether these should become private with 'friend class TransformerBase<TDeviceType, TPrecision>' to enforce that only the generate() orchestration path may manage the KV cache lifecycle.
|
inlineexplicitexport |
Construct a GroupedQueryAttention component.
| name | Component name identifier (mandatory). |
| config | GQA configuration (model_dim, num_heads, num_kv_heads). |
| device_id | Optional DeviceId to create an owned ExecutionContext (standalone / unit-test mode). |
|
overrideexportdefault |
|
inlineexport |
Run backward pass and return the component-owned input-gradient tensor.
| input | Concatenated QKV input tensor used in forward. |
| output_grad | Gradient w.r.t. the module output. |
|
inlineexportprivate |
|
inlineexport |
Inference-only single-token decode pass.
When the backend implements IPositionalUnaryOp and the cache has been populated by a prior forward() call, uses the fast O(n) KV cache path. When the backend does not support positional dispatch, falls back to forward(). The caller never needs to know which path was taken.
Precondition: forward() must have been called at least once to populate the KV cache before decode() is called.
| input | Single-token QKV input [B, 1, (Q + 2*KV) * head_dim]. |
| position | Current sequence position (0-based). |
|
inlineexport |
Standard forward pass.
Always available regardless of backend. When the backend supports KV caching, the first forward() call initialises and populates the cache (prefill with position_offset=0). When called again after decode() steps, it automatically resets the cache and begins a new prefill session — no explicit session management required by callers.
| input | Concatenated QKV input [B, T, (Q + 2*KV) * head_dim]. |
|
inlineexportnoexcept |
|
inlineoverrideexportvirtual |
Get the compute device id associated with this component.
Must return the device on which parameters and operations execute.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportvirtual |
Return non-owning pointers to parameter gradient tensors.
Only valid when isTraining() is true.
| std::runtime_error | if called when not in training mode or before the component has been built. |
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportvirtual |
Return the current memory allocation breakdown for this component.
Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:
After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients
For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.
May be called at any time — no lifecycle preconditions.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineexportnoexcept |
|
inlineexportnoexcept |
|
inlineexportnoexcept |
|
inlineoverrideexportvirtual |
Return non-owning pointers to parameter tensors.
The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportvirtual |
Get the component type identifier.
Used for serialization and runtime type identification.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportprotectedvirtual |
Hook invoked by build() to allocate component buffers.
Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.
The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.
| config | Build-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension. |
Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportprotectedvirtual |
Lifecycle hook: Called immediately after ExecutionContext is set.
Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.
Common uses:
Default implementation does nothing.
| Any | exception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state. |
Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportprotectedvirtual |
Hook called before TrainingMode transitions.
Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.
The default implementation is a no-op.
| mode | The incoming TrainingMode. |
Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportvirtual |
Return number of trainable parameters.
For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineexport |
Chunked prefill pass with explicit position offset.
Called by the transformer block during chunked prefill. The KV cache must already be initialized (via onBuilding or forward()).
| input | Concatenated QKV input [B, T_chunk, (Q + 2*KV) * head_dim]. |
| position_offset | Absolute position of the first token in this chunk. |
|
inlineoverrideexportvirtual |
|
inlineexport |
|
inlineexportnoexcept |
Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle.
Resolved once at build time. CPU backends return false; CUDA backends return true when CudaGroupedQueryAttentionOp is in use.
|
inlineoverrideexportvirtual |
Wait for outstanding device work submitted by this component.
On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineoverrideexportvirtual |
Produce a short, human-readable description of the component.
Implementations should keep output concise and avoid throwing.
Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.
|
inlineexportprivate |
Validate that the input tensor has the expected GQA-packed QKV shape.
Expected trailing dimension: (num_heads + 2 * num_kv_heads) * head_dim