Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Swiglu< TDeviceType, TPrecision > Class Template Referenceexport

SwiGLU activation component. More...

Inheritance diagram for Mila::Dnn::Swiglu< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::Swiglu< TDeviceType, TPrecision >:

Public Types

using ComponentBase = Component<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TensorType = Tensor<TPrecision, MR>

Public Member Functions

 Swiglu (const std::string &name, const SwigluConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 ~Swiglu () override=default
TensorTypebackward (const TensorType &input, const TensorType &output_grad)
TensorTypeforward (const TensorType &input)
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
MemoryStats getMemoryStats () const override
 Return memory allocation breakdown.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Static Public Member Functions

static std::unique_ptr< SwiglufromArchive_ (ModelArchive &archive, const std::string &component_name, IExecutionContext *exec_context)
Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.

Protected Member Functions

void onBuilding (const BuildContext &build_context) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Lifecycle hook: Called immediately after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Types

using OpType = typename OperationTraits<OperationType::SwigluOp, TDeviceType, TPrecision>::type

Private Member Functions

void createOperation ()
void validateBuildContext (const BuildContext &build_context) const

Static Private Member Functions

static void validateMetadata_ (const SerializationMetadata &meta, const std::string &component_name)

Private Attributes

SwigluConfig config_
std::unique_ptr< TensorTypeinput_grad_ { nullptr }
shape_t input_shape_
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< TensorTypeoutput_ { nullptr }
std::optional< TensorTypeoutput_view_
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }

Additional Inherited Members

Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Swiglu< TDeviceType, TPrecision >

SwiGLU activation component.

SwiGLU splits input along the feature axis into two halves x1,x2 and computes: out = x1 * GELU(x2)

Delegates work to a device-specific UnaryOperation named "SwigluOp".

Constructor & Destructor Documentation

◆ Swiglu()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Swiglu< TDeviceType, TPrecision >::Swiglu ( const std::string & name,
const SwigluConfig & config,
std::optional< DeviceId > device_id = std::nullopt )
inlineexplicitexport
Here is the call graph for this function:

◆ ~Swiglu()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Swiglu< TDeviceType, TPrecision >::~Swiglu ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::Swiglu< TDeviceType, TPrecision >::backward ( const TensorType & input,
const TensorType & output_grad )
inlineexport
Here is the call graph for this function:

◆ createOperation()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::createOperation ( )
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
TensorType & Mila::Dnn::Swiglu< TDeviceType, TPrecision >::forward ( const TensorType & input)
inlineexport
Here is the call graph for this function:

◆ fromArchive_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr< Swiglu > Mila::Dnn::Swiglu< TDeviceType, TPrecision >::fromArchive_ ( ModelArchive & archive,
const std::string & component_name,
IExecutionContext * exec_context )
inlinestaticexport
Here is the call graph for this function:

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TPrecision>
DeviceId Mila::Dnn::Swiglu< TDeviceType, TPrecision >::getDeviceId ( ) const
inlineoverrideexportvirtual

Get the compute device id associated with this component.

Must return the device on which parameters and operations execute.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector< ITensor * > Mila::Dnn::Swiglu< TDeviceType, TPrecision >::getGradients ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter gradient tensors.

Only valid when isTraining() is true.

Exceptions
std::runtime_errorif called when not in training mode or before the component has been built.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision>
MemoryStats Mila::Dnn::Swiglu< TDeviceType, TPrecision >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return memory allocation breakdown.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::vector< ITensor * > Mila::Dnn::Swiglu< TDeviceType, TPrecision >::getParameters ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter tensors.

The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ getType()

template<DeviceType TDeviceType, TensorDataType TPrecision>
const ComponentType Mila::Dnn::Swiglu< TDeviceType, TPrecision >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
SwigluConfig config_
Definition Swiglu.ixx:272
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::onExecutionContextSet ( )
inlineoverrideexportprotectedvirtual

Lifecycle hook: Called immediately after ExecutionContext is set.

Override this to perform initialization that requires a valid ExecutionContext. At the time this is called, getExecutionContext() is guaranteed to return a valid context.

Common uses:

  • Composite components: Create and configure child components.
  • Device resource allocation: Query device capabilities.

Default implementation does nothing.

Exceptions
Anyexception thrown will cause setExecutionContext() to fail and restore the component to a "context not set" state.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::onTrainingModeChanging ( TrainingMode mode)
inlineoverrideexportprotectedvirtual

Hook called before TrainingMode transitions.

Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.

The default implementation is a no-op.

Parameters
modeThe incoming TrainingMode.

Reimplemented from Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TPrecision>
size_t Mila::Dnn::Swiglu< TDeviceType, TPrecision >::parameterCount ( ) const
inlineoverrideexportvirtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::synchronize ( )
inlineoverrideexportvirtual

Wait for outstanding device work submitted by this component.

On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::string Mila::Dnn::Swiglu< TDeviceType, TPrecision >::toString ( ) const
inlineoverrideexportvirtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implements Mila::Dnn::Component< TDeviceType, TPrecision >.

Here is the call graph for this function:

◆ validateBuildContext()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::validateBuildContext ( const BuildContext & build_context) const
inlineexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validateMetadata_()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Swiglu< TDeviceType, TPrecision >::validateMetadata_ ( const SerializationMetadata & meta,
const std::string & component_name )
inlinestaticexportprivate
Here is the call graph for this function:
Here is the caller graph for this function:

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Swiglu/Swiglu.ixx