Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.Gqa Module Reference

Exported Modules

module  Compute.OperationTraits
module  Compute.DeviceTypeTraits
module  Compute.IKvCacheLifecycle
module  Dnn.TensorTypes
module  Dnn.TensorDataTypeTraits
module  Serialization.ModelArchive
module  Compute.CpuMemoryResource
module  Dnn.Component
module  Compute.ExecutionContextFactory
module  Compute.IKvInference
module  Compute.MemoryResource
module  Dnn.TensorDataType
module  Compute.GqaState
module  Dnn.Quantization.KvCache.Policy
module  Dnn.Components.GqaConfig
module  Compute.DeviceId
module  Compute.Device
module  Dnn.Tensor
module  Dnn.ComponentType
module  Compute.DeviceType
module  Serialization.Mode
module  Compute.ExecutionContext
module  Dnn.ITensor

Classes

class  Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
 Grouped-Query Attention module that accepts concatenated QKV input. More...

Typedefs

using ComponentBase = Component<TDeviceType, TComputePrecision>
using KvCacheTensorType = Tensor<kCacheDtype, MR>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename Compute::OperationTraits<Compute::OperationType::GroupedQueryAttentionOp, TDeviceType, TComputePrecision, TKvPolicy>::type
using TensorType = Tensor<TComputePrecision, MR>

Functions

 GroupedQueryAttention (const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct a GroupedQueryAttention component.
 ~GroupedQueryAttention () override=default
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Run backward pass and return the component-owned input-gradient tensor.
void createOperation ()
TensorTypedecode (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)
 Inference-only single-token decode pass.
TensorTypeforward (const TensorType &input)
 Standard forward pass.
const GqaConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
int64_t getModelDim () const noexcept
int64_t getNumHeads () const noexcept
int64_t getNumKvHeads () const noexcept
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
void onBuilding (const BuildContext &context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
size_t parameterCount () const override
 Return number of trainable parameters.
TensorTypeprefill (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)
 Chunked prefill pass with explicit position offset.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void setState (const GqaState &state)
 Forward the shared transient workspace to the underlying operation.
bool supportsKVCache () const noexcept
 Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle.
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void validateConcatenatedQKVShape (const shape_t &shape) const
 Validate that the input tensor has the expected GQA-packed QKV shape.

Variables

bool cache_initialized_ { false }
GqaConfig config_
std::unique_ptr< IExecutionContextcontext_ { nullptr }
bool decode_active_ { false }
std::unique_ptr< TensorTypedecode_output_ { nullptr }
std::unique_ptr< TensorTypeinput_grad_ { nullptr }
static constexpr TensorDataType kCacheDtype
static constexpr bool kKvCompressed = TKvPolicy::kIsActive
IKvCacheLifecyclekv_cache_op_ { nullptr }
shape_t max_input_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< TensorTypeoutput_ { nullptr }
std::optional< TensorTypeoutput_view_
IKvInferencepositional_op_ { nullptr }

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Attention/GQA/GroupedQueryAttention.ixx
 Grouped-Query Attention module (concatenated QKV input).