Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > Class Template Referenceexport

Grouped-Query Attention module that accepts concatenated QKV input. More...

Inheritance diagram for Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >:
Collaboration diagram for Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >:

Public Types

using ComponentBase = Component<TDeviceType, TComputePrecision>
using KvCacheTensorType = Tensor<kCacheDtype, MR>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename Compute::OperationTraits<Compute::OperationType::GroupedQueryAttentionOp, TDeviceType, TComputePrecision, TKvPolicy>::type
using TensorType = Tensor<TComputePrecision, MR>

Public Member Functions

 GroupedQueryAttention (const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct a GroupedQueryAttention component.
 ~GroupedQueryAttention () override=default
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Run backward pass and return the component-owned input-gradient tensor.
TensorTypedecode (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)
 Inference-only single-token decode pass.
TensorTypeforward (const TensorType &input)
 Standard forward pass.
const GqaConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
int64_t getModelDim () const noexcept
int64_t getNumHeads () const noexcept
int64_t getNumKvHeads () const noexcept
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
size_t parameterCount () const override
 Return number of trainable parameters.
TensorTypeprefill (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)
 Chunked prefill pass with explicit position offset.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void setState (const GqaState &state)
 Forward the shared transient workspace to the underlying operation.
bool supportsKVCache () const noexcept
 Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle.
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Static Public Attributes

static constexpr TensorDataType kCacheDtype
static constexpr bool kKvCompressed = TKvPolicy::kIsActive

Protected Member Functions

void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

void createOperation ()
void validateConcatenatedQKVShape (const shape_t &shape) const
 Validate that the input tensor has the expected GQA-packed QKV shape.

Private Attributes

bool cache_initialized_ { false }
GqaConfig config_
std::unique_ptr< IExecutionContextcontext_ { nullptr }
bool decode_active_ { false }
std::unique_ptr< TensorTypedecode_output_ { nullptr }
std::unique_ptr< TensorTypeinput_grad_ { nullptr }
IKvCacheLifecyclekv_cache_op_ { nullptr }
shape_t max_input_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< TensorTypeoutput_ { nullptr }
std::optional< TensorTypeoutput_view_
IKvInferencepositional_op_ { nullptr }

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TComputePrecision >
BuildContext build_context_
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
requires PrecisionSupportedOnDevice<TComputePrecision, TDeviceType>
class Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >

Grouped-Query Attention module that accepts concatenated QKV input.

GQA generalises MHA by allowing num_kv_heads < num_heads. Each K/V head is shared by a group of (num_heads / num_kv_heads) Q heads, reducing KV cache memory and bandwidth during inference.

The module requires a single input tensor in model-layout containing concatenated Q, K and V along the feature axis:

input shape == [B, T, (num_heads + 2 * num_kv_heads) * head_dim] output shape == [B, T, model_dim] (model_dim = num_heads * head_dim)

The backend compute implementation (registered as "GroupedQueryAttentionOp") must accept this layout and produce the output above.

KV-cache inference is an optional backend capability. After build(), supportsKVCache() indicates whether the underlying operation implements both IPositionalUnaryOp (prefill/decode dispatch) and IKVCacheLifecycle (cache init/reset). Both pointers are resolved once at build time.

The KV cache lifecycle (initializeKVCache / resetKVCache) is intended to be driven exclusively by the owning transformer's generate() method.

REVIEW: initializeKVCache() and resetKVCache() are currently public. When TransformerBase<> is introduced as the common base for GptTransformer, LlamaTransformer, MistralTransformer etc., revisit whether these should become private with 'friend class TransformerBase<TDeviceType, TPrecision>' to enforce that only the generate() orchestration path may manage the KV cache lifecycle.

Constructor & Destructor Documentation

◆ GroupedQueryAttention()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::GroupedQueryAttention ( const std::string & name,
const GqaConfig & config,
std::optional< DeviceId > device_id = std::nullopt )
inlineexplicitexport

Construct a GroupedQueryAttention component.

Parameters
nameComponent name identifier (mandatory).
configGQA configuration (model_dim, num_heads, num_kv_heads).
device_idOptional DeviceId to create an owned ExecutionContext (standalone / unit-test mode).

◆ ~GroupedQueryAttention()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::~GroupedQueryAttention ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
TensorType & Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::backward ( const TensorType & input,
const TensorType & output_grad )
inlineexport

Run backward pass and return the component-owned input-gradient tensor.

Parameters
inputConcatenated QKV input tensor used in forward.
output_gradGradient w.r.t. the module output.
Returns
Reference to component-owned TensorType containing the input gradient.

◆ createOperation()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::createOperation ( )
inlineexportprivate

◆ decode()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
TensorType & Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::decode ( const TensorType & q,
const TensorType & k,
const TensorType & v,
int position_offset )
inlineexport

Inference-only single-token decode pass.

When the backend implements IPositionalUnaryOp and the cache has been populated by a prior forward() call, uses the fast O(n) KV cache path. When the backend does not support positional dispatch, falls back to forward(). The caller never needs to know which path was taken.

Precondition: forward() must have been called at least once to populate the KV cache before decode() is called.

Parameters
inputSingle-token QKV input [B, 1, (Q + 2*KV) * head_dim].
positionCurrent sequence position (0-based).
Returns
Reference to component-owned single-token output tensor.

◆ forward()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
TensorType & Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::forward ( const TensorType & input)
inlineexport

Standard forward pass.

Always available regardless of backend. When the backend supports KV caching, the first forward() call initialises and populates the cache (prefill with position_offset=0). When called again after decode() steps, it automatically resets the cache and begins a new prefill session — no explicit session management required by callers.

Parameters
inputConcatenated QKV input [B, T, (Q + 2*KV) * head_dim].
Returns
Reference to component-owned output tensor [B, T, model_dim].

◆ getConfig()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
const GqaConfig & Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getConfig ( ) const
inlineexportnoexcept

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
DeviceId Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getDeviceId ( ) const
inlineoverrideexportvirtual

Get the compute device id associated with this component.

Must return the device on which parameters and operations execute.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
std::vector< ITensor * > Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getGradients ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter gradient tensors.

Only valid when isTraining() is true.

Exceptions
std::runtime_errorif called when not in training mode or before the component has been built.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
MemoryStats Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getModelDim()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
int64_t Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getModelDim ( ) const
inlineexportnoexcept

◆ getNumHeads()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
int64_t Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getNumHeads ( ) const
inlineexportnoexcept

◆ getNumKvHeads()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
int64_t Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getNumKvHeads ( ) const
inlineexportnoexcept

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
std::vector< ITensor * > Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getParameters ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter tensors.

The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ getType()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
const ComponentType Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
const std::string getName() const
Definition Component.ixx:410
std::unique_ptr< TensorType > output_
Definition GroupedQueryAttention.ixx:529
GqaConfig config_
Definition GroupedQueryAttention.ixx:513
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::onExecutionContextSet ( )
inlineoverrideexportprotectedvirtual

Lifecycle hook: Called immediately after ExecutionContext is set.

Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.

Common uses:

  • Composite components: Create and configure child components.
  • Device resource allocation: Query device capabilities.

Default implementation does nothing.

Exceptions
Anyexception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::onTrainingModeChanging ( TrainingMode mode)
inlineoverrideexportprotectedvirtual

Hook called before TrainingMode transitions.

Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.

The default implementation is a no-op.

Parameters
modeThe incoming TrainingMode.

Reimplemented from Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
size_t Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::parameterCount ( ) const
inlineoverrideexportvirtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ prefill()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
TensorType & Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::prefill ( const TensorType & q,
const TensorType & k,
const TensorType & v,
int position_offset )
inlineexport

Chunked prefill pass with explicit position offset.

Called by the transformer block during chunked prefill. The KV cache must already be initialized (via onBuilding or forward()).

Parameters
inputConcatenated QKV input [B, T_chunk, (Q + 2*KV) * head_dim].
position_offsetAbsolute position of the first token in this chunk.
Returns
Reference to component-owned output tensor.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

◆ setState()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::setState ( const GqaState & state)
inlineexport

Forward the shared transient workspace to the underlying operation.

Must be called after build() and before prefill() or decode(). A no-op on backends that do not implement setState (e.g. CPU stub).

Parameters
stateNon-owning pointers to the shared GQA scratch tensors.

◆ supportsKVCache()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
bool Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::supportsKVCache ( ) const
inlineexportnoexcept

Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle.

Resolved once at build time. CPU backends return false; CUDA backends return true when CudaGroupedQueryAttentionOp is in use.

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::synchronize ( )
inlineoverrideexportvirtual

Wait for outstanding device work submitted by this component.

On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ toString()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
std::string Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::toString ( ) const
inlineoverrideexportvirtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implements Mila::Dnn::Component< TDeviceType, TComputePrecision >.

◆ validateConcatenatedQKVShape()

template<DeviceType TDeviceType, TensorDataType TComputePrecision, KvCachePolicy TKvPolicy = NoKvCompression>
void Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >::validateConcatenatedQKVShape ( const shape_t & shape) const
inlineexportprivate

Validate that the input tensor has the expected GQA-packed QKV shape.

Expected trailing dimension: (num_heads + 2 * num_kv_heads) * head_dim


The documentation for this class was generated from the following file: