Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::IPackedKvInference Struct Referenceabstractexport

KV-cache inference interface for packed-QKV MHA backends. More...

Inheritance diagram for Mila::Dnn::Compute::IPackedKvInference:
Collaboration diagram for Mila::Dnn::Compute::IPackedKvInference:

Public Member Functions

 ~IPackedKvInference () override=default
virtual void decode (const ITensor &input, ITensor &output, int position)=0
 Process a single autoregressive token against the KV cache.
virtual void prefill (const ITensor &qkv, ITensor &output)=0
 Populate the KV cache from a packed QKV sequence and compute output.
Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle
virtual ~IKvCacheLifecycle ()=default
virtual void initializeKvCache (int batch_size, int max_sequence_length)=0
 Allocate the KV cache for a given batch size and maximum sequence length.
virtual void resetKvCache ()=0
 Reset the KV cache to an empty state, preserving the allocation.

Detailed Description

KV-cache inference interface for packed-QKV MHA backends.

Implemented by GPT-style MHA backends (e.g. CudaMultiHeadAttentionOp). Uses fused QKV input throughout — Q, K, and V are concatenated along the feature axis and split internally by the backend kernel.

Position is implicit: GPT-style MHA always begins prefill at position 0. Absolute positional encoding is handled upstream by Lpe, not inside attention.

Two-phase inference protocol: prefill — populate the KV cache from the full prompt sequence. decode — process one autoregressive token against the accumulated cache.

Constructor & Destructor Documentation

◆ ~IPackedKvInference()

Mila::Dnn::Compute::IPackedKvInference::~IPackedKvInference ( )
overridedefault

Member Function Documentation

◆ decode()

virtual void Mila::Dnn::Compute::IPackedKvInference::decode ( const ITensor & input,
ITensor & output,
int position )
pure virtual

Process a single autoregressive token against the KV cache.

Parameters
inputPacked QKV single-token input [B, 1, 3 * embedding_dim].
outputPre-allocated output [B, 1, embedding_dim].
positionZero-based absolute sequence position into the KV cache.

Implemented in Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >.

◆ prefill()

virtual void Mila::Dnn::Compute::IPackedKvInference::prefill ( const ITensor & qkv,
ITensor & output )
pure virtual

Populate the KV cache from a packed QKV sequence and compute output.

Parameters
qkvPacked QKV input [B, T, 3 * embedding_dim].
outputPre-allocated attention output [B, T, embedding_dim].

Implemented in Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >, Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TensorDataType::FP32 >.


The documentation for this struct was generated from the following file: