| backward(const TensorType &input, const TensorType &output_grad) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| build(const BuildContext &context) final | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinevirtual |
| build_context_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | protected |
| built_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | private |
| cache_initialized_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| Component(const std::string &name) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineexplicit |
| ComponentBase typedef | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |
| config_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| context_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| createOperation() | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineprivate |
| decode(const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| decode_active_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| decode_output_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| ensureBuilt(const char *method) const | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprivate |
| exec_context_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | private |
| forward(const TensorType &input) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| getConfig() const noexcept | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| getDeviceId() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| getDeviceType() | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinestatic |
| getExecutionContext() const | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprotected |
| getGradients() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| getMemoryStats() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| getModelDim() const noexcept | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| getName() const | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| getNumHeads() const noexcept | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| getNumKvHeads() const noexcept | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| getParameterNames() const | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinevirtual |
| getParameters() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| getPrecision() noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinestatic |
| getRuntimeMode() const noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| getTrainingMode() const noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| getType() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| GroupedQueryAttention(const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineexplicit |
| hasExecutionContext() const noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprotected |
| input_grad_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| isBuilt() const final | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinevirtual |
| isIdentifier(const std::string &s) noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprivatestatic |
| isInferenceMode() const noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| isTrainingMode() const noexcept | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| kCacheDtype | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | static |
| kKvCompressed | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | static |
| kv_cache_op_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| KvCacheTensorType typedef | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |
| loadParameter(const std::string &name, const Serialization::ITensorBlob &blob) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinevirtual |
| loadParameterFromBlob(const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprotected |
| max_input_shape_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| MR typedef | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |
| name_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | private |
| onBuilding(const BuildContext &context) override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineprotectedvirtual |
| onExecutionContextSet() override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineprotectedvirtual |
| onTrainingModeChanging(TrainingMode training_mode) override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineprotectedvirtual |
| operation_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| OpType typedef | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |
| output_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| output_view_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| parameterCount() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| positional_op_ | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | private |
| prefill(const TensorType &q, const TensorType &k, const TensorType &v, int position_offset) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| save_(ModelArchive &archive, SerializationMode mode) const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| setExecutionContext(IExecutionContext *context) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprotected |
| setState(const GqaState &state) | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| setTrainingMode(TrainingMode mode) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inline |
| supportsKVCache() const noexcept | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inline |
| synchronize() override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| TensorType typedef | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |
| toString() const override | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlinevirtual |
| training_mode_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | private |
| training_mode_mutex_ | Mila::Dnn::Component< TDeviceType, TComputePrecision > | private |
| validateConcatenatedQKVShape(const shape_t &shape) const | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | inlineprivate |
| validateName(const std::string &name) | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlineprivatestatic |
| zeroGradients() | Mila::Dnn::Component< TDeviceType, TComputePrecision > | inlinevirtual |
| ~Component()=default | Mila::Dnn::Component< TDeviceType, TComputePrecision > | virtual |
| ~GroupedQueryAttention() override=default | Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > | |