Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy > Member List

This is the complete list of members for Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >, including all inherited members.

backward(const TensorType &input, const TensorType &output_grad)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
build(const BuildContext &context) finalMila::Dnn::Component< TDeviceType, TComputePrecision >inlinevirtual
build_context_Mila::Dnn::Component< TDeviceType, TComputePrecision >protected
built_Mila::Dnn::Component< TDeviceType, TComputePrecision >private
cache_initialized_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
Component(const std::string &name)Mila::Dnn::Component< TDeviceType, TComputePrecision >inlineexplicit
ComponentBase typedefMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
config_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
context_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
createOperation()Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineprivate
decode(const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
decode_active_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
decode_output_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
ensureBuilt(const char *method) constMila::Dnn::Component< TDeviceType, TComputePrecision >inlineprivate
exec_context_Mila::Dnn::Component< TDeviceType, TComputePrecision >private
forward(const TensorType &input)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
getConfig() const noexceptMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
getDeviceId() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
getDeviceType()Mila::Dnn::Component< TDeviceType, TComputePrecision >inlinestatic
getExecutionContext() constMila::Dnn::Component< TDeviceType, TComputePrecision >inlineprotected
getGradients() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
getMemoryStats() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
getModelDim() const noexceptMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
getName() constMila::Dnn::Component< TDeviceType, TComputePrecision >inline
getNumHeads() const noexceptMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
getNumKvHeads() const noexceptMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
getParameterNames() constMila::Dnn::Component< TDeviceType, TComputePrecision >inlinevirtual
getParameters() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
getPrecision() noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inlinestatic
getRuntimeMode() const noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inline
getTrainingMode() const noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inline
getType() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
GroupedQueryAttention(const std::string &name, const GqaConfig &config, std::optional< DeviceId > device_id=std::nullopt)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineexplicit
hasExecutionContext() const noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inlineprotected
input_grad_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
isBuilt() const finalMila::Dnn::Component< TDeviceType, TComputePrecision >inlinevirtual
isIdentifier(const std::string &s) noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inlineprivatestatic
isInferenceMode() const noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inline
isTrainingMode() const noexceptMila::Dnn::Component< TDeviceType, TComputePrecision >inline
kCacheDtypeMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >static
kKvCompressedMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >static
kv_cache_op_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
KvCacheTensorType typedefMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
loadParameter(const std::string &name, const Serialization::ITensorBlob &blob)Mila::Dnn::Component< TDeviceType, TComputePrecision >inlinevirtual
loadParameterFromBlob(const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)Mila::Dnn::Component< TDeviceType, TComputePrecision >inlineprotected
max_input_shape_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
MR typedefMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
name_Mila::Dnn::Component< TDeviceType, TComputePrecision >private
onBuilding(const BuildContext &context) overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineprotectedvirtual
onExecutionContextSet() overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineprotectedvirtual
onTrainingModeChanging(TrainingMode training_mode) overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineprotectedvirtual
operation_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
OpType typedefMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
output_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
output_view_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
parameterCount() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
positional_op_Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >private
prefill(const TensorType &q, const TensorType &k, const TensorType &v, int position_offset)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
save_(ModelArchive &archive, SerializationMode mode) const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
setExecutionContext(IExecutionContext *context)Mila::Dnn::Component< TDeviceType, TComputePrecision >inlineprotected
setState(const GqaState &state)Mila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
setTrainingMode(TrainingMode mode)Mila::Dnn::Component< TDeviceType, TComputePrecision >inline
supportsKVCache() const noexceptMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inline
synchronize() overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
TensorType typedefMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >
toString() const overrideMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlinevirtual
training_mode_Mila::Dnn::Component< TDeviceType, TComputePrecision >private
training_mode_mutex_Mila::Dnn::Component< TDeviceType, TComputePrecision >private
validateConcatenatedQKVShape(const shape_t &shape) constMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >inlineprivate
validateName(const std::string &name)Mila::Dnn::Component< TDeviceType, TComputePrecision >inlineprivatestatic
zeroGradients()Mila::Dnn::Component< TDeviceType, TComputePrecision >inlinevirtual
~Component()=defaultMila::Dnn::Component< TDeviceType, TComputePrecision >virtual
~GroupedQueryAttention() override=defaultMila::Dnn::GroupedQueryAttention< TDeviceType, TComputePrecision, TKvPolicy >