Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB > Class Template Referenceabstractexport

Abstract base for paired operations: two inputs -> two outputs. More...

Inheritance diagram for Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >:
Collaboration diagram for Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >:

Public Types

using MR = CpuMemoryResource
using TensorInputAType = Tensor<TInputA, MR>
using TensorInputBType = Tensor<TInputB, MR>
using TensorOutputType = Tensor<TPrecision, MR>
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
using DataTypeTraits

Public Member Functions

virtual ~PairedOperation ()=default
virtual void backward (const ITensor &grad_out_a, const ITensor &grad_out_b, ITensor &grad_in_a, ITensor &grad_in_b) const =0
 Backward pass: propagate upstream gradients to input gradients.
virtual void forward (const ITensor &in_a, const ITensor &in_b, ITensor &out_a, ITensor &out_b) const =0
 Forward pass: (out_a, out_b) = f(in_a, in_b).
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
virtual ~Operation ()=default
virtual void build (const BuildContext &build_context)
 Prepare the operation for a concrete input shape.
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::string getName () const=0
 Human-readable operation name.
virtual OperationType getOperationType () const=0
 Operation type identifier.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setGradients (ITensor *weight_grad, ITensor *bias_grad)
 Bind module-owned gradient tensors to the operation.
virtual void setParameters (ITensor *weight, ITensor *bias)
 Bind module-owned parameter tensors to the operation.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Static Protected Member Functions

static const TensorInputATypeasInputA (const ITensor &t)
static const TensorInputBTypeasInputB (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >

Abstract base for paired operations: two inputs -> two outputs.

Models symmetric transforms where two tensors are processed jointly and each produces an independent output (e.g. Q and K in RoPE).

The backward signature is symmetric with forward: 2-in / 2-out. Implementations that require saved forward activations for their backward pass must cache them internally during forward().

Template Parameters
TDeviceTypeDevice target (DeviceType::Cpu, DeviceType::Cuda, ...)
TPrecisionCanonical element precision for inputs and outputs (e.g. FP32)
TInputAElement type for the first input tensor (defaults to TPrecision)
TInputBElement type for the second input tensor (defaults to TInputA)

Member Typedef Documentation

◆ MR

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
using Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::MR = CpuMemoryResource

◆ TensorInputAType

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
using Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::TensorInputAType = Tensor<TInputA, MR>

◆ TensorInputBType

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
using Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::TensorInputBType = Tensor<TInputB, MR>

◆ TensorOutputType

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
using Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::TensorOutputType = Tensor<TPrecision, MR>

Constructor & Destructor Documentation

◆ ~PairedOperation()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
virtual Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::~PairedOperation ( )
virtualdefault

Member Function Documentation

◆ asInputA()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
const TensorInputAType & Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::asInputA ( const ITensor & t)
inlinestaticprotected

◆ asInputB()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
const TensorInputBType & Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::asInputB ( const ITensor & t)
inlinestaticprotected

◆ asOutputTensor()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
TensorOutputType & Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::asOutputTensor ( ITensor & t)
inlinestaticprotected

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
virtual void Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::backward ( const ITensor & grad_out_a,
const ITensor & grad_out_b,
ITensor & grad_in_a,
ITensor & grad_in_b ) const
pure virtual

Backward pass: propagate upstream gradients to input gradients.

Parameters
grad_out_aUpstream gradient w.r.t. out_a (dL/dout_a).
grad_out_bUpstream gradient w.r.t. out_b (dL/dout_b).
grad_in_aOutput gradient w.r.t. in_a (dL/din_a).
grad_in_bOutput gradient w.r.t. in_b (dL/din_b).

Implemented in Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision, TensorDataType TInputA = TPrecision, TensorDataType TInputB = TInputA>
virtual void Mila::Dnn::Compute::PairedOperation< TDeviceType, TPrecision, TInputA, TInputB >::forward ( const ITensor & in_a,
const ITensor & in_b,
ITensor & out_a,
ITensor & out_b ) const
pure virtual

Forward pass: (out_a, out_b) = f(in_a, in_b).

Parameters
in_aFirst input tensor.
in_bSecond input tensor.
out_aFirst output tensor (same shape as in_a).
out_bSecond output tensor (same shape as in_b).

Implemented in Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TComputePrecision >, Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::BF16 >, and Mila::Dnn::Compute::Cuda::Rope::CudaRopeOp< TensorDataType::FP32 >.


The documentation for this class was generated from the following file: