|
Mila 0.13.48
Deep Neural Network Library
|
Capability interface for position-dependent paired operations. More...

Public Member Functions | |
| virtual | ~IPositionalPairedOp ()=default |
| virtual void | decode (const ITensor &inputA, const ITensor &inputB, ITensor &outputA, ITensor &outputB, int position)=0 |
| Process a single token at an explicit sequence position. | |
| virtual void | prefill (const ITensor &inputA, const ITensor &inputB, ITensor &outputA, ITensor &outputB, int position_offset)=0 |
| Process a chunk of tokens starting at a given position. | |
Capability interface for position-dependent paired operations.
Implemented by operations that accept two input/output tensor pairs and whose mathematical output changes based on absolute token position — e.g. RoPE, which rotates both Q and K embeddings by position-dependent angles.
Parameter order follows the PairedOperation convention: all inputs first, then all outputs, followed by the position argument.
Operations that are position-agnostic do not implement this interface — they use forward() for all modes.
|
virtualdefault |
|
pure virtual |
Process a single token at an explicit sequence position.
| inputA | Single-token first input [B, 1, ...]. |
| inputB | Single-token second input [B, 1, ...]. |
| outputA | Single-token first output [B, 1, ...]. |
| outputB | Single-token second output [B, 1, ...]. |
| position | Zero-based absolute sequence position. |
Implemented in Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.
|
pure virtual |
Process a chunk of tokens starting at a given position.
| inputA | First input tensor (e.g. Q) for this chunk. |
| inputB | Second input tensor (e.g. K) for this chunk. |
| outputA | First output tensor for this chunk. |
| outputB | Second output tensor for this chunk. |
| position_offset | Absolute position of the first token in this chunk. |
Implemented in Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.