Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::MLP< TDeviceType, TPrecision > Class Template Referenceexport

Multi-Layer Perceptron (MLP) composite component. More...

Inheritance diagram for Mila::Dnn::MLP< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::MLP< TDeviceType, TPrecision >:

Public Types

using ComponentPtr = typename CompositeComponentBase::ComponentPtr
using CompositeComponentBase = CompositeComponent<TDeviceType, TPrecision>
using GeluType = Gelu<TDeviceType, TPrecision>
using LayerNormType = LayerNorm<TDeviceType, TPrecision>
using LinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using SwigluType = Swiglu<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 MLP (const std::string &name, const MLPConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct an MLP component.
 ~MLP () override=default
 Default destructor.
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Backward pass using captured forward intermediates.
TensorTypedecode (const TensorType &input) const
TensorTypeforward (const TensorType &input)
 Forward pass.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
void save_ (ModelArchive &archive, SerializationMode mode) const override
 Serialize parameters to a model archive.
std::string toString () const override
 Human-readable status and configuration summary.
void zeroGradients () override
 Zero gradients for all child components.
Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
DeviceId getDeviceId () const override
 Get the compute device for this composite.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
void synchronize () override
 Synchronize all child components.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.

Protected Member Functions

void onBuilding (const BuildContext &context) override
 Build-time callback invoked by the CompositeComponent framework.
void onTrainingModeChanging (TrainingMode training_mode) override
 Called when the training/inference mode changes.
Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Types

using ActivationBackwardFn = std::function<TensorType&(const TensorType&, const TensorType&)>
using ActivationBase = Component<TDeviceType, TPrecision>
using ActivationForwardFn = std::function<TensorType& (const TensorType&)>

Private Member Functions

std::string activationChildName () const
 Returns the child component name suffix for the configured activation type.
void addActivation (const std::string &suffix)
 Helper to create and add a LayerNorm child component.
void addLinear (const std::string &suffix, dim_t in_features, dim_t out_features)
 Helper to create and add a Linear child component.
void clearForwardCache () noexcept
 Clear cached non-owning forward pointers.
void createGraph ()
 Build the internal component graph according to config_.
void validateInputShape (const shape_t &input_shape) const
 Validate input shape against the MLP configuration.

Private Attributes

std::shared_ptr< ActivationBaseactivation_ { nullptr }
ActivationBackwardFn activation_backward_
ActivationForwardFn activation_forward_
shape_t cached_hidden_shape_
shape_t cached_input_shape_
MLPConfig config_
std::shared_ptr< LinearTypefc1_ { nullptr }
std::shared_ptr< LinearTypefc2_ { nullptr }
TensorTypelast_act_out_ { nullptr }
TensorTypelast_fc1_out_ { nullptr }
TensorTypelast_final_out_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::MLP< TDeviceType, TPrecision >

Multi-Layer Perceptron (MLP) composite component.

Device-templated composite component that implements a standard MLP structure: Input -> Linear(in_features, hidden_size) -> [LayerNorm] -> Activation -> Linear(hidden_size, in_features) -> Output

When the configured activation is SwiGLU, fc1 projects to 2*hidden_size; the Swiglu component splits that into two halves and computes x1 * GELU(x2), producing hidden_size output fed into fc2. For all other activations fc1 projects to hidden_size directly.

The component composes child components (Linear, optional LayerNorm, Activation) and delegates forward/backward calls to them. Child components own intermediate tensors; MLP stores non-owning pointers to those tensors after forward() to chain backward().

Threading: call sites must ensure that forward/backward/zeroGradients are invoked in a thread-safe manner relative to one another; this class does not provide internal synchronization.

Template Parameters
TDeviceTypeDevice type for execution (CPU, CUDA, ...).
TPrecisionTensor data precision (Fp32, Fp16, etc.). Must be supported on the device.

Constructor & Destructor Documentation

◆ MLP()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::MLP< TDeviceType, TPrecision >::MLP ( const std::string & name,
const MLPConfig & config,
std::optional< DeviceId > device_id = std::nullopt )
inlineexplicitexport

Construct an MLP component.

The constructor validates the provided config, constructs the internal child component graph, and optionally creates and assigns an execution context when device_id is provided.

Parameters
nameComponent name used to name child subcomponents.
configMLP configuration (input features, hidden size, activation, bias, layer-norm flag).
device_idOptional device identifier; when present the MLP creates an owned execution context bound to that device and sets it on the component. If the provided device_id type does not match the template TDeviceType, an exception is thrown.
Exceptions
std::invalid_argumentif device_id is present but has a mismatched device type.
Here is the call graph for this function:

◆ ~MLP()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::MLP< TDeviceType, TPrecision >::~MLP ( )
overrideexportdefault

Default destructor.

Child components are stored as shared_ptr and will be destroyed automatically.

Member Function Documentation

◆ activationChildName()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::MLP< TDeviceType, TPrecision >::activationChildName ( ) const
inlineexportprivate

Returns the child component name suffix for the configured activation type.

Returns
Suffix string used to name the activation child component.
Here is the caller graph for this function:

◆ addActivation()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::addActivation ( const std::string & suffix)
inlineexportprivate

Helper to create and add a LayerNorm child component.

The LayerNorm is constructed with axis=-1 by default to normalize the last dimension.

Parameters
suffixSuffix appended to parent name for the child component.

Create and register the activation child component for the configured type.

Uses mlp_activation_impl to construct the concrete activation, bind type-erased forward/backward lambdas, and register the component with the composite framework. Supports ActivationType::Gelu and ActivationType::Swiglu.

Parameters
suffixSuffix appended to parent name for the activation child component.
Exceptions
std::invalid_argumentif the activation type in config_ is unsupported.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ addLinear()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::addLinear ( const std::string & suffix,
dim_t in_features,
dim_t out_features )
inlineexportprivate

Helper to create and add a Linear child component.

The created Linear component uses the parent's name plus the provided suffix.

Parameters
suffixSuffix appended to parent name for the child component.
in_featuresNumber of input features for the linear layer.
out_featuresNumber of output features for the linear layer.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::MLP< TDeviceType, TPrecision >::backward ( const TensorType & input,
const TensorType & output_grad )
inlineexport

Backward pass using captured forward intermediates.

Uses the child-owned tensors captured by the most recent forward() invocation to chain backward calls without recomputing forward:

The method clears the cached forward pointers before returning to avoid accidental reuse.

Preconditions:

  • Component must be built.
  • forward() must have been called previously to populate internal forward caches.
Parameters
inputThe original input tensor passed to forward(); required by fc1_->backward.
output_gradGradient tensor w.r.t. the MLP output.
Returns
Reference to the input-gradient tensor (owned by the fc1 child).
Exceptions
std::runtime_errorif the component is not built or if forward() was not called.
Here is the call graph for this function:

◆ clearForwardCache()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::clearForwardCache ( )
inlineexportprivatenoexcept

Clear cached non-owning forward pointers.

Safe to call at any time; used to avoid accidental reuse of child-owned tensors across forward/backward cycles. No side effects beyond pointer reset.

Here is the caller graph for this function:

◆ createGraph()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::createGraph ( )
inlineexportprivate

Build the internal component graph according to config_.

For SwiGLU, fc1 projects to 2*hidden_size so that the Swiglu activation can split the output into gate and up halves. For all other activations fc1 projects to hidden_size. Called from the constructor; does not perform shape-dependent build calls.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ decode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::MLP< TDeviceType, TPrecision >::decode ( const TensorType & input) const
inlineexport
Here is the call graph for this function:

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::MLP< TDeviceType, TPrecision >::forward ( const TensorType & input)
inlineexport

Forward pass.

Chains child component forward calls:

The function stores non-owning pointers to child-owned intermediate tensors produced during the forward call; these pointers are used by backward() to chain gradients.

Preconditions:

  • Component must be built (onBuilding called).
  • Input tensor must be bound to the same device/context as the component.
Parameters
inputInput tensor bound to this component's device/context.
Returns
Reference to the output tensor produced by the final Linear child (owned by that child).
Exceptions
std::runtime_errorif the component is not built prior to calling forward.
Here is the call graph for this function:

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision>
MemoryStats Mila::Dnn::MLP< TDeviceType, TPrecision >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const ComponentType Mila::Dnn::MLP< TDeviceType, TPrecision >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::onBuilding ( const BuildContext & context)
inlineoverrideexportprotectedvirtual

Build-time callback invoked by the CompositeComponent framework.

Validates the provided input_shape, computes the hidden shape, and builds each child component with the appropriate shape. For SwiGLU, the activation is built with 2*hidden_size along the feature axis; fc2 always receives hidden_size. After building, any cached forward pointers are cleared.

Parameters
input_shapeShape of the input tensor. The last dimension must equal config_.getInputFeatures().
Exceptions
std::invalid_argumentif input_shape rank < 1 or last dimension mismatches config.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::onTrainingModeChanging ( TrainingMode training_mode)
inlineoverrideexportprotectedvirtual

Called when the training/inference mode changes.

Propagates the training flag to child components so they can adjust behavior (dropout, batch/statistics, etc.) as needed.

Parameters
is_trainingTrue if switching to training mode; false for evaluation mode.

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

Serialize parameters to a model archive.

Saves child component parameters and state into the provided archive in a deterministic order.

Parameters
archiveSerialization archive to write to.
modeSerialization mode (enum driven by Serialization::Mode).

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::MLP< TDeviceType, TPrecision >::toString ( ) const
inlineoverrideexportvirtual

Human-readable status and configuration summary.

Produces a multi-line string describing the component name, shapes, parameter counts, activation and layer-norm usage, device assignment (if set), and child component names.

Returns
String containing component introspection information suitable for logging.

Reimplemented from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ validateInputShape()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineexportprivate

Validate input shape against the MLP configuration.

Ensures the input tensor has rank >= 1 and that its last dimension matches config_.getInputFeatures().

Parameters
input_shapeShape to validate.
Exceptions
std::invalid_argumentwhen rank < 1 or last-dimension mismatch.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ zeroGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::MLP< TDeviceType, TPrecision >::zeroGradients ( )
inlineoverrideexportvirtual

Zero gradients for all child components.

Recursively zeroes optimizer/parameter gradients in children. Safe to call regardless of build state; child pointers are checked before use.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.


The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/FFN/MLP.ixx