Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Network< TDeviceType, TPrecision > Class Template Referenceabstractexport

Root composite network container. More...

Inheritance diagram for Mila::Dnn::Network< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::Network< TDeviceType, TPrecision >:

Public Types

using ComponentPtr = typename CompositeBase::ComponentPtr
using CompositeBase = CompositeComponent<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 Network (const std::string &name)
 Construct network (context managed by derived class).
 ~Network () override=default
template<typename TOptimizer, typename TConfig>
std::shared_ptr< TOptimizer > createOptimizer (const TConfig &config)
 Create and configure an optimizer for this network's parameters.
DeviceId getDeviceId () const noexcept
 Get the compute device for this composite.
const ComponentType getType () const override
 Get the component type identifier.
void save (ModelArchive &archive, SerializationMode mode) const
 Save network to archive.
void synchronize () override
 Synchronize all child components.
std::string toString () const override
 Generate a human-readable description.
Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Protected Member Functions

virtual void save_ (ModelArchive &archive, SerializationMode mode) const =0
 Hook for concrete classes to save type-specific state.
void verifyArchitectureCompatibility (const PretrainedMetadata &metadata)
 Verify that imported model is compatible with network architecture.
Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

size_t parseLayerIndex (const std::string &name)
 Parse layer index from tensor name.
std::pair< std::vector< std::string >, std::string > parseTensorName (const std::string &tensor_name)
 Parse tensor name into component path and parameter name.
void saveComponentGraph (ModelArchive &archive, SerializationMode mode) const
 Save component graph topology.
void saveNetworkMetadata (ModelArchive &archive, SerializationMode mode) const
 Save base network metadata.

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
class Mila::Dnn::Network< TDeviceType, TPrecision >

Root composite network container.

Network is a specialized CompositeComponent that represents a complete neural network model and serves as the top-level entry point. It provides high-level serialization semantics while delegating lifecycle management to concrete subclasses.

Ownership Model:

Construction Pattern (Concrete Subclass):

class MnistClassifier : public Network<DeviceType::Cpu, TensorDataType::FP32>
{
public:
explicit MnistClassifier(const std::string& name, int64_t batch_size, DeviceId device_id)
: Network(name),
owned_context_(createExecutionContext(device_id)),
batch_size_(batch_size)
{
// 1. Create component graph (context-independent)
createGraph();
// 2. Propagate context to self and children
this->setExecutionContext(owned_context_.get());
}
private:
std::unique_ptr<IExecutionContext> owned_context_; // Concrete class owns context
};
void setExecutionContext(IExecutionContext *context)
Set the execution context for this component.
Definition Component.ixx:595
Network(const std::string &name)
Construct network (context managed by derived class).
Definition Network.ixx:136
std::unique_ptr< IExecutionContext > createExecutionContext(DeviceId device_id)
Create execution context for specified device.
Definition ExecutionContextFactory.ixx:23
Lightweight identifier for a compute device.
Definition DeviceId.ixx:38

Serialization Contract:

  • Base class (Network): Saves component graph topology and generic metadata
  • Concrete class: MUST override save_() to write type identifier and configuration
  • Concrete class: MUST provide static Load() factory method for deserialization

Deserialization Pattern (Concrete Subclass):

// REQUIRED: Static factory method for type-safe deserialization
static std::unique_ptr<MnistClassifier> Load(ModelArchive& archive, DeviceId device_id)
{
// 1. Read concrete-specific metadata
json meta = archive.readJson("network/classifier_meta.json");
std::string name = meta.at("name");
int64_t batch_size = meta.at("batch_size");
// 2. Construct via normal constructor path
auto classifier = std::make_unique<MnistClassifier>(name, batch_size, device_id);
// 3. Build with saved input shape
shape_t input_shape = meta.at("input_shape");
classifier->build(input_shape);
// 4. Load component weights
// (Base class handles graph traversal; weights loaded into already-built components)
return classifier;
}
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143
nlohmann::json json
Definition Linear.ixx:57

Design Rationale:

  • Concrete classes control infrastructure (context) lifecycle
  • Network base class focuses on container semantics and serialization
  • Clear initialization order: create context ? pass to base ? build graph
  • Enables future flexibility (custom contexts, multi-device, etc.)
  • Type-safe deserialization via concrete class Load() methods

Member Typedef Documentation

◆ ComponentPtr

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Network< TDeviceType, TPrecision >::ComponentPtr = typename CompositeBase::ComponentPtr

◆ CompositeBase

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Network< TDeviceType, TPrecision >::CompositeBase = CompositeComponent<TDeviceType, TPrecision>

◆ MR

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Network< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource

Constructor & Destructor Documentation

◆ Network()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Network< TDeviceType, TPrecision >::Network ( const std::string & name)
inlineexplicit

Construct network (context managed by derived class).

Base constructor for concrete network classes. Derived classes are responsible for:

  1. Creating and owning ExecutionContext
  2. Building the component graph via createGraph()
  3. Calling setExecutionContext() to propagate context to children
Parameters
nameNetwork name for identification and serialization
Exceptions
std::invalid_argumentif name is not a valid identifier

◆ ~Network()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Network< TDeviceType, TPrecision >::~Network ( )
overridedefault

Member Function Documentation

◆ createOptimizer()

template<DeviceType TDeviceType, TensorDataType TPrecision>
template<typename TOptimizer, typename TConfig>
std::shared_ptr< TOptimizer > Mila::Dnn::Network< TDeviceType, TPrecision >::createOptimizer ( const TConfig & config)
inline

Create and configure an optimizer for this network's parameters.

Factory method that creates an optimizer, enables training mode on the network, and registers all network parameters and gradients in a single atomic operation.

Lifecycle:

  1. Enables training mode (allocates gradients for all components)
  2. Creates optimizer using network's ExecutionContext
  3. Registers all network parameters and gradients
  4. Returns ready-to-use optimizer

Usage Pattern:

// Build network
mnist_net->build(input_shape);
// Create optimizer in one step
auto optimizer = mnist_net->createOptimizer<AdamWOptimizer<DeviceType::Cuda, TensorDataType::FP32>>(
.withWeightDecay(0.01f)
);
// Optimizer is ready to use
optimizer->step();
decltype(auto) withLearningRate(this Self &&self, float learning_rate) noexcept
Definition AdamWConfig.ixx:36
Mila::Dnn::Optimizers::AdamWConfig AdamWConfig
Definition CpuAdamWOptimizer.ixx:37
Template Parameters
TOptimizerOptimizer type (e.g., AdamWOptimizer, SGD)
TConfigOptimizer configuration type
Parameters
configOptimizer configuration
Returns
Shared pointer to configured and ready-to-use optimizer
Exceptions
std::runtime_errorif network is not built
std::runtime_errorif parameter/gradient count mismatch
Note
This method automatically calls setTraining(true), so explicit training mode activation is not required.
Here is the call graph for this function:

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TPrecision>
DeviceId Mila::Dnn::Network< TDeviceType, TPrecision >::getDeviceId ( ) const
inlinevirtualnoexcept

Get the compute device for this composite.

Returns the device from the shared execution context.

Returns
DeviceId for this composite and its children

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const ComponentType Mila::Dnn::Network< TDeviceType, TPrecision >::getType ( ) const
inlineoverridevirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ parseLayerIndex()

template<DeviceType TDeviceType, TensorDataType TPrecision>
size_t Mila::Dnn::Network< TDeviceType, TPrecision >::parseLayerIndex ( const std::string & name)
inlineprivate

Parse layer index from tensor name.

Extracts the layer number from names like "layers.5.attention.q_proj.weight"

◆ parseTensorName()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::pair< std::vector< std::string >, std::string > Mila::Dnn::Network< TDeviceType, TPrecision >::parseTensorName ( const std::string & tensor_name)
inlineprivate

Parse tensor name into component path and parameter name.

Examples: "wte.weight" -> ([], "wte.weight") "tf.layer_0.ln_1.weight" -> (["tf", "layer_0", "ln_1"], "weight") "tf.layer_0.fc_qkv_proj.bias" -> (["tf", "layer_0", "fc_qkv_proj"], "bias")

◆ save()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Network< TDeviceType, TPrecision >::save ( ModelArchive & archive,
SerializationMode mode ) const
inline

Save network to archive.

Saves component graph structure and delegates to concrete class via save_() hook for type-specific configuration.

Archive structure produced:

  • network/meta.json: Base metadata (name, version, num_components, timestamp)
  • network/architecture.json: Component topology (names, paths, ordering)
  • components/<name>/...: Child component state (recursive)
  • Concrete class writes additional files via save_() override
Parameters
archiveArchive to write to
modeSerialization mode (Checkpoint, WeightsOnly, Architecture)
Exceptions
std::runtime_errorif save_() is not overridden by concrete class
Here is the call graph for this function:

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Network< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
protectedpure virtual

Hook for concrete classes to save type-specific state.

REQUIRED override for concrete networks. Must write:

  • Type identifier (e.g., "type": "MnistClassifier")
  • Configuration parameters (batch_size, architecture constants)
  • Shape metadata (for validation during Load())

This metadata enables the concrete class's Load() method to reconstruct the network.

Example implementation:

void save_(ModelArchive& archive, SerializationMode mode) const override
{
json meta;
meta["type"] = "MnistClassifier"; // Type identifier for runtime dispatch
meta["batch_size"] = batch_size_;
meta["input_shape"] = leading_shape_;
// ... other configuration
archive.writeJson("network/classifier_meta.json", meta);
}
virtual void save_(ModelArchive &archive, SerializationMode mode) const =0
Hook for concrete classes to save type-specific state.
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
Parameters
archiveArchive to write to
modeSerialization mode (passed from save())

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Implemented in Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, and Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >.

Here is the caller graph for this function:

◆ saveComponentGraph()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Network< TDeviceType, TPrecision >::saveComponentGraph ( ModelArchive & archive,
SerializationMode mode ) const
inlineprivate

Save component graph topology.

Writes the component manifest (list of child components) and recursively saves each component's state with scoped namespacing.

Archive structure:

  • network/architecture.json: Component manifest metadata
  • network/components_list.json: Array of component names (for ordering)
  • network/component_<name>.json: Individual component descriptor
  • components/<name>/...: Component state (via recursive save_)

Components are saved in deterministic (sorted by name) order for reproducible archives.

Parameters
archiveArchive to write to
modeSerialization mode (passed to children)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ saveNetworkMetadata()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Network< TDeviceType, TPrecision >::saveNetworkMetadata ( ModelArchive & archive,
SerializationMode mode ) const
inlineprivate

Save base network metadata.

Writes generic metadata that applies to all networks:

  • format_version: Archive format version (for compatibility checking)
  • name: Network name
  • num_components: Component count (for validation)
  • mode: Serialization mode (Checkpoint/WeightsOnly/Architecture)
  • export_time: Unix timestamp of serialization
Parameters
archiveArchive to write to
modeSerialization mode
Here is the call graph for this function:
Here is the caller graph for this function:

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Network< TDeviceType, TPrecision >::synchronize ( )
inlineoverridevirtual

Synchronize all child components.

Waits for outstanding device operations on all children.

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::Network< TDeviceType, TPrecision >::toString ( ) const
inlineoverridevirtual

Generate a human-readable description.

Returns
String representation showing network name and children

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ verifyArchitectureCompatibility()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Network< TDeviceType, TPrecision >::verifyArchitectureCompatibility ( const PretrainedMetadata & metadata)
inlineprotected

Verify that imported model is compatible with network architecture.


The documentation for this class was generated from the following file: