Mila
Deep Neural Network Library
|
A registry for operations that can be created based on operation names, type information, and device type. More...
Classes | |
struct | TypeID |
Type ID structure to uniquely identify operations based on input types and device type. More... | |
struct | TypeIDHash |
Hash function for TypeID to use in unordered_map. More... | |
Public Member Functions | |
template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput > requires ValidTensorTypes<TInput1, TInput2> && ValidFloatTensorType<TOutput> | |
std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > > | createBinaryOperation (const std::string &operation_name, std::shared_ptr< DeviceContext > context, const ComponentConfig &config) const |
Create a binary operation based on the type information, device type, and operation name. | |
template<DeviceType TDeviceType, typename TInput , typename TOutput > | |
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > | createUnaryOperation (const std::string &operation_name, std::shared_ptr< DeviceContext > context, const ComponentConfig &config) const |
Create a unary operation based on the type information, device type, and operation name. | |
std::optional< FusedOpMeta > | findFusedMatch (const std::vector< std::string > &child_types, DeviceType device_type, std::type_index precision_type, const std::string &variant="Default") |
Find a fused operation match for a sequence of module types. | |
template<DeviceType TDeviceType, typename TInput1 , typename TInput2 , typename TOutput > requires ValidTensorTypes<TInput1, TInput2>&& ValidFloatTensorType<TOutput> | |
void | registerBinaryOperation (const std::string &operation_name, std::function< std::shared_ptr< BinaryOperation< TDeviceType, TInput1, TInput2, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)> creator) |
Register a binary operation creator for a specific device type. | |
template<typename TOutput > | |
void | registerFusedOperation (const std::vector< OperationType > &operation_types, const std::string &fused_op_name, const std::string &variant="Default") |
Register a fused operation. | |
template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = float, typename TOutput = TInput> requires ValidTensorType<TInput>&& ValidFloatTensorType<TOutput> | |
void | registerUnaryOperation (const std::string &operation_name, std::function< std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > >(std::shared_ptr< DeviceContext >, const ComponentConfig &)> creator) |
Register a unary operation creator for a specific device type. | |
Static Public Member Functions | |
static OperationRegistry & | instance () |
Get the singleton instance of the OperationRegistry. | |
Private Types | |
using | GenericCreator = std::function< std::shared_ptr< void >(std::shared_ptr< DeviceContext >, const ComponentConfig &)> |
Type alias for a generic operation creator function. | |
Private Member Functions | |
OperationRegistry ()=default | |
Default constructor. | |
Private Attributes | |
std::vector< FusedOpMeta > | fused_ops_ |
List of registered fused operations. | |
std::unordered_map< TypeID, std::unordered_map< std::string, GenericCreator >, TypeIDHash > | registry_ |
Registry mapping type info to operations. | |
A registry for operations that can be created based on operation names, type information, and device type.
This singleton class manages the registration and creation of neural network operations in the Mila framework. It provides a unified interface to access different operation implementations across various device types (CPU, CUDA) and with different data types.
The registry supports:
|
private |
Type alias for a generic operation creator function.
|
privatedefault |
Default constructor.
|
inline |
Create a binary operation based on the type information, device type, and operation name.
TDeviceType | The device type for the operation (defaults to CUDA). |
TInput1 | The first input tensor element type. |
TInput2 | The second input tensor element type (defaults to TInput1). |
TOutput | The output tensor element type (defaults to TInput2). |
operation_name | The name of the operation. |
context | The device context to use for the operation. |
precision_policy | The compute precision policy to use. |
std::runtime_error | If the type combination, device type, or operation name is invalid. |
std::invalid_argument | If the context is null. |
|
inline |
Create a unary operation based on the type information, device type, and operation name.
TDeviceType | The device type for the operation (defaults to CUDA). |
TInput | The input tensor element type. |
TOutput | The output tensor element type (defaults to TInput). |
operation_name | The name of the operation. |
context | The device context to use for the operation. |
precision_policy | The compute precision policy to use. |
std::runtime_error | If the type combination, device type, or operation name is invalid. |
std::invalid_argument | If the context is null. |
|
inline |
Find a fused operation match for a sequence of module types.
child_types | The sequence of module types to match. |
device_type | The device type. |
precision_type | The precision type. |
variant | The variant of the operation (defaults to "Default"). |
|
inlinestatic |
Get the singleton instance of the OperationRegistry.
|
inline |
Register a binary operation creator for a specific device type.
TDeviceType | The device type for the operation (defaults to CUDA). |
TInput1 | The first input tensor element type. |
TInput2 | The second input tensor element type (defaults to TInput1). |
TOutput | The output tensor element type (defaults to TInput2). |
operation_name | The name of the operation. |
creator | The function that creates the binary operation. |
precision_policy | The compute precision policy to use (defaults to Auto). |
|
inline |
Register a fused operation.
TOutput | The precision type of the operation. |
operation_types | The sequence of operation types to fuse. |
fused_op_name | The name of the fused operation. |
variant | The variant of the operation (defaults to "Default"). |
|
inline |
Register a unary operation creator for a specific device type.
TDeviceType | The device type for the operation (defaults to CUDA). |
TInput | The input tensor element type. |
TOutput | The output tensor element type (defaults to TInput). |
operation_name | The name of the operation. |
creator | The function that creates the unary operation. |
precision_policy | The compute precision policy to use (defaults to Auto). |
|
private |
List of registered fused operations.
|
private |
Registry mapping type info to operations.