|
| UnaryOperation (OperationType operation_type) |
| Constructs a UnaryOperation with the specified operation type.
|
|
| UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs a UnaryOperation with the specified operation type and device context.
|
|
virtual | ~UnaryOperation ()=default |
| Virtual destructor for proper cleanup of derived classes.
|
|
virtual void | backward (const Tensor< TInput, MR > &grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_grads) const |
| Executes the backward pass of a unary operation.
|
|
virtual void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meter_grads, Tensor< TInput, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const |
| Executes the comprehensive backward pass of a unary operation.
|
|
virtual void | forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0 |
| Executes the forward pass of a unary operation.
|
|
| OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs an OperationBase object with a specific device context and compute precision.
|
|
virtual | ~OperationBase ()=default |
| Virtual destructor for the OperationBase class.
|
|
std::shared_ptr< DeviceContext > | getDeviceContext () const |
| Gets the device context associated with this operation.
|
|
DeviceType | getDeviceType () const |
| Gets the device type for this operation.
|
|
virtual std::string | getName () const =0 |
| Gets the name of the operation.
|
|
OperationType | getOperationType () const |
| Gets the operation type enumeration value.
|
|
template<
DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
requires ValidTensorType<TInput>&& ValidFloatTensorType<TOutput>
class Mila::Dnn::Compute::UnaryOperation< TDeviceType, TInput, TOutput >
Abstract base class for unary operations in the compute framework.
The UnaryOperation
class defines the interface for operations that take a single input tensor and produce a single output tensor. Derived classes must implement the forward()
method for the forward pass and may optionally override the backward()
method for the backward pass.
Additional methods for shape validation and parameter initialization are provided to ensure correctness and flexibility in derived classes.
- Template Parameters
-
TInput | The data type of the input tensor elements. |
TOutput | The data type of the output tensor elements. Defaults to TInput. |
TDeviceType | The device type (e.g., CPU, CUDA) on which the operation is executed. |
template<
DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
Executes the backward pass of a unary operation.
This is the simplified version of the backward pass that only computes gradients for the parameters, not for the input.
Derived classes may override this method to define the backward computation. The default implementation throws an exception, indicating that the operation does not support a backward pass.
- Parameters
-
grad | The gradient tensor from the next layer in the network. |
parameters | The parameters used during the forward pass. |
output_grads | Output vector where parameter gradients will be stored. |
- Exceptions
-
std::runtime_error | If the operation does not support backward pass. |