|
Mila 0.13.48
Deep Neural Network Library
|
CUDA implementation of Multi-Head Attention using column-major cuBLASLt optimization. More...


Public Types | |
| using | ConfigType = MultiHeadAttentionConfig |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using | TensorType = Tensor<TPrecision, MR> |
| using | UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
| Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| using | MR |
| using | TensorInputType |
| using | TensorOutputType |
| Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaMultiHeadAttentionOp (IExecutionContext *context, const MultiHeadAttentionConfig &config) | |
| void | backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override |
| Backward pass: compute gradient wrt input given output gradient. | |
| void | build (const BuildContext &config) override |
| Prepare the operation for a concrete input shape. | |
| void | decode (const ITensor &input, ITensor &output, int position) override |
| Process a single autoregressive token against the KV cache. | |
| void | forward (const ITensor &input, ITensor &output) const override |
| Forward pass: compute output = f(input). | |
| const MultiHeadAttentionConfig & | getConfig () const |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| void | initializeKvCache (int batch_size, int max_seq_length) override |
| Allocate the KV cache for a given batch size and maximum sequence length. | |
| void | prefill (const ITensor &input, ITensor &output) override |
| Populate the KV cache from a packed QKV sequence and compute output. | |
| void | resetKvCache () override |
| Reset the KV cache to an empty state, preserving the allocation. | |
| void | setGradients (ITensor *, ITensor *) override |
| Bind module-owned gradient tensors to the operation. | |
| void | setParameters (ITensor *, ITensor *) override |
| Bind module-owned parameter tensors to the operation. | |
| Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| virtual | ~UnaryOperation ()=default |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual std::size_t | getStateMemorySize () const |
| Returns the number of bytes of state memory allocated by this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
| Public Member Functions inherited from Mila::Dnn::Compute::IPackedKvInference | |
| ~IPackedKvInference () override=default | |
| Public Member Functions inherited from Mila::Dnn::Compute::IKvCacheLifecycle | |
| virtual | ~IKvCacheLifecycle ()=default |
Private Member Functions | |
| void | allocateStateTensors () |
| void | buildCublasLtPlans () |
| void | ensureKVCacheEnabled () const |
| void | getComputeTypes (cublasComputeType_t &compute_type, cudaDataType_t &scale_type) const |
| cudaDataType_t | getCudaDataType () const |
| void | validateDecodeInputShape (const shape_t &input_shape) const |
| void | validateInputShape (const shape_t &input_shape) const |
| void | validatePrefillInputShape (const shape_t &input_shape) const |
Additional Inherited Members | |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
| Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| static const TensorInputType & | asInputTensor (const ITensor &t) |
| static TensorOutputType & | asOutputTensor (ITensor &t) |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
CUDA implementation of Multi-Head Attention using column-major cuBLASLt optimization.
Design philosophy:
Forward pass:
Backward pass:
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::ConfigType = MultiHeadAttentionConfig |
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::TensorType = Tensor<TPrecision, MR> |
| using Mila::Dnn::Compute::Cuda::MultiHeadAttention::CudaMultiHeadAttentionOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
|
inline |
|
inlineprivate |

|
inlineoverridevirtual |
Backward pass: compute gradient wrt input given output gradient.
Signature ordered as (input, output_grad, input_grad) to match module and operation implementations across the codebase.
Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.
|
inlineoverridevirtual |
Prepare the operation for a concrete input shape.
Default implementation is a no-op. Operations requiring shape-dependent setup should override this method.
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineprivate |

|
inlineoverridevirtual |
Process a single autoregressive token against the KV cache.
| input | Packed QKV single-token input [B, 1, 3 * embedding_dim]. |
| output | Pre-allocated output [B, 1, embedding_dim]. |
| position | Zero-based absolute sequence position into the KV cache. |
Implements Mila::Dnn::Compute::IPackedKvInference.
|
inlineprivate |

|
inlineoverridevirtual |
Forward pass: compute output = f(input).
Implementations should accept polymorphic ITensor references and may use the typed aliases / helpers to obtain typed tensor references.
Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.
|
inlineprivate |

|
inline |
|
inlineprivate |

|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Allocate the KV cache for a given batch size and maximum sequence length.
| batch_size | Number of sequences in the batch. |
| max_sequence_length | Maximum number of tokens the cache must hold. |
Implements Mila::Dnn::Compute::IKvCacheLifecycle.
|
inlineoverridevirtual |
Populate the KV cache from a packed QKV sequence and compute output.
| qkv | Packed QKV input [B, T, 3 * embedding_dim]. |
| output | Pre-allocated attention output [B, T, embedding_dim]. |
Implements Mila::Dnn::Compute::IPackedKvInference.
|
inlineoverridevirtual |
Reset the KV cache to an empty state, preserving the allocation.
Implements Mila::Dnn::Compute::IKvCacheLifecycle.
|
inlineoverridevirtual |
Bind module-owned gradient tensors to the operation.
New canonical API for binding gradient buffers. Mirrors semantics of setParameters() but for gradients used during backward().
The operation MUST NOT take ownership of the provided pointers. Implementations may cache rawData() pointers for hot-path writes.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Bind module-owned parameter tensors to the operation.
The module retains ownership of the provided ITensor objects. Implementations may cache rawData() pointers for hot-path access but MUST NOT free the provided pointers.
Default: no-op for stateless operations.
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineprivate |

|
inlineprivate |

|
inlineprivate |

|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |