Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > Class Template Referenceabstractexport

Abstract class for binary operations in the neural network framework. More...

Inheritance diagram for Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >:
Collaboration diagram for Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >:

Public Types

using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 

Public Member Functions

 BinaryOperation (OperationType operation_type)
 Constructs a BinaryOperation with the specified operation type and precision policy.
 
 BinaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs a BinaryOperation with the specified operation type, device context, and precision policy.
 
virtual ~BinaryOperation ()=default
 Virtual destructor for proper cleanup of derived classes.
 
virtual void backward (const Tensor< TInput1, MR > &input1, const Tensor< TInput2, MR > &input2, const Tensor< TOutput, MR > &output, const Tensor< TOutput, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &parameter_gradients, Tensor< TInput1, MR > &input1_gradient, Tensor< TInput2, MR > &input2_gradient, const OperationAttributes &attributes, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const
 Executes the backward pass of a binary operation.
 
virtual void forward (const Tensor< TInput1, MR > &input1, const Tensor< TInput2, MR > &input2, const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > &parameters, const OperationAttributes &attributes, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0
 Executes the forward pass of a binary operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >
 OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs an OperationBase object with a specific device context and compute precision.
 
virtual ~OperationBase ()=default
 Virtual destructor for the OperationBase class.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Gets the device context associated with this operation.
 
DeviceType getDeviceType () const
 Gets the device type for this operation.
 
virtual std::string getName () const =0
 Gets the name of the operation.
 
OperationType getOperationType () const
 Gets the operation type enumeration value.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
requires ValidTensorTypes<TInput1, TInput2>&& ValidFloatTensorType<TOutput>
class Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >

Abstract class for binary operations in the neural network framework.

This class extends OperationBase to provide specialized functionality for operations that take two input tensors and produce a single output tensor. Derived classes must implement the forward() method to define the specific computation for the operation. Examples include element-wise operations (add, multiply), matrix operations (matmul), and more complex operations like convolution.

The class supports configurable compute precision policies which control how operations handle mixed precision computations. This allows for optimizing between performance and accuracy based on the specific requirements of the application.

Template Parameters
TDeviceTypeThe target device type for the operation, defaults to DeviceType::Cuda.
TInput1The data type of the first input tensor elements. Must satisfy ValidTensorType constraint.
TInput2The data type of the second input tensor elements, defaults to TInput1. Must satisfy ValidTensorType constraint.
TOutputThe data type of the output tensor elements, defaults to TInput1. Must satisfy ValidFloatTensorType constraint.

Member Typedef Documentation

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
using Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, HostMemoryResource>

Memory resource type based on device type.

This type alias automatically selects the appropriate memory resource type based on the template device type. For CUDA devices, it uses CudaMemoryResource; for CPU devices, it uses HostMemoryResource. This ensures memory allocation is performed correctly for the target device.

Constructor & Destructor Documentation

◆ BinaryOperation() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::BinaryOperation ( OperationType  operation_type)
inline

Constructs a BinaryOperation with the specified operation type and precision policy.

Creates a device context that matches the TDeviceType template parameter using the CreateCompatibleContext helper function. This constructor simplifies creation when a custom device context is not needed. The precision policy controls how the operation handles mixed precision computations.

Parameters
operation_typeThe type of the operation from the OperationType enumeration.
precision_policyThe compute precision policy to use. Controls the balance between performance and accuracy for mixed precision operations. Defaults to Auto which lets the implementation decide based on hardware.

◆ BinaryOperation() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::BinaryOperation ( OperationType  operation_type,
std::shared_ptr< DeviceContext context 
)
inline

Constructs a BinaryOperation with the specified operation type, device context, and precision policy.

Validates that the provided context is compatible with the TDeviceType template parameter. This allows for more control over the execution environment by providing a pre-configured device context. The precision policy controls how the operation handles mixed precision computations.

Parameters
operation_typeThe type of the operation from the OperationType enumeration.
contextThe device context to use for this operation. Must be compatible with TDeviceType.
precision_policyThe compute precision policy to use. Controls the balance between performance and accuracy for mixed precision operations. Defaults to Auto which lets the implementation decide based on hardware.
Exceptions
std::runtime_errorIf the provided context is incompatible with TDeviceType.

◆ ~BinaryOperation()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
virtual Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::~BinaryOperation ( )
virtualdefault

Virtual destructor for proper cleanup of derived classes.

Ensures proper cleanup of derived class resources when destroyed through a base class pointer. Default implementation is sufficient for this base class.

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
virtual void Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::backward ( const Tensor< TInput1, MR > &  input1,
const Tensor< TInput2, MR > &  input2,
const Tensor< TOutput, MR > &  output,
const Tensor< TOutput, MR > &  output_gradient,
const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > &  parameters,
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &  parameter_gradients,
Tensor< TInput1, MR > &  input1_gradient,
Tensor< TInput2, MR > &  input2_gradient,
const OperationAttributes attributes,
const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &  output_state 
) const
inlinevirtual

Executes the backward pass of a binary operation.

Computes gradients with respect to both inputs and parameters by propagating the output gradient backward through the operation. Derived classes may override this method to define their specific backward computation.

The operation's precision policy may affect how the gradient computation is performed, balancing between performance and accuracy based on the policy setting.

The default implementation throws an exception indicating that the operation does not support a backward pass.

Parameters
input1First input tensor from the forward pass.
input2Second input tensor from the forward pass.
outputOutput tensor from the forward pass.
output_gradientGradient of the loss with respect to the output.
parametersParameters tensor from forward pass.
parameter_gradientsOutput vector where parameter gradients will be stored.
input1_gradientOutput tensor where gradients for the first input will be stored.
input2_gradientOutput tensor where gradients for the second input will be stored.
attributesConfiguration settings that control the operation's behavior.
output_stateCache tensors from forward pass.
Exceptions
std::runtime_errorIf the operation does not support backward pass.

◆ forward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
virtual void Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::forward ( const Tensor< TInput1, MR > &  input1,
const Tensor< TInput2, MR > &  input2,
const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > &  parameters,
const OperationAttributes attributes,
Tensor< TOutput, MR > &  output,
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &  output_state 
) const
pure virtual

Executes the forward pass of a binary operation.

Performs the computation defined by the specific binary operation, transforming the two input tensors into an output tensor according to the operation's rules. Derived classes must implement this method to define their specific computation. The method also supports additional parameters and operation-specific attributes.

The operation's precision policy may affect how the computation is performed, balancing between performance and accuracy based on the policy setting.

Parameters
input1The first input tensor to the operation.
input2The second input tensor to the operation.
parametersOptional operation-specific learnable parameters (e.g., weights, biases).
attributesConfiguration settings that control the operation's behavior.
outputPre-allocated tensor where the operation results will be stored.
output_stateOptional cache for intermediate values needed during backward pass.

The documentation for this class was generated from the following file: