|
Mila 0.13.48
Deep Neural Network Library
|
CUDA implementation of the Rope (rotary positional embedding) operation. More...


Public Types | |
| using | CacheKey = RopeCacheRegistry::CacheKey |
| using | ComputeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TComputePrecision>::device_type |
| using | ConfigType = RopeConfig |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | TensorType = Tensor<TComputePrecision, MR> |
| Public Types inherited from Mila::Dnn::Compute::PairedOperation< DeviceType::Cuda, TComputePrecision > | |
| using | MR |
| using | TensorInputAType |
| using | TensorInputBType |
| using | TensorOutputType |
| Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaRopeOp (const CudaRopeOp &)=delete | |
| CudaRopeOp (CudaRopeOp &&other) noexcept | |
| CudaRopeOp (IExecutionContext *context, const RopeConfig &config) | |
| ~CudaRopeOp () | |
| void | backward (const ITensor &dQ_out, const ITensor &dK_out, ITensor &dQ_in, ITensor &dK_in) const override |
| Backward pass (hot path). | |
| void | build (const BuildContext &build_context) override |
| Prepare the operation for a concrete input shape (cold path). | |
| void | decode (const ITensor &Q_in, const ITensor &K_in, ITensor &Q_out, ITensor &K_out, int position) override |
| Single-token decode with explicit position. | |
| void | forward (const ITensor &Q_in, const ITensor &K_in, ITensor &Q_out, ITensor &K_out) const override |
| Full-sequence forward pass. | |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| std::size_t | getStateMemorySize () const override |
| Returns the number of bytes of state memory allocated by this operation. | |
| CudaRopeOp & | operator= (const CudaRopeOp &)=delete |
| CudaRopeOp & | operator= (CudaRopeOp &&other) noexcept |
| void | prefill (const ITensor &Q_in, const ITensor &K_in, ITensor &Q_out, ITensor &K_out, int position_offset) override |
| Chunked prefill with explicit position offset. | |
| Public Member Functions inherited from Mila::Dnn::Compute::PairedOperation< DeviceType::Cuda, TComputePrecision > | |
| virtual | ~PairedOperation ()=default |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setGradients (ITensor *weight_grad, ITensor *bias_grad) |
| Bind module-owned gradient tensors to the operation. | |
| virtual void | setParameters (ITensor *weight, ITensor *bias) |
| Bind module-owned parameter tensors to the operation. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
| Public Member Functions inherited from Mila::Dnn::Compute::IPositionalPairedOp | |
| virtual | ~IPositionalPairedOp ()=default |
Private Member Functions | |
| void | dispatchForward (const ITensor &Q_in, const ITensor &K_in, ITensor &Q_out, ITensor &K_out, int B, int T, int position_offset) const |
| void | ensureBuilt () const |
| CacheKey | makeCacheKey () const noexcept |
| void | releaseCache () noexcept |
| void | validateRuntimeShape (int B, int T) const |
Private Attributes | |
| int | batch_size_ { 0 } |
| CacheKey | cache_key_ {} |
| RopeConfig | config_ |
| CudaExecutionContext * | context_ |
| float * | cos_cache_ { nullptr } |
| bool | owns_cache_ { false } |
| int | seq_length_ { 0 } |
| float * | sin_cache_ { nullptr } |
Additional Inherited Members | |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
| Static Protected Member Functions inherited from Mila::Dnn::Compute::PairedOperation< DeviceType::Cuda, TComputePrecision > | |
| static const TensorInputAType & | asInputA (const ITensor &t) |
| static const TensorInputBType & | asInputB (const ITensor &t) |
| static TensorOutputType & | asOutputTensor (ITensor &t) |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
CUDA implementation of the Rope (rotary positional embedding) operation.
Takes the projected Q and K tensors produced by linear layers and applies position-dependent rotations so that attention scores encode relative position implicitly through the inner product.
Design:
Input/output shapes: Q: [B, T, n_heads, head_dim] K: [B, T, n_kv_heads, head_dim] Q', K' – same shapes as inputs.
Decode shapes (T=1, explicit position): Q: [B, 1, n_heads, head_dim] K: [B, 1, n_kv_heads, head_dim]
| TPrecision | Precision of Q/K tensors (FP32 or FP16). |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::CacheKey = RopeCacheRegistry::CacheKey |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::ComputeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TComputePrecision>::device_type |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::ConfigType = RopeConfig |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >::TensorType = Tensor<TComputePrecision, MR> |
|
inline |
|
inline |
|
delete |
|
inlinenoexcept |
|
inlineoverridevirtual |
Backward pass (hot path).
RoPE is an orthogonal rotation (R^T R = I), so the Jacobian is R^T. The backward pass is therefore the inverse rotation: rotate the upstream gradients by -theta (negate sin terms). No new parameters are accumulated.
Implements Mila::Dnn::Compute::PairedOperation< DeviceType::Cuda, TComputePrecision >.
|
inlineoverridevirtual |
Prepare the operation for a concrete input shape (cold path).
On the first call, acquires a shared cos/sin cache from RopeCacheRegistry and fills it if this is the first op with this configuration. Subsequent calls on the same instance update the runtime shape limits only; the shared cache is not re-acquired.
| build_context | Build context carrying the Q/K input shape [B, T, ...]. |
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Single-token decode with explicit position.
Reads only the cache row at position. Used for KV-cache autoregressive generation where T=1.
| Q_in | Input Q [B, 1, n_heads, head_dim]. |
| K_in | Input K [B, 1, n_kv_heads, head_dim]. |
| Q_out | Output Q [B, 1, n_heads, head_dim]. |
| K_out | Output K [B, 1, n_kv_heads, head_dim]. |
| position | Zero-based absolute sequence position. |
Implements Mila::Dnn::Compute::IPositionalPairedOp.
|
inlineprivate |

|
inlineprivate |

|
inlineoverridevirtual |
Full-sequence forward pass.
Applies RoPE to Q and K across the full sequence with position_offset = 0. Used for training forward passes.
Implements Mila::Dnn::Compute::PairedOperation< DeviceType::Cuda, TComputePrecision >.
|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Returns the number of bytes of state memory allocated by this operation.
State memory includes build-time buffers such as caches and scratch allocations. Parameters and gradients are owned at the component level and are not included.
Override in derived operations that allocate device or host state during build().
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineprivatenoexcept |

|
delete |
|
inlinenoexcept |
|
inlineoverridevirtual |
Chunked prefill with explicit position offset.
Applies RoPE to Q and K using absolute positions [position_offset .. position_offset + T - 1] for the cos/sin cache lookup.
| Q_in | Input Q [B, T, n_heads, head_dim]. |
| K_in | Input K [B, T, n_kv_heads, head_dim]. |
| Q_out | Output Q [B, T, n_heads, head_dim]. |
| K_out | Output K [B, T, n_kv_heads, head_dim]. |
| position_offset | Absolute position of the first token in this chunk. |
Implements Mila::Dnn::Compute::IPositionalPairedOp.
|
inlineprivatenoexcept |

|
inlineprivate |

|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |