|
| | GroupedQueryAttention (const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt) |
| | Construct a GroupedQueryAttention component.
|
| | ~GroupedQueryAttention () override=default |
| TensorType & | backward (const TensorType &input, const TensorType &output_grad) |
| | Run backward pass and return the component-owned input-gradient tensor.
|
| void | createOperation () |
| TensorType & | decode (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) |
| | Inference-only single-token decode pass.
|
| TensorType & | forward (const TensorType &input) |
| | Standard forward pass.
|
| const GqaConfig & | getConfig () const noexcept |
| DeviceId | getDeviceId () const override |
| | Get the compute device id associated with this component.
|
| std::vector< ITensor * > | getGradients () const override |
| | Return non-owning pointers to parameter gradient tensors.
|
| MemoryStats | getMemoryStats () const override |
| | Return the current memory allocation breakdown for this component.
|
| int64_t | getModelDim () const noexcept |
| int64_t | getNumHeads () const noexcept |
| int64_t | getNumKvHeads () const noexcept |
| std::vector< ITensor * > | getParameters () const override |
| | Return non-owning pointers to parameter tensors.
|
| const ComponentType | getType () const override |
| | Get the component type identifier.
|
| void | onBuilding (const BuildContext &context) override |
| | Hook invoked by build() to allocate component buffers.
|
| void | onExecutionContextSet () override |
| | Lifecycle hook: Called immediately after ExecutionContext is set.
|
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| | Hook called before TrainingMode transitions.
|
| size_t | parameterCount () const override |
| | Return number of trainable parameters.
|
| TensorType & | prefill (const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) |
| | Chunked prefill pass with explicit position offset.
|
| void | save_ (ModelArchive &archive, SerializationMode mode) const override |
| void | setState (const GqaState &state) |
| | Forward the shared transient workspace to the underlying operation.
|
| bool | supportsKVCache () const noexcept |
| | Returns true when the underlying operation implements both IPositionalUnaryOp and IKVCacheLifecycle.
|
| void | synchronize () override |
| | Wait for outstanding device work submitted by this component.
|
| std::string | toString () const override |
| | Produce a short, human-readable description of the component.
|
| void | validateConcatenatedQKVShape (const shape_t &shape) const |
| | Validate that the input tensor has the expected GQA-packed QKV shape.
|