|
Mila 0.13.48
Deep Neural Network Library
|
Multi-Layer Perceptron (MLP) composite component. More...


Public Types | |
| using | ComponentPtr = typename CompositeComponentBase::ComponentPtr |
| using | CompositeComponentBase = CompositeComponent<TDeviceType, TPrecision> |
| using | GeluType = Gelu<TDeviceType, TPrecision> |
| using | LayerNormType = LayerNorm<TDeviceType, TPrecision> |
| using | LinearType = Linear<TDeviceType, TPrecision> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | SwigluType = Swiglu<TDeviceType, TPrecision> |
| using | TensorType = Tensor<TPrecision, MR> |
| Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| using | ComponentBase = Component<TDeviceType, TPrecision> |
| using | ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>> |
Public Member Functions | |
| MLP (const std::string &name, const MLPConfig &config, std::optional< DeviceId > device_id=std::nullopt) | |
| Construct an MLP component. | |
| ~MLP () override=default | |
| Default destructor. | |
| TensorType & | backward (const TensorType &input, const TensorType &output_grad) |
| Backward pass using captured forward intermediates. | |
| TensorType & | decode (const TensorType &input) const |
| TensorType & | forward (const TensorType &input) |
| Forward pass. | |
| MemoryStats | getMemoryStats () const override |
| Return the current memory allocation breakdown for this component. | |
| const ComponentType | getType () const override |
| Get the component type identifier. | |
| void | save_ (ModelArchive &archive, SerializationMode mode) const override |
| Serialize parameters to a model archive. | |
| std::string | toString () const override |
| Human-readable status and configuration summary. | |
| void | zeroGradients () override |
| Zero gradients for all child components. | |
| Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| CompositeComponent (CompositeComponent &&) noexcept=default | |
| CompositeComponent (const CompositeComponent &)=delete | |
| CompositeComponent (const std::string &name) | |
| Construct composite component with name. | |
| virtual | ~CompositeComponent ()=default |
| CompositeComponent & | addComponent (ComponentPtr component) |
| Add a pre-constructed child component (chainable). | |
| size_t | childCount () const noexcept |
| Get the number of direct children. | |
| void | clearComponents () |
| Clear all child components. | |
| ComponentPtr | findComponent (const std::string &path) const |
| Resolve a dot-separated component path within this composite. | |
| ComponentPtr | getComponent (const std::string &name) const |
| Retrieve a direct child component by name. | |
| const std::vector< ComponentPtr > & | getComponents () const |
| Get all child components in insertion order. | |
| DeviceId | getDeviceId () const override |
| Get the compute device for this composite. | |
| std::vector< ITensor * > | getGradients () const override |
| Get all parameter gradients from all children. | |
| std::vector< ITensor * > | getParameters () const override |
| Get all parameters from all children. | |
| bool | hasChildren () const noexcept |
| Check if this composite has any children. | |
| bool | hasComponent (const std::string &name) const |
| Check if a named child component exists. | |
| CompositeComponent & | operator= (CompositeComponent &&) noexcept=default |
| CompositeComponent & | operator= (const CompositeComponent &)=delete |
| size_t | parameterCount () const override |
| Count parameters across all children. | |
| bool | removeComponent (const std::string &name) |
| Get the named child components map. | |
| void | synchronize () override |
| Synchronize all child components. | |
| ComponentPtr | tryFindComponent (const std::string &path) const |
| Try to resolve a dot-separated component path within this composite. | |
| Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
Protected Member Functions | |
| void | onBuilding (const BuildContext &context) override |
| Build-time callback invoked by the CompositeComponent framework. | |
| void | onTrainingModeChanging (TrainingMode training_mode) override |
| Called when the training/inference mode changes. | |
| Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision > | |
| template<typename TComponent> | |
| std::shared_ptr< TComponent > | getComponentAs (const std::string &name) const |
| Retrieve a typed child component by name. | |
| void | onExecutionContextSet () override |
| Hook invoked after ExecutionContext is set. | |
| virtual void | optimize () |
| Virtual hook for graph optimization after construction. | |
| Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| template<TensorDataType TParameterPrecision, typename TMemoryResource> | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Private Types | |
| using | ActivationBackwardFn = std::function<TensorType&(const TensorType&, const TensorType&)> |
| using | ActivationBase = Component<TDeviceType, TPrecision> |
| using | ActivationForwardFn = std::function<TensorType& (const TensorType&)> |
Private Member Functions | |
| std::string | activationChildName () const |
| Returns the child component name suffix for the configured activation type. | |
| void | addActivation (const std::string &suffix) |
| Helper to create and add a LayerNorm child component. | |
| void | addLinear (const std::string &suffix, dim_t in_features, dim_t out_features) |
| Helper to create and add a Linear child component. | |
| void | clearForwardCache () noexcept |
| Clear cached non-owning forward pointers. | |
| void | createGraph () |
| Build the internal component graph according to config_. | |
| void | validateInputShape (const shape_t &input_shape) const |
| Validate input shape against the MLP configuration. | |
Private Attributes | |
| std::shared_ptr< ActivationBase > | activation_ { nullptr } |
| ActivationBackwardFn | activation_backward_ |
| ActivationForwardFn | activation_forward_ |
| shape_t | cached_hidden_shape_ |
| shape_t | cached_input_shape_ |
| MLPConfig | config_ |
| std::shared_ptr< LinearType > | fc1_ { nullptr } |
| std::shared_ptr< LinearType > | fc2_ { nullptr } |
| TensorType * | last_act_out_ { nullptr } |
| TensorType * | last_fc1_out_ { nullptr } |
| TensorType * | last_final_out_ { nullptr } |
| std::unique_ptr< IExecutionContext > | owned_exec_context_ { nullptr } |
Additional Inherited Members | |
| Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
| Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision > | |
| BuildContext | build_context_ { shape_t{ 1 }, RuntimeMode::Training } |
| The BuildContext stored at build time. | |
Multi-Layer Perceptron (MLP) composite component.
Device-templated composite component that implements a standard MLP structure: Input -> Linear(in_features, hidden_size) -> [LayerNorm] -> Activation -> Linear(hidden_size, in_features) -> Output
When the configured activation is SwiGLU, fc1 projects to 2*hidden_size; the Swiglu component splits that into two halves and computes x1 * GELU(x2), producing hidden_size output fed into fc2. For all other activations fc1 projects to hidden_size directly.
The component composes child components (Linear, optional LayerNorm, Activation) and delegates forward/backward calls to them. Child components own intermediate tensors; MLP stores non-owning pointers to those tensors after forward() to chain backward().
Threading: call sites must ensure that forward/backward/zeroGradients are invoked in a thread-safe manner relative to one another; this class does not provide internal synchronization.
|
inlineexplicitexport |
Construct an MLP component.
The constructor validates the provided config, constructs the internal child component graph, and optionally creates and assigns an execution context when device_id is provided.
| name | Component name used to name child subcomponents. |
| config | MLP configuration (input features, hidden size, activation, bias, layer-norm flag). |
| device_id | Optional device identifier; when present the MLP creates an owned execution context bound to that device and sets it on the component. If the provided device_id type does not match the template TDeviceType, an exception is thrown. |
| std::invalid_argument | if device_id is present but has a mismatched device type. |

|
overrideexportdefault |
Default destructor.
Child components are stored as shared_ptr and will be destroyed automatically.
|
inlineexportprivate |
Returns the child component name suffix for the configured activation type.

|
inlineexportprivate |
Helper to create and add a LayerNorm child component.
The LayerNorm is constructed with axis=-1 by default to normalize the last dimension.
| suffix | Suffix appended to parent name for the child component. |
Create and register the activation child component for the configured type.
Uses mlp_activation_impl to construct the concrete activation, bind type-erased forward/backward lambdas, and register the component with the composite framework. Supports ActivationType::Gelu and ActivationType::Swiglu.
| suffix | Suffix appended to parent name for the activation child component. |
| std::invalid_argument | if the activation type in config_ is unsupported. |


|
inlineexportprivate |
Helper to create and add a Linear child component.
The created Linear component uses the parent's name plus the provided suffix.
| suffix | Suffix appended to parent name for the child component. |
| in_features | Number of input features for the linear layer. |
| out_features | Number of output features for the linear layer. |


|
inlineexport |
Backward pass using captured forward intermediates.
Uses the child-owned tensors captured by the most recent forward() invocation to chain backward calls without recomputing forward:
The method clears the cached forward pointers before returning to avoid accidental reuse.
Preconditions:
| input | The original input tensor passed to forward(); required by fc1_->backward. |
| output_grad | Gradient tensor w.r.t. the MLP output. |
| std::runtime_error | if the component is not built or if forward() was not called. |

|
inlineexportprivatenoexcept |
Clear cached non-owning forward pointers.
Safe to call at any time; used to avoid accidental reuse of child-owned tensors across forward/backward cycles. No side effects beyond pointer reset.

|
inlineexportprivate |
Build the internal component graph according to config_.
For SwiGLU, fc1 projects to 2*hidden_size so that the Swiglu activation can split the output into gate and up halves. For all other activations fc1 projects to hidden_size. Called from the constructor; does not perform shape-dependent build calls.


|
inlineexport |

|
inlineexport |
Forward pass.
Chains child component forward calls:
The function stores non-owning pointers to child-owned intermediate tensors produced during the forward call; these pointers are used by backward() to chain gradients.
Preconditions:
| input | Input tensor bound to this component's device/context. |
| std::runtime_error | if the component is not built prior to calling forward. |

|
inlineoverrideexportvirtual |
Return the current memory allocation breakdown for this component.
Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:
After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients
For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.
May be called at any time — no lifecycle preconditions.
Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

|
inlineoverrideexportvirtual |
Get the component type identifier.
Used for serialization and runtime type identification.
Implements Mila::Dnn::Component< TDeviceType, TPrecision >.
|
inlineoverrideexportprotectedvirtual |
Build-time callback invoked by the CompositeComponent framework.
Validates the provided input_shape, computes the hidden shape, and builds each child component with the appropriate shape. For SwiGLU, the activation is built with 2*hidden_size along the feature axis; fc2 always receives hidden_size. After building, any cached forward pointers are cleared.
| input_shape | Shape of the input tensor. The last dimension must equal config_.getInputFeatures(). |
| std::invalid_argument | if input_shape rank < 1 or last dimension mismatches config. |
Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

|
inlineoverrideexportprotectedvirtual |
Called when the training/inference mode changes.
Propagates the training flag to child components so they can adjust behavior (dropout, batch/statistics, etc.) as needed.
| is_training | True if switching to training mode; false for evaluation mode. |
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.
|
inlineoverrideexportvirtual |
Serialize parameters to a model archive.
Saves child component parameters and state into the provided archive in a deterministic order.
| archive | Serialization archive to write to. |
| mode | Serialization mode (enum driven by Serialization::Mode). |
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.
|
inlineoverrideexportvirtual |
Human-readable status and configuration summary.
Produces a multi-line string describing the component name, shapes, parameter counts, activation and layer-norm usage, device assignment (if set), and child component names.
Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

|
inlineexportprivate |
Validate input shape against the MLP configuration.
Ensures the input tensor has rank >= 1 and that its last dimension matches config_.getInputFeatures().
| input_shape | Shape to validate. |
| std::invalid_argument | when rank < 1 or last-dimension mismatch. |


|
inlineoverrideexportvirtual |
Zero gradients for all child components.
Recursively zeroes optimizer/parameter gradients in children. Safe to call regardless of build state; child pointers are checked before use.
Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.