Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput > Class Template Referenceabstractexport

Base class for all compute operations in the Mila neural network framework. More...

Inheritance diagram for Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >:

Public Member Functions

 OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs an OperationBase object with a specific device context and compute precision.
 
virtual ~OperationBase ()=default
 Virtual destructor for the OperationBase class.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Gets the device context associated with this operation.
 
DeviceType getDeviceType () const
 Gets the device type for this operation.
 
virtual std::string getName () const =0
 Gets the name of the operation.
 
OperationType getOperationType () const
 Gets the operation type enumeration value.
 

Private Attributes

std::shared_ptr< DeviceContextdevice_context_
 The device context for execution.
 
OperationType operation_type_
 The operation type identifier.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
requires ValidTensorTypes<TInput1, TInput2> && ValidFloatTensorType<TOutput>
class Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >

Base class for all compute operations in the Mila neural network framework.

This abstract base class defines the common interface for all operations that can be performed in the neural network computation graph, regardless of the device type (CPU, CUDA, etc). Specific operations inherit from this class and implement their specialized behavior while adhering to a consistent interface.

Template Parameters
TInput1The data type of the first input tensor elements. Must satisfy ValidTensorType constraint.
TInput2The data type of the second input tensor elements, defaults to TInput1. Must satisfy ValidTensorType constraint.
TOutputThe data type of the output tensor elements, defaults to TInput1. Must satisfy ValidFloatTensorType constraint.
TDeviceTypeThe target device type for the operation, defaults to DeviceType::Cuda.

Constructor & Destructor Documentation

◆ OperationBase()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::OperationBase ( OperationType  operation_type,
std::shared_ptr< DeviceContext context 
)
inline

Constructs an OperationBase object with a specific device context and compute precision.

Initializes the operation with the specified operation type and device context, using the template parameter-specified compute precision.

Parameters
operation_typeThe type of the operation (from OperationType enum).
contextThe device context to use for this operation. Must not be null.
precision_policyThe compute precision policy to use for this operation.
Exceptions
std::invalid_argumentMay throw if context is null (implementation dependent).

◆ ~OperationBase()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
virtual Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::~OperationBase ( )
virtualdefault

Virtual destructor for the OperationBase class.

Ensures proper cleanup of derived class resources when destroyed through a base class pointer. Default implementation is sufficient for this base class.

Member Function Documentation

◆ getDeviceContext()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
std::shared_ptr< DeviceContext > Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::getDeviceContext ( ) const
inline

Gets the device context associated with this operation.

The device context contains information about the execution environment, including the device, streams, and memory resources. This context is used for all device interactions performed by this operation.

Returns
std::shared_ptr<DeviceContext> The device context for this operation.
Here is the caller graph for this function:

◆ getDeviceType()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
DeviceType Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::getDeviceType ( ) const
inline

Gets the device type for this operation.

This is a convenience method that retrieves the device type from the associated device context. It delegates to the device context's device to determine the actual hardware target.

Returns
DeviceType The type of device (CPU, CUDA, etc.) for this operation.
Here is the caller graph for this function:

◆ getName()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
virtual std::string Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::getName ( ) const
pure virtual

◆ getOperationType()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
OperationType Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::getOperationType ( ) const
inline

Gets the operation type enumeration value.

Returns the operation type that was specified during construction. This identifies the category of neural network operation being performed.

Returns
OperationType The enumeration value identifying this operation's category.

Member Data Documentation

◆ device_context_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
std::shared_ptr<DeviceContext> Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::device_context_
private

The device context for execution.

◆ operation_type_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput1 = float, typename TInput2 = TInput1, typename TOutput = TInput1>
OperationType Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >::operation_type_
private

The operation type identifier.


The documentation for this class was generated from the following file: