|
Mila 0.13.48
Deep Neural Network Library
|
Root composite network container. More...


Public Types | |
| using | ComponentPtr = typename CompositeBase::ComponentPtr |
| using | CompositeBase = CompositeComponent<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| using | ComponentBase = Component<TDeviceType, TPrecision> |
| using | ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>> |
Public Member Functions | |
| Network (const std::string &name) | |
| Construct network (context managed by derived class). | |
| ~Network () override=default | |
| template<typename TOptimizer, typename TConfig> | |
| std::shared_ptr< TOptimizer > | createOptimizer (const TConfig &config) |
| Create and configure an optimizer for this network's parameters. | |
| DeviceId | getDeviceId () const noexcept |
| Get the compute device for this composite. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| void | save (ModelArchive &archive, SerializationMode mode) const |
| Save network to archive. | |
| void | synchronize () override |
| Synchronize all child components. | |
| std::string | toString () const override |
| Generate a human-readable description. | |
| Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| CompositeComponent (CompositeComponent &&) noexcept=default | |
| CompositeComponent (const CompositeComponent &)=delete | |
| CompositeComponent (const std::string &name) | |
| Construct composite component with name. | |
| virtual | ~CompositeComponent ()=default |
| CompositeComponent & | addComponent (ComponentPtr component) |
| Add a pre-constructed child component (chainable). | |
| size_t | childCount () const noexcept |
| Get the number of direct children. | |
| void | clearComponents () |
| Clear all child components. | |
| ComponentPtr | findComponent (const std::string &path) const |
| Resolve a dot-separated component path within this composite. | |
| ComponentPtr | getComponent (const std::string &name) const |
| Retrieve a direct child component by name. | |
| const std::vector< ComponentPtr > & | getComponents () const |
| Get all child components in insertion order. | |
| std::vector< ITensor * > | getGradients () const override |
| Get all parameter gradients from all children. | |
| std::vector< ITensor * > | getParameters () const override |
| Get all parameters from all children. | |
| bool | hasChildren () const noexcept |
| Check if this composite has any children. | |
| bool | hasComponent (const std::string &name) const |
| Check if a named child component exists. | |
| CompositeComponent & | operator= (CompositeComponent &&) noexcept=default |
| CompositeComponent & | operator= (const CompositeComponent &)=delete |
| size_t | parameterCount () const override |
| Count parameters across all children. | |
| bool | removeComponent (const std::string &name) |
| Get the named child components map. | |
| ComponentPtr | tryFindComponent (const std::string &path) const |
| Try to resolve a dot-separated component path within this composite. | |
| Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| virtual MemoryStats | getMemoryStats () const =0 |
| Return the current memory allocation breakdown for this component. | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
| virtual void | zeroGradients () |
| Clear all model-owned gradients for this component. | |
Protected Member Functions | |
| virtual void | save_ (ModelArchive &archive, SerializationMode mode) const =0 |
| Hook for concrete classes to save type-specific state. | |
| void | verifyArchitectureCompatibility (const PretrainedMetadata &metadata) |
| Verify that imported model is compatible with network architecture. | |
| Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| template<typename TComponent> | |
| std::shared_ptr< TComponent > | getComponentAs (const std::string &name) const |
| Retrieve a typed child component by name. | |
| void | onExecutionContextSet () override |
| Hook invoked after ExecutionContext is set. | |
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| Hook invoked when training mode is about to change. | |
| virtual void | optimize () |
| Virtual hook for graph optimization after construction. | |
| Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| template<TensorDataType TParameterPrecision, typename TMemoryResource> | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| virtual void | onBuilding (const BuildContext &config) |
| Hook invoked by build() to allocate component buffers. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Private Member Functions | |
| size_t | parseLayerIndex (const std::string &name) |
| Parse layer index from tensor name. | |
| std::pair< std::vector< std::string >, std::string > | parseTensorName (const std::string &tensor_name) |
| Parse tensor name into component path and parameter name. | |
| void | saveComponentGraph (ModelArchive &archive, SerializationMode mode) const |
| Save component graph topology. | |
| void | saveNetworkMetadata (ModelArchive &archive, SerializationMode mode) const |
| Save base network metadata. | |
Additional Inherited Members | |
| Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
| Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| BuildContext | build_context_ { shape_t{ 1 }, RuntimeMode::Training } |
| The BuildContext stored at build time. | |
Root composite network container.
Network is a specialized CompositeComponent that represents a complete neural network model and serves as the top-level entry point. It provides high-level serialization semantics while delegating lifecycle management to concrete subclasses.
Ownership Model:
Construction Pattern (Concrete Subclass):
Serialization Contract:
Deserialization Pattern (Concrete Subclass):
Design Rationale:
| using Mila::Dnn::Network< TDeviceType, TPrecision >::ComponentPtr = typename CompositeBase::ComponentPtr |
| using Mila::Dnn::Network< TDeviceType, TPrecision >::CompositeBase = CompositeComponent<TDeviceType, TPrecision> |
| using Mila::Dnn::Network< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
|
inlineexplicit |
Construct network (context managed by derived class).
Base constructor for concrete network classes. Derived classes are responsible for:
| name | Network name for identification and serialization |
| std::invalid_argument | if name is not a valid identifier |
|
overridedefault |
|
inline |
Create and configure an optimizer for this network's parameters.
Factory method that creates an optimizer, enables training mode on the network, and registers all network parameters and gradients in a single atomic operation.
Lifecycle:
Usage Pattern:
| TOptimizer | Optimizer type (e.g., AdamWOptimizer, SGD) |
| TConfig | Optimizer configuration type |
| config | Optimizer configuration |
| std::runtime_error | if network is not built |
| std::runtime_error | if parameter/gradient count mismatch |

|
inlinevirtualnoexcept |
Get the compute device for this composite.
Returns the device from the shared execution context.
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.


|
inlineoverridevirtual |
Get the component type identifier.
Used for serialization and runtime type identification.
Implements Mila::Dnn::Component< TDeviceType, TPrecision >.
|
inlineprivate |
Parse layer index from tensor name.
Extracts the layer number from names like "layers.5.attention.q_proj.weight"
|
inlineprivate |
Parse tensor name into component path and parameter name.
Examples: "wte.weight" -> ([], "wte.weight") "tf.layer_0.ln_1.weight" -> (["tf", "layer_0", "ln_1"], "weight") "tf.layer_0.fc_qkv_proj.bias" -> (["tf", "layer_0", "fc_qkv_proj"], "bias")
|
inline |
Save network to archive.
Saves component graph structure and delegates to concrete class via save_() hook for type-specific configuration.
Archive structure produced:
| archive | Archive to write to |
| mode | Serialization mode (Checkpoint, WeightsOnly, Architecture) |
| std::runtime_error | if save_() is not overridden by concrete class |

|
protectedpure virtual |
Hook for concrete classes to save type-specific state.
REQUIRED override for concrete networks. Must write:
This metadata enables the concrete class's Load() method to reconstruct the network.
Example implementation:
| archive | Archive to write to |
| mode | Serialization mode (passed from save()) |
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.
Implemented in Mila::Dnn::GptTransformer< TDeviceType, TPrecision >, and Mila::Dnn::LlamaTransformer< TDeviceType, TPrecision, TWeightQuantization, TKvCachePolicy >.

|
inlineprivate |
Save component graph topology.
Writes the component manifest (list of child components) and recursively saves each component's state with scoped namespacing.
Archive structure:
Components are saved in deterministic (sorted by name) order for reproducible archives.
| archive | Archive to write to |
| mode | Serialization mode (passed to children) |


|
inlineprivate |
Save base network metadata.
Writes generic metadata that applies to all networks:
| archive | Archive to write to |
| mode | Serialization mode |


|
inlineoverridevirtual |
Synchronize all child components.
Waits for outstanding device operations on all children.
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

|
inlineoverridevirtual |
Generate a human-readable description.
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

|
inlineprotected |
Verify that imported model is compatible with network architecture.