Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::Gqa::Detail Namespace Reference

Classes

struct  cuda_gqa_kernels
struct  cuda_gqa_kernels< float >
struct  cuda_gqa_kernels< nv_bfloat16 >

Typedefs

template<typename TNative>
using CublasLtMatMulPlan = CublasLtMatMulPlan<TNative>

Functions

template<typename TNative>
CublasLtMatMulPlan< TNative > build_att_value_decode_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Single-token Att @ V decode plan.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_att_value_decode_plan_optimized (cublasLtHandle_t handle, int batch_size, int num_kv_heads, int group_size, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Single-token Att @ V decode plan reading V from [B, NKV, T, HS] cache.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_att_value_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Att @ V weighted-sum plan (training, full sequence length).
template<typename TNative>
CublasLtMatMulPlan< TNative > build_att_value_prefill_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int chunk_rows, int max_seq_length, int head_size, int prefill_window_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
template<typename TNative>
CublasLtMatMulPlan< TNative > build_att_value_prefill_plan_optimized (cublasLtHandle_t handle, int batch_size, int num_kv_heads, int group_size, int chunk_rows, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Att @ V prefill plan reading V directly from [B, NKV, T, HS] cache.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_backward_att_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 dAtt = dVout @ V^T.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_backward_k_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 dK = dPreatt^T @ Q.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_backward_q_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 dQ = dPreatt @ K.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_backward_v_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 dV = Att^T @ dVout.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_qk_decode_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Single-token Q @ K^T decode plan.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_qk_decode_plan_optimized (cublasLtHandle_t handle, int batch_size, int num_kv_heads, int group_size, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Single-token Q @ K^T decode plan reading K from [B, NKV, T, HS] cache.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_qk_prefill_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int chunk_rows, int max_seq_length, int head_size, int prefill_window_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Q @ K^T attention score plan for chunked prefill (inference).
template<typename TNative>
CublasLtMatMulPlan< TNative > build_qk_prefill_plan_optimized (cublasLtHandle_t handle, int batch_size, int num_kv_heads, int group_size, int chunk_rows, int max_seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Q @ K^T prefill plan reading K directly from [B, NKV, T, HS] cache.
template<typename TNative>
CublasLtMatMulPlan< TNative > build_qk_score_plan (cublasLtHandle_t handle, int batch_size, int num_heads, int seq_length, int head_size, cudaDataType_t cuda_data_type, cublasComputeType_t compute_type, cudaDataType_t scale_type)
 Q @ K^T attention score plan (training, full sequence length).

Typedef Documentation

◆ CublasLtMatMulPlan

Function Documentation

◆ build_att_value_decode_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_att_value_decode_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Single-token Att @ V decode plan.

Att is [1, T] (softmax weights over cache); V is [T, HS]. batch_count = B * NH (after KV expansion).

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_att_value_decode_plan_optimized()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_att_value_decode_plan_optimized ( cublasLtHandle_t handle,
int batch_size,
int num_kv_heads,
int group_size,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Single-token Att @ V decode plan reading V from [B, NKV, T, HS] cache.

att_decode is [B, NH, 1, T] = [B*NKV, GS, T]. V cache is [B, NKV, T, HS] = [B*NKV, T, HS]. v_out_decode is [B, NH, 1, HS] = [B*NKV, GS, HS].

strideA = GS * T (att_decode head stride) strideB = T * HS (V cache head stride) strideC = GS * HS (v_out_decode head stride)

Parameters
group_sizeNH / NKV.
num_kv_headsNumber of KV heads (NKV).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_att_value_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_att_value_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Att @ V weighted-sum plan (training, full sequence length).

After KV expansion: V is [B, NH, T, HS]. batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_att_value_prefill_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_att_value_prefill_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int chunk_rows,
int max_seq_length,
int head_size,
int prefill_window_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_att_value_prefill_plan_optimized()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_att_value_prefill_plan_optimized ( cublasLtHandle_t handle,
int batch_size,
int num_kv_heads,
int group_size,
int chunk_rows,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Att @ V prefill plan reading V directly from [B, NKV, T, HS] cache.

att is [B, NH, chunk, T] = [B*NKV, GS*chunk, T]. V is [B, NKV, T, HS] = [B*NKV, T, HS]. v_out is [B, NH, chunk, HS] = [B*NKV, GS*chunk, HS].

strideA = GS * chunk * T (att head stride) strideB = T * HS (V cache head stride) strideC = GS * chunk * HS (v_out head stride)

Parameters
group_sizeNH / NKV.
chunk_rowsNumber of Q rows per chunk (kPrefillChunkSize for full chunks).
num_kv_headsNumber of KV heads (NKV).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_backward_att_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_backward_att_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

dAtt = dVout @ V^T.

batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_backward_k_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_backward_k_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

dK = dPreatt^T @ Q.

dPreatt is [B, NH, T, T]; Q is [B, NH, T, HS]. The result dK is [B, NH, T, HS] in the expanded layout; the permute_backward kernel subsequently reduces it to [B, NKV, T, HS]. batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_backward_q_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_backward_q_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

dQ = dPreatt @ K.

K is the expanded KV buffer; batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_backward_v_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_backward_v_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

dV = Att^T @ dVout.

dV is accumulated across the Q group by the permute_backward kernel; the cuBLASLt op itself works on the expanded NH dimension. batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_qk_decode_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_qk_decode_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Single-token Q @ K^T decode plan.

Q has one row (the current token); K is the full cached sequence. After KV expansion: K is [B, NH, T, HS], batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_qk_decode_plan_optimized()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_qk_decode_plan_optimized ( cublasLtHandle_t handle,
int batch_size,
int num_kv_heads,
int group_size,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Single-token Q @ K^T decode plan reading K from [B, NKV, T, HS] cache.

Q compact is [B, NH, 1, HS] = [B*NKV, GS, HS]. K cache is [B, NKV, T, HS] = [B*NKV, T, HS]. preatt_decode is [B, NH, 1, T] = [B*NKV, GS, T].

strideA = GS * HS (compact Q decode stride) strideB = T * HS (K cache head stride) strideC = GS * T (preatt_decode head stride)

Parameters
group_sizeNH / NKV.
num_kv_headsNumber of KV heads (NKV).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_qk_prefill_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_qk_prefill_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int chunk_rows,
int max_seq_length,
int head_size,
int prefill_window_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Q @ K^T attention score plan for chunked prefill (inference).

Computes preatt[B*NH, chunk_rows, max_seq_length] = Q[B*NH, chunk_rows, HS] @ K_exp^T[B*NH, max_seq_length, HS].

strideA covers the full [T, HS] Q buffer so batch slices are spaced correctly even though only chunk_rows rows are read per slice.

Parameters
chunk_rowsNumber of Q-token rows (kPrefillChunkSize).
max_seq_lengthFull context / KV sequence length (T_).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_qk_prefill_plan_optimized()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_qk_prefill_plan_optimized ( cublasLtHandle_t handle,
int batch_size,
int num_kv_heads,
int group_size,
int chunk_rows,
int max_seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Q @ K^T prefill plan reading K directly from [B, NKV, T, HS] cache.

Q compact buffer is [B, NH, chunk, HS] = [B*NKV, GS*chunk, HS]. K cache is [B, NKV, T, HS] = [B*NKV, T, HS]. Output preatt is [B, NH, chunk, T] = [B*NKV, GS*chunk, T].

strideA = GS * chunk * HS (compact Q head stride) strideB = T * HS (K cache head stride) strideC = GS * chunk * T (preatt head stride)

Parameters
group_sizeNH / NKV — number of Q heads per KV head.
chunk_rowsNumber of Q rows per chunk (kPrefillChunkSize for full chunks).
num_kv_headsNumber of KV heads (NKV).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_qk_score_plan()

template<typename TNative>
CublasLtMatMulPlan< TNative > Mila::Dnn::Compute::Cuda::Gqa::Detail::build_qk_score_plan ( cublasLtHandle_t handle,
int batch_size,
int num_heads,
int seq_length,
int head_size,
cudaDataType_t cuda_data_type,
cublasComputeType_t compute_type,
cudaDataType_t scale_type )

Q @ K^T attention score plan (training, full sequence length).

After KV expansion: K is [B, NH, T, HS] — same layout as MHA. batch_count = B * NH.

Here is the call graph for this function:
Here is the caller graph for this function: