Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::IKvInference Struct Referenceabstractexport

Compute interface for attention operations that maintain a KV cache. More...

Inheritance diagram for Mila::Dnn::Compute::IKvInference:
Collaboration diagram for Mila::Dnn::Compute::IKvInference:

Public Member Functions

 ~IKvInference () override=default
virtual void decode (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position)=0
 Process a single token at an explicit KV cache position.
virtual void prefill (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset)=0
 Populate the KV cache and compute attention output for a token chunk.
Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle
virtual ~IKvCacheLifecycle ()=default
virtual void initializeKvCache (int batch_size, int max_sequence_length)=0
 Allocate the KV cache for a given batch size and maximum sequence length.
virtual void resetKvCache ()=0
 Reset the KV cache to an empty state, preserving the allocation.

Detailed Description

Compute interface for attention operations that maintain a KV cache.

Extends IKVCacheLifecycle with the two-phase inference compute contract used by modern transformer architectures (GQA, MQA, and derived models):

prefill — populate the KV cache from a (possibly chunked) prompt sequence. Q, K, V are passed separately to support upstream in-place RoPE rotation before cache insertion. An explicit position_offset supports chunked prefill over sequences longer than a single pass.

decode — process a single autoregressive token against the accumulated KV cache. Input is packed QKV [B, 1, (n_heads + 2*n_kv_heads)*head_dim].

Constructor & Destructor Documentation

◆ ~IKvInference()

Mila::Dnn::Compute::IKvInference::~IKvInference ( )
overridedefault

Member Function Documentation

◆ decode()

virtual void Mila::Dnn::Compute::IKvInference::decode ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position )
pure virtual

Process a single token at an explicit KV cache position.

Parameters
qQuery [B, 1, n_heads * head_dim].
kKey [B, 1, n_kv_heads * head_dim].
vValue [B, 1, n_kv_heads * head_dim].
outputPre-allocated output [B, 1, model_dim].
positionZero-based absolute sequence position into the KV cache.

Implemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >.

◆ prefill()

virtual void Mila::Dnn::Compute::IKvInference::prefill ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position_offset )
pure virtual

Populate the KV cache and compute attention output for a token chunk.

Parameters
qQuery [B, T_chunk, n_heads * head_dim].
kKey [B, T_chunk, n_kv_heads * head_dim].
vValue [B, T_chunk, n_kv_heads * head_dim].
outputPre-allocated output [B, T_chunk, model_dim].
position_offsetAbsolute position of the first token in this chunk.

Implemented in Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >, Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TensorDataType::FP32 >.


The documentation for this struct was generated from the following file: