Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::NetworkFactory Class Referenceexport

Factory registry for Network deserialization. More...

Public Types

template<DeviceType TDeviceType, TensorDataType TPrecision>
using NetworkFactoryFunc

Static Public Member Functions

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
static std::unique_ptr< Network< TDeviceType, TPrecision > > create (ModelArchive &archive, std::shared_ptr< ExecutionContext< TDeviceType > > exec_context)
template<DeviceType TDeviceType, TensorDataType TPrecision>
static void registerNetwork (const std::string &network_type, NetworkFactoryFunc< TDeviceType, TPrecision > factory)

Static Private Member Functions

template<DeviceType TDeviceType, TensorDataType TPrecision>
static std::unordered_map< std::string, NetworkFactoryFunc< TDeviceType, TPrecision > > & getRegistry ()
 Get the registry for a specific device type and precision.

Detailed Description

Factory registry for Network deserialization.

Provides type-safe network reconstruction from archives using registered factory functions. Each concrete network type registers its own Load() method to enable polymorphic deserialization.

Design Pattern:

  • Registration: Concrete networks register factory functions at startup
  • Dispatch: Factory reads metadata to determine network type and precision
  • Construction: Invokes appropriate registered factory function

Usage:

// Registration (typically in network implementation file)
"MnistClassifier",
[](ModelArchive& archive, auto exec_ctx) {
return MnistClassifier::Load(archive, exec_ctx->getDeviceId());
});
// Deserialization
archive, exec_context);
static void registerNetwork(const std::string &network_type, NetworkFactoryFunc< TDeviceType, TPrecision > factory)
Definition NetworkFactory.ixx:81
static std::unique_ptr< Network< TDeviceType, TPrecision > > create(ModelArchive &archive, std::shared_ptr< ExecutionContext< TDeviceType > > exec_context)
Definition NetworkFactory.ixx:116
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/NetworkFactory.ixx.

Member Typedef Documentation

◆ NetworkFactoryFunc

Initial value:
std::function<std::unique_ptr<Network<TDeviceType, TPrecision>>(
std::shared_ptr<ExecutionContext<TDeviceType>>
)>
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/NetworkFactory.ixx.

Member Function Documentation

◆ create()

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< Network< TDeviceType, TPrecision > > Mila::Dnn::NetworkFactory::create ( ModelArchive & archive,
std::shared_ptr< ExecutionContext< TDeviceType > > exec_context )
inlinestatic
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/NetworkFactory.ixx.
Here is the call graph for this function:

◆ getRegistry()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unordered_map< std::string, NetworkFactoryFunc< TDeviceType, TPrecision > > & Mila::Dnn::NetworkFactory::getRegistry ( )
inlinestaticprivate

Get the registry for a specific device type and precision.

Returns a reference to the static registry map for the given device and precision combination. Each instantiation has its own independent registry.

Template Parameters
TDeviceTypeDevice type
TPrecisionPrecision type
Returns
Reference to the static registry map
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/NetworkFactory.ixx.
Here is the caller graph for this function:

◆ registerNetwork()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::NetworkFactory::registerNetwork ( const std::string & network_type,
NetworkFactoryFunc< TDeviceType, TPrecision > factory )
inlinestatic
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/NetworkFactory.ixx.
Here is the call graph for this function:

The documentation for this class was generated from the following file: