Mila
Deep Neural Network Library
|
Abstract class for binary operations in the neural network framework. More...
Public Types | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource > |
Memory resource type based on device type. | |
Public Member Functions | |
BinaryOperation (OperationType operation_type) | |
Constructs a BinaryOperation with the specified operation type and precision policy. | |
BinaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context) | |
Constructs a BinaryOperation with the specified operation type, device context, and precision policy. | |
virtual | ~BinaryOperation ()=default |
Virtual destructor for proper cleanup of derived classes. | |
virtual void | backward (const Tensor< TInput1, MR > &input1, const Tensor< TInput2, MR > &input2, const Tensor< TOutput, MR > &output, const Tensor< TOutput, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meter_gradients, Tensor< TInput1, MR > &input1_gradient, Tensor< TInput2, MR > &input2_gradient, const OperationAttributes &attributes, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const |
Executes the backward pass of a binary operation. | |
virtual void | forward (const Tensor< TInput1, MR > &input1, const Tensor< TInput2, MR > &input2, const std::vector< std::shared_ptr< Tensor< TInput1, MR > > > ¶meters, const OperationAttributes &attributes, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0 |
Executes the forward pass of a binary operation. | |
![]() | |
OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context) | |
Constructs an OperationBase object with a specific device context and compute precision. | |
virtual | ~OperationBase ()=default |
Virtual destructor for the OperationBase class. | |
std::shared_ptr< DeviceContext > | getDeviceContext () const |
Gets the device context associated with this operation. | |
DeviceType | getDeviceType () const |
Gets the device type for this operation. | |
virtual std::string | getName () const =0 |
Gets the name of the operation. | |
OperationType | getOperationType () const |
Gets the operation type enumeration value. | |
Abstract class for binary operations in the neural network framework.
This class extends OperationBase to provide specialized functionality for operations that take two input tensors and produce a single output tensor. Derived classes must implement the forward() method to define the specific computation for the operation. Examples include element-wise operations (add, multiply), matrix operations (matmul), and more complex operations like convolution.
The class supports configurable compute precision policies which control how operations handle mixed precision computations. This allows for optimizing between performance and accuracy based on the specific requirements of the application.
TDeviceType | The target device type for the operation, defaults to DeviceType::Cuda. |
TInput1 | The data type of the first input tensor elements. Must satisfy ValidTensorType constraint. |
TInput2 | The data type of the second input tensor elements, defaults to TInput1. Must satisfy ValidTensorType constraint. |
TOutput | The data type of the output tensor elements, defaults to TInput1. Must satisfy ValidFloatTensorType constraint. |
using Mila::Dnn::Compute::BinaryOperation< TDeviceType, TInput1, TInput2, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, HostMemoryResource> |
Memory resource type based on device type.
This type alias automatically selects the appropriate memory resource type based on the template device type. For CUDA devices, it uses CudaMemoryResource; for CPU devices, it uses HostMemoryResource. This ensures memory allocation is performed correctly for the target device.
|
inline |
Constructs a BinaryOperation with the specified operation type and precision policy.
Creates a device context that matches the TDeviceType template parameter using the CreateCompatibleContext helper function. This constructor simplifies creation when a custom device context is not needed. The precision policy controls how the operation handles mixed precision computations.
operation_type | The type of the operation from the OperationType enumeration. |
precision_policy | The compute precision policy to use. Controls the balance between performance and accuracy for mixed precision operations. Defaults to Auto which lets the implementation decide based on hardware. |
|
inline |
Constructs a BinaryOperation with the specified operation type, device context, and precision policy.
Validates that the provided context is compatible with the TDeviceType template parameter. This allows for more control over the execution environment by providing a pre-configured device context. The precision policy controls how the operation handles mixed precision computations.
operation_type | The type of the operation from the OperationType enumeration. |
context | The device context to use for this operation. Must be compatible with TDeviceType. |
precision_policy | The compute precision policy to use. Controls the balance between performance and accuracy for mixed precision operations. Defaults to Auto which lets the implementation decide based on hardware. |
std::runtime_error | If the provided context is incompatible with TDeviceType. |
|
virtualdefault |
Virtual destructor for proper cleanup of derived classes.
Ensures proper cleanup of derived class resources when destroyed through a base class pointer. Default implementation is sufficient for this base class.
|
inlinevirtual |
Executes the backward pass of a binary operation.
Computes gradients with respect to both inputs and parameters by propagating the output gradient backward through the operation. Derived classes may override this method to define their specific backward computation.
The operation's precision policy may affect how the gradient computation is performed, balancing between performance and accuracy based on the policy setting.
The default implementation throws an exception indicating that the operation does not support a backward pass.
input1 | First input tensor from the forward pass. |
input2 | Second input tensor from the forward pass. |
output | Output tensor from the forward pass. |
output_gradient | Gradient of the loss with respect to the output. |
parameters | Parameters tensor from forward pass. |
parameter_gradients | Output vector where parameter gradients will be stored. |
input1_gradient | Output tensor where gradients for the first input will be stored. |
input2_gradient | Output tensor where gradients for the second input will be stored. |
attributes | Configuration settings that control the operation's behavior. |
output_state | Cache tensors from forward pass. |
std::runtime_error | If the operation does not support backward pass. |
|
pure virtual |
Executes the forward pass of a binary operation.
Performs the computation defined by the specific binary operation, transforming the two input tensors into an output tensor according to the operation's rules. Derived classes must implement this method to define their specific computation. The method also supports additional parameters and operation-specific attributes.
The operation's precision policy may affect how the computation is performed, balancing between performance and accuracy based on the policy setting.
input1 | The first input tensor to the operation. |
input2 | The second input tensor to the operation. |
parameters | Optional operation-specific learnable parameters (e.g., weights, biases). |
attributes | Configuration settings that control the operation's behavior. |
output | Pre-allocated tensor where the operation results will be stored. |
output_state | Optional cache for intermediate values needed during backward pass. |