|
Mila 0.13.48
Deep Neural Network Library
|
CUDA Grouped-Query Attention operation. More...


Public Types | |
| using | ConfigType = GqaConfig |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using | TensorType = Tensor<TPrecision, MR> |
| using | UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
| Public Types inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaGqaOp (IExecutionContext *context, const GqaConfig &config) | |
| void | backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const |
| void | build (const BuildContext &context) override |
| Prepare the operation for a concrete input shape. | |
| void | decode (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position) override |
| Process a single token at an explicit KV cache position. | |
| void | forward (const ITensor &input, ITensor &output) const |
| Standard (non-cached) forward pass used during training. | |
| const GqaConfig & | getConfig () const |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| std::size_t | getStateMemorySize () const override |
| Returns the number of bytes of state memory allocated by this operation. | |
| void | initializeKvCache (int batch_size, int max_seq_length) override |
| Allocate the KV cache for a given batch size and maximum sequence length. | |
| void | prefill (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset) override |
| Populate the KV cache and compute attention output for a token chunk. | |
| void | resetKvCache () override |
| Reset the KV cache to an empty state, preserving the allocation. | |
| void | setGradients (ITensor *, ITensor *) override |
| Bind module-owned gradient tensors to the operation. | |
| void | setParameters (ITensor *, ITensor *) override |
| Bind module-owned parameter tensors to the operation. | |
| void | setState (const GqaState &state) |
| Wire the shared transient scratch buffers for the optimized inference path. | |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
| Public Member Functions inherited from Mila::Dnn::Compute::IKvInference | |
| ~IKvInference () override=default | |
| Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle | |
| virtual | ~IKvCacheLifecycle ()=default |
Private Member Functions | |
| void | buildCublasLtPlans () |
| void | buildCublasLtPlans_optimized () |
| void | decode_optimized (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position) |
| void | decodeImpl (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position) |
| void | ensureKVCacheEnabled () const |
| void | getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const |
| cudaDataType_t | getCudaDataType () const |
| const Detail::CublasLtMatMulPlan< NativeType > & | getOrBuildPartialAVPlan (int chunk_len) |
| const Detail::CublasLtMatMulPlan< NativeType > & | getOrBuildPartialAVPlan_optimized (int chunk_len) |
| const Detail::CublasLtMatMulPlan< NativeType > & | getOrBuildPartialQKPlan (int chunk_len) |
| const Detail::CublasLtMatMulPlan< NativeType > & | getOrBuildPartialQKPlan_optimized (int chunk_len) |
| void | initializeState (const BuildContext &build_context) |
| void | initializeState_optimized (const BuildContext &build_context) |
| void | prefill_optimized (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset) |
| void | prefillImpl (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset) |
| void | validateDecodeInputShape (const shape_t &s) const |
| void | validateInputShape (const shape_t &s) const |
| void | validatePrefillInputShape (const shape_t &s) const |
Static Private Member Functions | |
| static NativeType * | raw (const std::shared_ptr< TensorType > &t) |
Additional Inherited Members | |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
CUDA Grouped-Query Attention operation.
GQA generalises MHA by allowing num_kv_heads < num_heads. Every group of (num_heads / num_kv_heads) Q heads shares a single K/V head, reducing KV cache memory and bandwidth proportionally to the group size.
The legacy path uses cuBLASLt batched matmuls on an expanded layout: K and V are stored compactly in [B, NKV, T, HS] and expanded to [B, NH, T, HS] before the matmuls so every cuBLASLt plan operates at batch_count = B * NH.
The optimized path (kUseOptimizedPath) eliminates the expansion buffers and q_tensor_ by rebuilding cuBLASLt plans against the compact NKV layout with grouped head strides. See GqaMemory.md Phase 1 and Phase 2.
Forward pass (training):
Prefill pass (inference only, with KV cache):
Decode pass (decode / KV-cache):
Backward pass (training only):
| TPrecision | Tensor element type and cuBLASLt data/compute type. |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::ConfigType = GqaConfig |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::TensorType = Tensor<TPrecision, MR> |
| using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
|
inline |
|
inline |
|
inlineoverridevirtual |
Prepare the operation for a concrete input shape.
Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inlineprivate |

|
inlineprivate |

|
inlineoverridevirtual |
Process a single token at an explicit KV cache position.
| q | Query [B, 1, n_heads * head_dim]. |
| k | Key [B, 1, n_kv_heads * head_dim]. |
| v | Value [B, 1, n_kv_heads * head_dim]. |
| output | Pre-allocated output [B, 1, model_dim]. |
| position | Zero-based absolute sequence position into the KV cache. |
Implements Mila::Dnn::Compute::IKvInference.
|
inlineprivate |

|
inlineprivate |

|
inlineprivate |

|
inline |
Standard (non-cached) forward pass used during training.
|
inlineprivate |

|
inline |
|
inlineprivate |

|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inlineprivate |

|
inlineprivate |

|
inlineprivate |

|
inlineprivate |

|
inlineoverridevirtual |
Returns the number of bytes of state memory allocated by this operation.
State memory includes build-time buffers such as caches and scratch allocations. Parameters and gradients are owned at the component level and are not included.
Override in derived operations that allocate device or host state during build().
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inlineoverridevirtual |
Allocate the KV cache for a given batch size and maximum sequence length.
| batch_size | Number of sequences in the batch. |
| max_sequence_length | Maximum number of tokens the cache must hold. |
Implements Mila::Dnn::Compute::IKvCacheLifecycle.
|
inlineprivate |

|
inlineprivate |

|
inlineoverridevirtual |
Populate the KV cache and compute attention output for a token chunk.
| q | Query [B, T_chunk, n_heads * head_dim]. |
| k | Key [B, T_chunk, n_kv_heads * head_dim]. |
| v | Value [B, T_chunk, n_kv_heads * head_dim]. |
| output | Pre-allocated output [B, T_chunk, model_dim]. |
| position_offset | Absolute position of the first token in this chunk. |
Implements Mila::Dnn::Compute::IKvInference.
|
inlineprivate |

|
inlineprivate |

|
inlinestaticprivate |

|
inlineoverridevirtual |
Reset the KV cache to an empty state, preserving the allocation.
Implements Mila::Dnn::Compute::IKvCacheLifecycle.
|
inlineoverridevirtual |
Bind module-owned gradient tensors to the operation.
New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().
The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inlineoverridevirtual |
Bind module-owned parameter tensors to the operation.
The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.
|
inline |
Wire the shared transient scratch buffers for the optimized inference path.
Called once per build by LlamaTransformer after all blocks are built. The tensors are owned by LlamaTransformer and shared across all GQA layers sequentially. Must be called before prefill() or decode() when use_optimized_path_ is true.
| state | Non-owning pointers to the shared workspace tensors. All slots must be non-null for the optimized inference path. |
|
inlineprivate |
|
inlineprivate |

|
inlineprivate |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Batch size.
|
private |
|
private |
|
private |
|
private |
|
private |
Model dim = NH * HS.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Group size = NH / NKV.
|
private |
Head dim = C / NH.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Number of Q heads.
|
private |
Number of KV heads.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Max sequence length.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |