Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::OperationRegistry Class Referenceexport

A registry for operations that can be created based on operation names, type information, and device type. More...

Classes

struct  TypeID
 Type ID structure to uniquely identify operations based on input types and device type. More...
 
struct  TypeIDHash
 Hash function for TypeID to use in unordered_map. More...
 

Public Member Functions

template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput >
requires ValidTensorTypes<TInput1, TInput2> && ValidFloatTensorType<TOutput>
std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > > createBinaryOperation (const std::string &operation_name, std::shared_ptr< DeviceContext > context, const ComponentConfig &config) const
 Create a binary operation based on the type information, device type, and operation name.
 
template<DeviceType TDeviceType, typename TInput , typename TOutput >
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > createUnaryOperation (const std::string &operation_name, std::shared_ptr< DeviceContext > context, const ComponentConfig &config) const
 Create a unary operation based on the type information, device type, and operation name.
 
std::optional< FusedOpMetafindFusedMatch (const std::vector< std::string > &child_types, DeviceType device_type, std::type_index precision_type, const std::string &variant="Default")
 Find a fused operation match for a sequence of module types.
 
template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput >
requires ValidTensorTypes<TInput1, TInput2>&& ValidFloatTensorType<TOutput>
void registerBinaryOperation (const std::string &operation_name, std::function< std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)> creator)
 Register a binary operation creator for a specific device type.
 
template<typename TOutput >
void registerFusedOperation (const std::vector< OperationType > &operation_types, const std::string &fused_op_name, const std::string &variant="Default")
 Register a fused operation.
 
template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
requires ValidTensorType<TInput>&& ValidFloatTensorType<TOutput>
void registerUnaryOperation (const std::string &operation_name, std::function< std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)> creator)
 Register a unary operation creator for a specific device type.
 

Static Public Member Functions

static OperationRegistryinstance ()
 Get the singleton instance of the OperationRegistry.
 

Private Types

using GenericCreator = std::function< std::shared_ptr< void >(std::shared_ptr< DeviceContext >, const ComponentConfig &)>
 Type alias for a generic operation creator function.
 

Private Member Functions

 OperationRegistry ()=default
 Default constructor.
 

Private Attributes

std::vector< FusedOpMetafused_ops_
 List of registered fused operations.
 
std::unordered_map< TypeID, std::unordered_map< std::string, GenericCreator >, TypeIDHashregistry_
 Registry mapping type info to operations.
 

Detailed Description

A registry for operations that can be created based on operation names, type information, and device type.

This singleton class manages the registration and creation of neural network operations in the Mila framework. It provides a unified interface to access different operation implementations across various device types (CPU, CUDA) and with different data types.

The registry supports:

  • Type-safe registration and lookup of operations
  • Operation variants for specialized implementations
  • Fused operations for performance optimization
  • Automatic device-specific operation selection

Member Typedef Documentation

◆ GenericCreator

using Mila::Dnn::Compute::OperationRegistry::GenericCreator = std::function<std::shared_ptr<void>( std::shared_ptr<DeviceContext>, const ComponentConfig& )>
private

Type alias for a generic operation creator function.

Constructor & Destructor Documentation

◆ OperationRegistry()

Mila::Dnn::Compute::OperationRegistry::OperationRegistry ( )
privatedefault

Default constructor.

Member Function Documentation

◆ createBinaryOperation()

template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput >
requires ValidTensorTypes<TInput1, TInput2> && ValidFloatTensorType<TOutput>
std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > > Mila::Dnn::Compute::OperationRegistry::createBinaryOperation ( const std::string &  operation_name,
std::shared_ptr< DeviceContext context,
const ComponentConfig config 
) const
inline

Create a binary operation based on the type information, device type, and operation name.

Template Parameters
TDeviceTypeThe device type for the operation (defaults to CUDA).
TInput1The first input tensor element type.
TInput2The second input tensor element type (defaults to TInput1).
TOutputThe output tensor element type (defaults to TInput2).
Parameters
operation_nameThe name of the operation.
contextThe device context to use for the operation.
precision_policyThe compute precision policy to use.
Returns
std::shared_ptr<BinaryOperation<TDeviceType, TInput1, TInput2, TOutput>> The created binary operation.
Exceptions
std::runtime_errorIf the type combination, device type, or operation name is invalid.
std::invalid_argumentIf the context is null.
Here is the caller graph for this function:

◆ createUnaryOperation()

template<DeviceType TDeviceType, typename TInput , typename TOutput >
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > Mila::Dnn::Compute::OperationRegistry::createUnaryOperation ( const std::string &  operation_name,
std::shared_ptr< DeviceContext context,
const ComponentConfig config 
) const
inline

Create a unary operation based on the type information, device type, and operation name.

Template Parameters
TDeviceTypeThe device type for the operation (defaults to CUDA).
TInputThe input tensor element type.
TOutputThe output tensor element type (defaults to TInput).
Parameters
operation_nameThe name of the operation.
contextThe device context to use for the operation.
precision_policyThe compute precision policy to use.
Returns
std::shared_ptr<UnaryOperation<TDeviceType, TInput, TOutput>> The created unary operation.
Exceptions
std::runtime_errorIf the type combination, device type, or operation name is invalid.
std::invalid_argumentIf the context is null.
Here is the caller graph for this function:

◆ findFusedMatch()

std::optional< FusedOpMeta > Mila::Dnn::Compute::OperationRegistry::findFusedMatch ( const std::vector< std::string > &  child_types,
DeviceType  device_type,
std::type_index  precision_type,
const std::string &  variant = "Default" 
)
inline

Find a fused operation match for a sequence of module types.

Parameters
child_typesThe sequence of module types to match.
device_typeThe device type.
precision_typeThe precision type.
variantThe variant of the operation (defaults to "Default").
Returns
std::optional<FusedOpMeta> The matched fused operation metadata if found, or nullopt if no match.

◆ instance()

static OperationRegistry & Mila::Dnn::Compute::OperationRegistry::instance ( )
inlinestatic

Get the singleton instance of the OperationRegistry.

Returns
OperationRegistry& The singleton instance.
Here is the caller graph for this function:

◆ registerBinaryOperation()

template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput >
requires ValidTensorTypes<TInput1, TInput2>&& ValidFloatTensorType<TOutput>
void Mila::Dnn::Compute::OperationRegistry::registerBinaryOperation ( const std::string &  operation_name,
std::function< std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)>  creator 
)
inline

Register a binary operation creator for a specific device type.

Template Parameters
TDeviceTypeThe device type for the operation (defaults to CUDA).
TInput1The first input tensor element type.
TInput2The second input tensor element type (defaults to TInput1).
TOutputThe output tensor element type (defaults to TInput2).
Parameters
operation_nameThe name of the operation.
creatorThe function that creates the binary operation.
precision_policyThe compute precision policy to use (defaults to Auto).
Here is the caller graph for this function:

◆ registerFusedOperation()

template<typename TOutput >
void Mila::Dnn::Compute::OperationRegistry::registerFusedOperation ( const std::vector< OperationType > &  operation_types,
const std::string &  fused_op_name,
const std::string &  variant = "Default" 
)
inline

Register a fused operation.

Template Parameters
TOutputThe precision type of the operation.
Parameters
operation_typesThe sequence of operation types to fuse.
fused_op_nameThe name of the fused operation.
variantThe variant of the operation (defaults to "Default").
Here is the call graph for this function:

◆ registerUnaryOperation()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput>
requires ValidTensorType<TInput>&& ValidFloatTensorType<TOutput>
void Mila::Dnn::Compute::OperationRegistry::registerUnaryOperation ( const std::string &  operation_name,
std::function< std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)>  creator 
)
inline

Register a unary operation creator for a specific device type.

Template Parameters
TDeviceTypeThe device type for the operation (defaults to CUDA).
TInputThe input tensor element type.
TOutputThe output tensor element type (defaults to TInput).
Parameters
operation_nameThe name of the operation.
creatorThe function that creates the unary operation.
precision_policyThe compute precision policy to use (defaults to Auto).
Here is the caller graph for this function:

Member Data Documentation

◆ fused_ops_

std::vector<FusedOpMeta> Mila::Dnn::Compute::OperationRegistry::fused_ops_
private

List of registered fused operations.

◆ registry_

std::unordered_map<TypeID, std::unordered_map<std::string, GenericCreator>, TypeIDHash> Mila::Dnn::Compute::OperationRegistry::registry_
private

Registry mapping type info to operations.


The documentation for this class was generated from the following file: