Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision > Class Template Referenceexport

CUDA implementation of Multi-Head Attention using column-major cuBLASLt optimization. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >:

Public Types

using ConfigType = MultiHeadAttentionConfig
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
using DataTypeTraits

Public Member Functions

 CudaMultiHeadAttentionOp (IExecutionContext *context, const MultiHeadAttentionConfig &config)
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override
 Backward pass: compute gradient wrt input given output gradient.
void build (const BuildContext &config) override
 Prepare the operation for a concrete input shape.
void decode (const ITensor &input, ITensor &output, int position) override
 Process a single autoregressive token against the KV cache.
void forward (const ITensor &input, ITensor &output) const override
 Forward pass: compute output = f(input).
const MultiHeadAttentionConfiggetConfig () const
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void initializeKvCache (int batch_size, int max_seq_length) override
 Allocate the KV cache for a given batch size and maximum sequence length.
void prefill (const ITensor &input, ITensor &output) override
 Populate the KV cache from a packed QKV sequence and compute output.
void resetKvCache () override
 Reset the KV cache to an empty state, preserving the allocation.
void setGradients (ITensor *, ITensor *) override
 Bind module-owned gradient tensors to the operation.
void setParameters (ITensor *, ITensor *) override
 Bind module-owned parameter tensors to the operation.
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.
Public Member Functions inherited from Mila::Dnn::Compute::IPackedKvInference
 ~IPackedKvInference () override=default
Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle
virtual ~IKvCacheLifecycle ()=default

Private Member Functions

void allocateStateTensors ()
void buildCublasLtPlans ()
void ensureKVCacheEnabled () const
void getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const
cudaDataType_t getCudaDataType () const
void validateDecodeInputShape (const shape_t &input_shape) const
void validateInputShape (const shape_t &input_shape) const
void validatePrefillInputShape (const shape_t &input_shape) const

Private Attributes

int active_max_seq_len_ { 0 }
NativeTypeatt_ { nullptr }
NativeTypeatt_decode_ { nullptr }
std::shared_ptr< TensorTypeatt_decode_tensor_
std::shared_ptr< TensorTypeatt_tensor_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_decode_plan_
Detail::CublasLtMatMulPlan< NativeTypeatt_value_plan_
int B_ { 0 }
Detail::CublasLtMatMulPlan< NativeTypebackward_att_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_k_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_q_plan_
Detail::CublasLtMatMulPlan< NativeTypebackward_v_plan_
int cached_seq_len_ { 0 }
MultiHeadAttentionConfig config_
CudaExecutionContextcontext_
cublasLtHandle_t cublaslt_handle_ { nullptr }
NativeTypedatt_ { nullptr }
std::shared_ptr< TensorTypedatt_tensor_
NativeTypedk_ { nullptr }
std::shared_ptr< TensorTypedk_tensor_
NativeTypedpreatt_ { nullptr }
std::shared_ptr< TensorTypedpreatt_tensor_
NativeTypedq_ { nullptr }
std::shared_ptr< TensorTypedq_tensor_
NativeTypedV_ { nullptr }
std::shared_ptr< TensorTypedV_tensor_
NativeTypedVout_ { nullptr }
std::shared_ptr< TensorTypedVout_tensor_
int embedding_dim_ { 0 }
int HS_ { 0 }
NativeTypek_ { nullptr }
std::shared_ptr< TensorTypek_tensor_
bool kv_cache_enabled_ { false }
int NH_ { 0 }
NativeTypepreatt_ { nullptr }
NativeTypepreatt_decode_ { nullptr }
std::shared_ptr< TensorTypepreatt_decode_tensor_
std::shared_ptr< TensorTypepreatt_tensor_
NativeTypeq_ { nullptr }
std::shared_ptr< TensorTypeq_tensor_
Detail::CublasLtMatMulPlan< NativeTypeqk_decode_plan_
Detail::CublasLtMatMulPlan< NativeTypeqk_score_plan_
int qkv_dim_ { 0 }
int T_ { 0 }
NativeTypev_ { nullptr }
NativeTypev_out_ { nullptr }
NativeTypev_out_decode_ { nullptr }
std::shared_ptr< TensorTypev_out_decode_tensor_
std::shared_ptr< TensorTypev_out_tensor_
std::shared_ptr< TensorTypev_tensor_

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >

CUDA implementation of Multi-Head Attention using column-major cuBLASLt optimization.

Design philosophy:

  • Two-phase initialization: build() creates cuBLASLt plans, forward()/backward() execute them
  • Column-major layout eliminates most transpose operations in cuBLASLt
  • All dimension computation and algorithm selection happens once in build()
  • Forward/backward are hot-path methods with zero setup overhead
  • Custom CUDA kernels handle permute/unpermute and softmax operations

Forward pass:

  1. Permute QKV from [B, T, 3*C] to separate Q, K, V [B, NH, HS, T] (column-major)
  2. Compute attention scores: preatt = Q^T @ K (exploiting column-major layout)
  3. Apply softmax with causal masking: att = softmax(preatt / sqrt(HS))
  4. Compute values: vaccum = Att @ V^T
  5. Unpermute output from [B, NH, HS, T] to [B, T, C]

Backward pass:

  1. Unpermute output gradient to [B, NH, HS, T]
  2. Compute dV = Att^T @ dvaccum^T
  3. Compute dAtt = dvaccum^T @ V
  4. Softmax backward: dPreatt = softmax_backward(dAtt, Att)
  5. Compute dQ = dPreatt @ K^T
  6. Compute dK = dPreatt^T @ Q^T
  7. Permute gradients back to concatenated QKV format

Member Typedef Documentation

◆ ConfigType

◆ CudaExecutionContext

◆ MR

◆ NativeType

◆ TensorType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>

Constructor & Destructor Documentation

◆ CudaMultiHeadAttentionOp()

template<TensorDataType TPrecision>
Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::CudaMultiHeadAttentionOp ( IExecutionContext * context,
const MultiHeadAttentionConfig & config )
inline

Member Function Documentation

◆ allocateStateTensors()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::allocateStateTensors ( )
inlineprivate
Here is the caller graph for this function:

◆ backward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Backward pass: compute gradient wrt input given output gradient.

Signature ordered as (input, output_grad, input_grad) to match module and operation implementations across the codebase.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ build()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::build ( const BuildContext & build_context)
inlineoverridevirtual

Prepare the operation for a concrete input shape.

Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ buildCublasLtPlans()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::buildCublasLtPlans ( )
inlineprivate
Here is the caller graph for this function:

◆ decode()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::decode ( const ITensor & input,
ITensor & output,
int position )
inlineoverridevirtual

Process a single autoregressive token against the KV cache.

Parameters
inputPacked QKV single-token input [B, 1, 3 * embedding_dim].
outputPre-allocated output [B, 1, embedding_dim].
positionZero-based absolute sequence position into the KV cache.

Implements Mila::Dnn::Compute::IPackedKvInference.

◆ ensureKVCacheEnabled()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::ensureKVCacheEnabled ( ) const
inlineprivate
Here is the caller graph for this function:

◆ forward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Forward pass: compute output = f(input).

Implementations should accept polymorphic ITensor references and may use the typed aliases / helpers to obtain typed tensor references.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ getComputeTypes()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::getComputeTypes ( cublasComputeType_t & compute_type,
cudaDataType_t & scale_type ) const
inlineprivate
Here is the caller graph for this function:

◆ getConfig()

template<TensorDataType TPrecision>
const MultiHeadAttentionConfig & Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::getConfig ( ) const
inline

◆ getCudaDataType()

template<TensorDataType TPrecision>
cudaDataType_t Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::getCudaDataType ( ) const
inlineprivate
Here is the caller graph for this function:

◆ getName()

template<TensorDataType TPrecision>
std::string Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ getOperationType()

template<TensorDataType TPrecision>
OperationType Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::getOperationType ( ) const
inlineoverridevirtual

◆ initializeKvCache()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::initializeKvCache ( int batch_size,
int max_sequence_length )
inlineoverridevirtual

Allocate the KV cache for a given batch size and maximum sequence length.

Parameters
batch_sizeNumber of sequences in the batch.
max_sequence_lengthMaximum number of tokens the cache must hold.

Implements Mila::Dnn::Compute::IKvCacheLifecycle.

◆ prefill()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::prefill ( const ITensor & qkv,
ITensor & output )
inlineoverridevirtual

Populate the KV cache from a packed QKV sequence and compute output.

Parameters
qkvPacked QKV input [B, T, 3 * embedding_dim].
outputPre-allocated attention output [B, T, embedding_dim].

Implements Mila::Dnn::Compute::IPackedKvInference.

◆ resetKvCache()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::resetKvCache ( )
inlineoverridevirtual

Reset the KV cache to an empty state, preserving the allocation.

Implements Mila::Dnn::Compute::IKvCacheLifecycle.

◆ setGradients()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlineoverridevirtual

Bind module-owned gradient tensors to the operation.

New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().

The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ setParameters()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::setParameters ( ITensor * weight,
ITensor * bias )
inlineoverridevirtual

Bind module-owned parameter tensors to the operation.

The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.

Default: no-op for stateless operations.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ validateDecodeInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::validateDecodeInputShape ( const shape_t & input_shape) const
inlineprivate
Here is the caller graph for this function:

◆ validateInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineprivate
Here is the caller graph for this function:

◆ validatePrefillInputShape()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::validatePrefillInputShape ( const shape_t & input_shape) const
inlineprivate
Here is the caller graph for this function:

Member Data Documentation

◆ active_max_seq_len_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::active_max_seq_len_ { 0 }
private

◆ att_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::att_ { nullptr }
private

◆ att_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::att_decode_ { nullptr }
private

◆ att_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::att_decode_tensor_
private

◆ att_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::att_tensor_
private

◆ att_value_decode_plan_

◆ att_value_plan_

◆ B_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::B_ { 0 }
private

◆ backward_att_plan_

◆ backward_k_plan_

◆ backward_q_plan_

◆ backward_v_plan_

◆ cached_seq_len_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::cached_seq_len_ { 0 }
private

◆ config_

◆ context_

◆ cublaslt_handle_

template<TensorDataType TPrecision>
cublasLtHandle_t Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::cublaslt_handle_ { nullptr }
private

◆ datt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::datt_ { nullptr }
private

◆ datt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::datt_tensor_
private

◆ dk_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dk_ { nullptr }
private

◆ dk_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dk_tensor_
private

◆ dpreatt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dpreatt_ { nullptr }
private

◆ dpreatt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dpreatt_tensor_
private

◆ dq_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dq_ { nullptr }
private

◆ dq_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dq_tensor_
private

◆ dV_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dV_ { nullptr }
private

◆ dV_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dV_tensor_
private

◆ dVout_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dVout_ { nullptr }
private

◆ dVout_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::dVout_tensor_
private

◆ embedding_dim_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::embedding_dim_ { 0 }
private

◆ HS_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::HS_ { 0 }
private

◆ k_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::k_ { nullptr }
private

◆ k_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::k_tensor_
private

◆ kv_cache_enabled_

template<TensorDataType TPrecision>
bool Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::kv_cache_enabled_ { false }
private

◆ NH_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::NH_ { 0 }
private

◆ preatt_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::preatt_ { nullptr }
private

◆ preatt_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::preatt_decode_ { nullptr }
private

◆ preatt_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::preatt_decode_tensor_
private

◆ preatt_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::preatt_tensor_
private

◆ q_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::q_ { nullptr }
private

◆ q_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::q_tensor_
private

◆ qk_decode_plan_

◆ qk_score_plan_

◆ qkv_dim_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::qkv_dim_ { 0 }
private

◆ T_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::T_ { 0 }
private

◆ v_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_ { nullptr }
private

◆ v_out_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_out_ { nullptr }
private

◆ v_out_decode_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_out_decode_ { nullptr }
private

◆ v_out_decode_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_out_decode_tensor_
private

◆ v_out_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_out_tensor_
private

◆ v_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::v_tensor_
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Attention/MHA/CudaMhaOp.ixx