|
Mila 0.13.48
Deep Neural Network Library
|
Compute interface for attention operations that maintain a KV cache. More...


Public Member Functions | |
| ~IKvInference () override=default | |
| virtual void | decode (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position)=0 |
| Process a single token at an explicit KV cache position. | |
| virtual void | prefill (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset)=0 |
| Populate the KV cache and compute attention output for a token chunk. | |
| Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle | |
| virtual | ~IKvCacheLifecycle ()=default |
| virtual void | initializeKvCache (int batch_size, int max_sequence_length)=0 |
| Allocate the KV cache for a given batch size and maximum sequence length. | |
| virtual void | resetKvCache ()=0 |
| Reset the KV cache to an empty state, preserving the allocation. | |
Compute interface for attention operations that maintain a KV cache.
Extends IKVCacheLifecycle with the two-phase inference compute contract used by modern transformer architectures (GQA, MQA, and derived models):
prefill — populate the KV cache from a (possibly chunked) prompt sequence. Q, K, V are passed separately to support upstream in-place RoPE rotation before cache insertion. An explicit position_offset supports chunked prefill over sequences longer than a single pass.
decode — process a single autoregressive token against the accumulated KV cache. Input is packed QKV [B, 1, (n_heads + 2*n_kv_heads)*head_dim].
|
overridedefault |
|
pure virtual |
Process a single token at an explicit KV cache position.
| q | Query [B, 1, n_heads * head_dim]. |
| k | Key [B, 1, n_kv_heads * head_dim]. |
| v | Value [B, 1, n_kv_heads * head_dim]. |
| output | Pre-allocated output [B, 1, model_dim]. |
| position | Zero-based absolute sequence position into the KV cache. |
Implemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >.
|
pure virtual |
Populate the KV cache and compute attention output for a token chunk.
| q | Query [B, T_chunk, n_heads * head_dim]. |
| k | Key [B, T_chunk, n_kv_heads * head_dim]. |
| v | Value [B, T_chunk, n_kv_heads * head_dim]. |
| output | Pre-allocated output [B, T_chunk, model_dim]. |
| position_offset | Absolute position of the first token in this chunk. |
Implemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >.