Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.MLP Module Reference

Exported Modules

module  Dnn.Components.Gelu
module  Dnn.ComponentType
module  Dnn.TensorDataType
module  Dnn.TensorTypes
module  Compute.MemoryResource
module  Dnn.Components.Swiglu
module  Compute.DeviceTypeTraits
module  Dnn.ActivationType
module  Serialization.ModelArchive
module  Dnn.TensorDataTypeTraits
module  Dnn.Components.Linear
module  Compute.ExecutionContextFactory
module  Dnn.Components.LayerNorm
module  Compute.OperationRegistry
module  Compute.DeviceType
module  Serialization.Mode
module  Dnn.ITensor
module  Compute.ExecutionContext
module  Dnn.CompositeComponent
module  Compute.IExecutionContext
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Compute.Device
module  Dnn.Tensor
module  Compute.DeviceId

Classes

class  Mila::Dnn::MLP< TDeviceType, TPrecision >
 Multi-Layer Perceptron (MLP) composite component. More...
class  Mila::Dnn::MLPConfig
 Configuration class for the Multi-Layer Perceptron (MLP) block. More...

Typedefs

using ActivationBackwardFn = std::function<TensorType&(const TensorType&, const TensorType&)>
using ActivationBase = Component<TDeviceType, TPrecision>
using ActivationForwardFn = std::function<TensorType& (const TensorType&)>
using ComponentPtr = typename CompositeComponentBase::ComponentPtr
using CompositeComponentBase = CompositeComponent<TDeviceType, TPrecision>
using GeluType = Gelu<TDeviceType, TPrecision>
using LayerNormType = LayerNorm<TDeviceType, TPrecision>
using LinearType = Linear<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using SwigluType = Swiglu<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>

Functions

 MLP (const std::string &name, const MLPConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct an MLP component.
 ~MLP () override=default
 Default destructor.
std::string activationChildName () const
 Returns the child component name suffix for the configured activation type.
void addActivation (const std::string &suffix)
 Helper to create and add a LayerNorm child component.
void addLinear (const std::string &suffix, dim_t in_features, dim_t out_features)
 Helper to create and add a Linear child component.
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
 Backward pass using captured forward intermediates.
void clearForwardCache () noexcept
 Clear cached non-owning forward pointers.
void createGraph ()
 Build the internal component graph according to config_.
TensorTypedecode (const TensorType &input) const
TensorTypeforward (const TensorType &input)
 Forward pass.
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
const ComponentType getType () const override
 Get the component type identifier.
void onBuilding (const BuildContext &context) override
 Build-time callback invoked by the CompositeComponent framework.
void onTrainingModeChanging (TrainingMode training_mode) override
 Called when the training/inference mode changes.
void save_ (ModelArchive &archive, SerializationMode mode) const override
 Serialize parameters to a model archive.
std::string toString () const override
 Human-readable status and configuration summary.
void validateInputShape (const shape_t &input_shape) const
 Validate input shape against the MLP configuration.
void zeroGradients () override
 Zero gradients for all child components.

Variables

std::shared_ptr< ActivationBaseactivation_ { nullptr }
ActivationBackwardFn activation_backward_
ActivationForwardFn activation_forward_
shape_t cached_hidden_shape_
shape_t cached_input_shape_
MLPConfig config_
std::shared_ptr< LinearTypefc1_ { nullptr }
std::shared_ptr< LinearTypefc2_ { nullptr }
TensorTypelast_act_out_ { nullptr }
TensorTypelast_fc1_out_ { nullptr }
TensorTypelast_final_out_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/FFN/MLP.ixx
 Multi-Layer Perceptron (MLP) block for neural networks.
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/FFN/MLP.Config.ixx
file  /__w/Mila/Mila/Mila/Src/Dnn/Components/FFN/MLP.Dispatch.ixx
 Activation dispatch helpers for MLP.