Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision > Class Template Referenceexport

CUDA Grouped-Query Attention operation. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >:

Public Types

using ConfigType = GqaConfig
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>
Public Types inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >
using DataTypeTraits

Public Member Functions

 CudaGqaOp (IExecutionContext *context, const GqaConfig &config)
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const
void build (const BuildContext &context) override
 Prepare the operation for a concrete input shape.
void decode (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position) override
 Process a single token at an explicit KV cache position.
void forward (const ITensor &input, ITensor &output) const
 Standard (non-cached) forward pass used during training.
const GqaConfiggetConfig () const
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
std::size_t getStateMemorySize () const override
 Returns the number of bytes of state memory allocated by this operation.
void initializeKvCache (int batch_size, int max_seq_length) override
 Allocate the KV cache for a given batch size and maximum sequence length.
void prefill (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset) override
 Populate the KV cache and compute attention output for a token chunk.
void resetKvCache () override
 Reset the KV cache to an empty state, preserving the allocation.
void setGradients (ITensor *, ITensor *) override
 Bind module-owned gradient tensors to the operation.
void setParameters (ITensor *, ITensor *) override
 Bind module-owned parameter tensors to the operation.
void setState (const GqaState &state)
 Wire the shared transient scratch buffers for the optimized inference path.
Public Member Functions inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.
Public Member Functions inherited from Mila::Dnn::Compute::IKvInference
 ~IKvInference () override=default
Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle
virtual ~IKvCacheLifecycle ()=default

Private Member Functions

void buildCublasLtPlans ()
void buildCublasLtPlans_optimized ()
void decode_optimized (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position)
void decodeImpl (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position)
void ensureKVCacheEnabled () const
void getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const
cudaDataType_t getCudaDataType () const
const Detail::CublasLtMatMulPlan< NativeType > & getOrBuildPartialAVPlan (int chunk_len)
const Detail::CublasLtMatMulPlan< NativeType > & getOrBuildPartialAVPlan_optimized (int chunk_len)
const Detail::CublasLtMatMulPlan< NativeType > & getOrBuildPartialQKPlan (int chunk_len)
const Detail::CublasLtMatMulPlan< NativeType > & getOrBuildPartialQKPlan_optimized (int chunk_len)
void initializeState (const BuildContext &build_context)
void initializeState_optimized (const BuildContext &build_context)
void prefill_optimized (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset)
void prefillImpl (const ITensor &q, const ITensor &k, const ITensor &v, ITensor &output, int position_offset)
void validateDecodeInputShape (const shape_t &s) const
void validateInputShape (const shape_t &s) const
void validatePrefillInputShape (const shape_t &s) const

Static Private Member Functions

static NativeTyperaw (const std::shared_ptr< TensorType > &t)

Private Attributes

int active_max_seq_len_ { 0 }
NativeTypeatt_ { nullptr }
NativeTypeatt_decode_ { nullptr }
NativeTypeatt_decode_opt_ { nullptr }
std::shared_ptr< TensorTypeatt_decode_tensor_
NativeTypeatt_opt_ { nullptr }
std::shared_ptr< TensorTypeatt_tensor_
std::shared_ptr< TensorTypeatt_tensor_optimized_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_decode_plan_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_decode_plan_optimized_
std::unordered_map< int, Detail::CublasLtMatMulPlan< NativeType > > att_value_partial_prefill_plan_cache_
std::unordered_map< int, Detail::CublasLtMatMulPlan< NativeType > > att_value_partial_prefill_plan_cache_optimized_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_plan_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_prefill_plan_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_prefill_plan_optimized_
int B_ { 0 }
 Batch size.
Detail::CublasLtMatMulPlan< NativeTypebackward_att_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_k_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_q_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_v_plan_
int C_ { 0 }
 Model dim = NH * HS.
int cached_seq_len_ { 0 }
GqaConfig config_
CudaExecutionContextcontext_
cublasLtHandle_t cublaslt_handle_ { nullptr }
NativeTypedatt_ { nullptr }
std::shared_ptr< TensorTypedatt_tensor_
NativeTypedK_ { nullptr }
NativeTypedK_exp_ { nullptr }
std::shared_ptr< TensorTypedK_exp_tensor_
std::shared_ptr< TensorTypedK_tensor_
NativeTypedpreatt_ { nullptr }
std::shared_ptr< TensorTypedpreatt_tensor_
NativeTypedq_ { nullptr }
std::shared_ptr< TensorTypedq_tensor_
NativeTypedV_ { nullptr }
NativeTypedV_exp_ { nullptr }
std::shared_ptr< TensorTypedV_exp_tensor_
std::shared_ptr< TensorTypedV_tensor_
NativeTypedVout_ { nullptr }
std::shared_ptr< TensorTypedVout_tensor_
int GS_ { 0 }
 Group size = NH / NKV.
int HS_ { 0 }
 Head dim = C / NH.
NativeTypek_ { nullptr }
NativeTypek_exp_ { nullptr }
std::shared_ptr< TensorTypek_exp_tensor_
NativeTypek_opt_ { nullptr }
std::shared_ptr< TensorTypek_tensor_
bool kv_cache_enabled_ { false }
int NH_ { 0 }
 Number of Q heads.
int NKV_ { 0 }
 Number of KV heads.
NativeTypepreatt_ { nullptr }
NativeTypepreatt_decode_ { nullptr }
NativeTypepreatt_decode_opt_ { nullptr }
std::shared_ptr< TensorTypepreatt_decode_tensor_
NativeTypepreatt_opt_ { nullptr }
std::shared_ptr< TensorTypepreatt_tensor_
std::shared_ptr< TensorTypepreatt_tensor_optimized_
int prefill_chunk_size_ { 0 }
NativeTypeq_ { nullptr }
NativeTypeq_permute_opt_ { nullptr }
std::shared_ptr< TensorTypeq_permute_tensor_optimized_
std::shared_ptr< TensorTypeq_tensor_
Detail::CublasLtMatMulPlan< NativeTypeqk_decode_plan_
Detail::CublasLtMatMulPlan< NativeTypeqk_decode_plan_optimized_
std::unordered_map< int, Detail::CublasLtMatMulPlan< NativeType > > qk_partial_prefill_plan_cache_
std::unordered_map< int, Detail::CublasLtMatMulPlan< NativeType > > qk_partial_prefill_plan_cache_optimized_
Detail::CublasLtMatMulPlan< NativeTypeqk_prefill_plan_
Detail::CublasLtMatMulPlan< NativeTypeqk_prefill_plan_optimized_
Detail::CublasLtMatMulPlan< NativeTypeqk_score_plan_
std::size_t state_memory_size_ { 0 }
int T_ { 0 }
 Max sequence length.
bool use_optimized_path_ { false }
NativeTypev_ { nullptr }
NativeTypev_exp_ { nullptr }
std::shared_ptr< TensorTypev_exp_tensor_
NativeTypev_opt_ { nullptr }
NativeTypev_out_ { nullptr }
NativeTypev_out_decode_ { nullptr }
NativeTypev_out_decode_opt_ { nullptr }
std::shared_ptr< TensorTypev_out_decode_tensor_
NativeTypev_out_opt_ { nullptr }
std::shared_ptr< TensorTypev_out_tensor_
std::shared_ptr< TensorTypev_out_tensor_optimized_
std::shared_ptr< TensorTypev_tensor_

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Protected Attributes inherited from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >

CUDA Grouped-Query Attention operation.

GQA generalises MHA by allowing num_kv_heads < num_heads. Every group of (num_heads / num_kv_heads) Q heads shares a single K/V head, reducing KV cache memory and bandwidth proportionally to the group size.

The legacy path uses cuBLASLt batched matmuls on an expanded layout: K and V are stored compactly in [B, NKV, T, HS] and expanded to [B, NH, T, HS] before the matmuls so every cuBLASLt plan operates at batch_count = B * NH.

The optimized path (kUseOptimizedPath) eliminates the expansion buffers and q_tensor_ by rebuilding cuBLASLt plans against the compact NKV layout with grouped head strides. See GqaMemory.md Phase 1 and Phase 2.

Forward pass (training):

  1. permute_qkv -> Q[B,NH,T,HS], K[B,NKV,T,HS], V[B,NKV,T,HS]
  2. expand_kv -> k_exp/v_exp [B,NH,T,HS]
  3. qk_score_plan -> preatt [B,NH,T,T]
  4. softmax_forward -> att [B,NH,T,T]
  5. att_value_plan -> v_out [B,NH,T,HS]
  6. unpermute_output -> Y [B,T,NH*HS]

Prefill pass (inference only, with KV cache):

  1. prefill_permute_qkv -> Q[B,NH,chunk,HS], K/V[B,NKV,chunk,HS] (padded to T)
  2. prefill_expand_kv -> k_exp/v_exp [B,NH,chunk,HS] (padded to T)
  3. prefill_qk_plan -> preatt [B,NH,chunk,T]
  4. prefill_softmax -> att [B,NH,chunk,T]
  5. prefill_att_value_plan -> v_out [B,NH,chunk,HS]
  6. prefill_unpermute_output -> Y [B,chunk,C]

Decode pass (decode / KV-cache):

  1. permute_qkv_decode -> single Q token; append K/V to cache
  2. expand_kv -> expand cache slice up to current position
  3. qk_decode_plan -> preatt_decode [B,NH,1,T]
  4. softmax_decode -> att_decode
  5. att_value_decode -> v_out_decode [B,NH,1,HS]
  6. unpermute_output -> Y [B,1,C]

Backward pass (training only):

  1. unpermute_backward -> dVout [B,NH,T,HS]
  2. backward_v_plan -> dV_exp (expanded)
  3. backward_att_plan -> dAtt
  4. softmax_backward -> dPreatt
  5. backward_q_plan -> dQ
  6. backward_k_plan -> dK_exp (expanded)
  7. reduce_kv_grad -> dK/dV [B,NKV,T,HS] (sum over group)
  8. permute_backward -> dX [B,T,(NH+2*NKV)*HS]
Template Parameters
TPrecisionTensor element type and cuBLASLt data/compute type.

Member Typedef Documentation

◆ ConfigType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::ConfigType = GqaConfig

◆ CudaExecutionContext

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

◆ NativeType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type

◆ TensorType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>

Constructor & Destructor Documentation

◆ CudaGqaOp()

template<TensorDataType TPrecision>
Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::CudaGqaOp ( IExecutionContext * context,
const GqaConfig & config )
inline

Member Function Documentation

◆ backward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inline

◆ build()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::build ( const BuildContext & build_context)
inlineoverridevirtual

Prepare the operation for a concrete input shape.

Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.

◆ buildCublasLtPlans()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::buildCublasLtPlans ( )
inlineprivate
Here is the caller graph for this function:

◆ buildCublasLtPlans_optimized()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::buildCublasLtPlans_optimized ( )
inlineprivate
Here is the caller graph for this function:

◆ decode()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::decode ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position )
inlineoverridevirtual

Process a single token at an explicit KV cache position.

Parameters
qQuery [B, 1, n_heads * head_dim].
kKey [B, 1, n_kv_heads * head_dim].
vValue [B, 1, n_kv_heads * head_dim].
outputPre-allocated output [B, 1, model_dim].
positionZero-based absolute sequence position into the KV cache.

Implements Mila::Dnn::Compute::IKvInference.

◆ decode_optimized()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::decode_optimized ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position )
inlineprivate
Here is the caller graph for this function:

◆ decodeImpl()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::decodeImpl ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position )
inlineprivate
Here is the caller graph for this function:

◆ ensureKVCacheEnabled()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::ensureKVCacheEnabled ( ) const
inlineprivate
Here is the caller graph for this function:

◆ forward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inline

Standard (non-cached) forward pass used during training.

◆ getComputeTypes()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getComputeTypes ( cublasComputeType_t & compute_type,
cudaDataType_t & scale_type ) const
inlineprivate
Here is the caller graph for this function:

◆ getConfig()

template<TensorDataType TPrecision>
const GqaConfig & Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getConfig ( ) const
inline

◆ getCudaDataType()

template<TensorDataType TPrecision>
cudaDataType_t Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getCudaDataType ( ) const
inlineprivate
Here is the caller graph for this function:

◆ getName()

template<TensorDataType TPrecision>
std::string Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.

◆ getOperationType()

template<TensorDataType TPrecision>
OperationType Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getOperationType ( ) const
inlineoverridevirtual

◆ getOrBuildPartialAVPlan()

template<TensorDataType TPrecision>
const Detail::CublasLtMatMulPlan< NativeType > & Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getOrBuildPartialAVPlan ( int chunk_len)
inlineprivate
Here is the caller graph for this function:

◆ getOrBuildPartialAVPlan_optimized()

template<TensorDataType TPrecision>
const Detail::CublasLtMatMulPlan< NativeType > & Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getOrBuildPartialAVPlan_optimized ( int chunk_len)
inlineprivate
Here is the caller graph for this function:

◆ getOrBuildPartialQKPlan()

template<TensorDataType TPrecision>
const Detail::CublasLtMatMulPlan< NativeType > & Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getOrBuildPartialQKPlan ( int chunk_len)
inlineprivate
Here is the caller graph for this function:

◆ getOrBuildPartialQKPlan_optimized()

template<TensorDataType TPrecision>
const Detail::CublasLtMatMulPlan< NativeType > & Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getOrBuildPartialQKPlan_optimized ( int chunk_len)
inlineprivate
Here is the caller graph for this function:

◆ getStateMemorySize()

template<TensorDataType TPrecision>
std::size_t Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::getStateMemorySize ( ) const
inlineoverridevirtual

Returns the number of bytes of state memory allocated by this operation.

State memory includes build-time buffers such as caches and scratch allocations. Parameters and gradients are owned at the component level and are not included.

Override in derived operations that allocate device or host state during build().

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.

◆ initializeKvCache()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::initializeKvCache ( int batch_size,
int max_sequence_length )
inlineoverridevirtual

Allocate the KV cache for a given batch size and maximum sequence length.

Parameters
batch_sizeNumber of sequences in the batch.
max_sequence_lengthMaximum number of tokens the cache must hold.

Implements Mila::Dnn::Compute::IKvCacheLifecycle.

◆ initializeState()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::initializeState ( const BuildContext & build_context)
inlineprivate
Here is the caller graph for this function:

◆ initializeState_optimized()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::initializeState_optimized ( const BuildContext & build_context)
inlineprivate
Here is the caller graph for this function:

◆ prefill()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::prefill ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position_offset )
inlineoverridevirtual

Populate the KV cache and compute attention output for a token chunk.

Parameters
qQuery [B, T_chunk, n_heads * head_dim].
kKey [B, T_chunk, n_kv_heads * head_dim].
vValue [B, T_chunk, n_kv_heads * head_dim].
outputPre-allocated output [B, T_chunk, model_dim].
position_offsetAbsolute position of the first token in this chunk.

Implements Mila::Dnn::Compute::IKvInference.

◆ prefill_optimized()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::prefill_optimized ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position_offset )
inlineprivate
Here is the caller graph for this function:

◆ prefillImpl()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::prefillImpl ( const ITensor & q,
const ITensor & k,
const ITensor & v,
ITensor & output,
int position_offset )
inlineprivate
Here is the caller graph for this function:

◆ raw()

template<TensorDataType TPrecision>
NativeType * Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::raw ( const std::shared_ptr< TensorType > & t)
inlinestaticprivate
Here is the caller graph for this function:

◆ resetKvCache()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::resetKvCache ( )
inlineoverridevirtual

Reset the KV cache to an empty state, preserving the allocation.

Implements Mila::Dnn::Compute::IKvCacheLifecycle.

◆ setGradients()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlineoverridevirtual

Bind module-owned gradient tensors to the operation.

New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().

The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.

◆ setParameters()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::setParameters ( ITensor * weight,
ITensor * bias )
inlineoverridevirtual

Bind module-owned parameter tensors to the operation.

The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< DeviceType::Cuda, TPrecision >.

◆ setState()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::setState ( const GqaState & state)
inline

Wire the shared transient scratch buffers for the optimized inference path.

Called once per build by LlamaTransformer after all blocks are built. The tensors are owned by LlamaTransformer and shared across all GQA layers sequentially. Must be called before prefill() or decode() when use_optimized_path_ is true.

Parameters
stateNon-owning pointers to the shared workspace tensors. All slots must be non-null for the optimized inference path.

◆ validateDecodeInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::validateDecodeInputShape ( const shape_t & s) const
inlineprivate

◆ validateInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::validateInputShape ( const shape_t & s) const
inlineprivate
Here is the caller graph for this function:

◆ validatePrefillInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::validatePrefillInputShape ( const shape_t & s) const
inlineprivate

Member Data Documentation

◆ active_max_seq_len_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::active_max_seq_len_ { 0 }
private

◆ att_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_ { nullptr }
private

◆ att_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_decode_ { nullptr }
private

◆ att_decode_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_decode_opt_ { nullptr }
private

◆ att_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_decode_tensor_
private

◆ att_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_opt_ { nullptr }
private

◆ att_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_tensor_
private

◆ att_tensor_optimized_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_tensor_optimized_
private

◆ att_value_decode_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_decode_plan_
private

◆ att_value_decode_plan_optimized_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_decode_plan_optimized_
private

◆ att_value_partial_prefill_plan_cache_

template<TensorDataType TPrecision>
std::unordered_map<int, Detail::CublasLtMatMulPlan<NativeType> > Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_partial_prefill_plan_cache_
private

◆ att_value_partial_prefill_plan_cache_optimized_

template<TensorDataType TPrecision>
std::unordered_map<int, Detail::CublasLtMatMulPlan<NativeType> > Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_partial_prefill_plan_cache_optimized_
private

◆ att_value_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_plan_
private

◆ att_value_prefill_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_prefill_plan_
private

◆ att_value_prefill_plan_optimized_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::att_value_prefill_plan_optimized_
private

◆ B_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::B_ { 0 }
private

Batch size.

◆ backward_att_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::backward_att_plan_
private

◆ backward_k_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::backward_k_plan_
private

◆ backward_q_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::backward_q_plan_
private

◆ backward_v_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::backward_v_plan_
private

◆ C_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::C_ { 0 }
private

Model dim = NH * HS.

◆ cached_seq_len_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::cached_seq_len_ { 0 }
private

◆ config_

template<TensorDataType TPrecision>
GqaConfig Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::config_
private

◆ context_

template<TensorDataType TPrecision>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::context_
private

◆ cublaslt_handle_

template<TensorDataType TPrecision>
cublasLtHandle_t Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::cublaslt_handle_ { nullptr }
private

◆ datt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::datt_ { nullptr }
private

◆ datt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::datt_tensor_
private

◆ dK_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dK_ { nullptr }
private

◆ dK_exp_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dK_exp_ { nullptr }
private

◆ dK_exp_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dK_exp_tensor_
private

◆ dK_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dK_tensor_
private

◆ dpreatt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dpreatt_ { nullptr }
private

◆ dpreatt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dpreatt_tensor_
private

◆ dq_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dq_ { nullptr }
private

◆ dq_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dq_tensor_
private

◆ dV_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dV_ { nullptr }
private

◆ dV_exp_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dV_exp_ { nullptr }
private

◆ dV_exp_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dV_exp_tensor_
private

◆ dV_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dV_tensor_
private

◆ dVout_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dVout_ { nullptr }
private

◆ dVout_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::dVout_tensor_
private

◆ GS_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::GS_ { 0 }
private

Group size = NH / NKV.

◆ HS_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::HS_ { 0 }
private

Head dim = C / NH.

◆ k_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::k_ { nullptr }
private

◆ k_exp_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::k_exp_ { nullptr }
private

◆ k_exp_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::k_exp_tensor_
private

◆ k_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::k_opt_ { nullptr }
private

◆ k_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::k_tensor_
private

◆ kv_cache_enabled_

template<TensorDataType TPrecision>
bool Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::kv_cache_enabled_ { false }
private

◆ NH_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::NH_ { 0 }
private

Number of Q heads.

◆ NKV_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::NKV_ { 0 }
private

Number of KV heads.

◆ preatt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_ { nullptr }
private

◆ preatt_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_decode_ { nullptr }
private

◆ preatt_decode_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_decode_opt_ { nullptr }
private

◆ preatt_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_decode_tensor_
private

◆ preatt_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_opt_ { nullptr }
private

◆ preatt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_tensor_
private

◆ preatt_tensor_optimized_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::preatt_tensor_optimized_
private

◆ prefill_chunk_size_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::prefill_chunk_size_ { 0 }
private

◆ q_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::q_ { nullptr }
private

◆ q_permute_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::q_permute_opt_ { nullptr }
private

◆ q_permute_tensor_optimized_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::q_permute_tensor_optimized_
private

◆ q_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::q_tensor_
private

◆ qk_decode_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_decode_plan_
private

◆ qk_decode_plan_optimized_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_decode_plan_optimized_
private

◆ qk_partial_prefill_plan_cache_

template<TensorDataType TPrecision>
std::unordered_map<int, Detail::CublasLtMatMulPlan<NativeType> > Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_partial_prefill_plan_cache_
private

◆ qk_partial_prefill_plan_cache_optimized_

template<TensorDataType TPrecision>
std::unordered_map<int, Detail::CublasLtMatMulPlan<NativeType> > Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_partial_prefill_plan_cache_optimized_
private

◆ qk_prefill_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_prefill_plan_
private

◆ qk_prefill_plan_optimized_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_prefill_plan_optimized_
private

◆ qk_score_plan_

template<TensorDataType TPrecision>
Detail::CublasLtMatMulPlan<NativeType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::qk_score_plan_
private

◆ state_memory_size_

template<TensorDataType TPrecision>
std::size_t Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::state_memory_size_ { 0 }
private

◆ T_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::T_ { 0 }
private

Max sequence length.

◆ use_optimized_path_

template<TensorDataType TPrecision>
bool Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::use_optimized_path_ { false }
private

◆ v_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_ { nullptr }
private

◆ v_exp_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_exp_ { nullptr }
private

◆ v_exp_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_exp_tensor_
private

◆ v_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_opt_ { nullptr }
private

◆ v_out_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_ { nullptr }
private

◆ v_out_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_decode_ { nullptr }
private

◆ v_out_decode_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_decode_opt_ { nullptr }
private

◆ v_out_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_decode_tensor_
private

◆ v_out_opt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_opt_ { nullptr }
private

◆ v_out_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_tensor_
private

◆ v_out_tensor_optimized_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_out_tensor_optimized_
private

◆ v_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::Gqa::CudaGqaOp< TPrecision >::v_tensor_
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Attention/GQA/CudaGqaOp.ixx