Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.Rope Module Reference

Exported Modules

module  Compute.DeviceTypeTraits
module  Dnn.TensorOps
module  Compute.OperationTraits
module  Dnn.TensorTypes
module  Compute.IPositionalPairedOp
module  Dnn.TensorDataType
module  Compute.PairedOperation
module  Compute.DeviceType
module  Dnn.ITensor
module  Serialization.Mode
module  Compute.ExecutionContext
module  Serialization.ModelArchive
module  Dnn.TensorDataTypeTraits
module  Compute.ExecutionContextFactory
module  Dnn.Components.RopeConfig
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Dnn.ComponentType
module  Compute.Device
module  Dnn.Tensor
module  Compute.DeviceId

Classes

class  Mila::Dnn::Rope< TDeviceType, TPrecision >
 Device-templated RoPE component. More...

Typedefs

using ComponentBase = Component<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename OperationTraits<OperationType::RopeOp, TDeviceType, TPrecision>::type
using TensorType = Tensor<TPrecision, MR>

Functions

 Rope (const std::string &name, const RopeConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 ~Rope () override=default
std::pair< TensorType &, TensorType & > backward (TensorType &grad_Q, TensorType &grad_K)
 Backpropagate gradients through RoPE.
void createOperation ()
void decode (TensorType &Q, TensorType &K, int position)
 Single-token decode with explicit position.
void forward (TensorType &Q, TensorType &K)
 Apply rotary position embeddings to Q and K in-place.
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
void onBuilding (const BuildContext &build_context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
size_t parameterCount () const override
 Return number of trainable parameters.
void prefill (TensorType &Q, TensorType &K, int position_offset)
 Apply rotary position embeddings with an explicit position offset.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void validateBuildContext (const BuildContext &context) const
void zeroGradients () override
 Clear all model-owned gradients for this component.

Variables

RopeConfig config_
shape_t k_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }
std::unique_ptr< TensorTypeowned_K_grad_ { nullptr }
std::unique_ptr< TensorTypeowned_Q_grad_ { nullptr }
IPositionalPairedOppositional_op_ { nullptr }
shape_t q_shape_

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Encodings/Rope/Rope.ixx
 Rotary positional embedding (RoPE) component.