Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::ResidualConfig Class Referenceexport

Configuration class for Residual connection module. More...

Inheritance diagram for Mila::Dnn::ResidualConfig:
Collaboration diagram for Mila::Dnn::ResidualConfig:

Public Types

enum class  ConnectionType { Addition , ScaledAddition , Gated }
 

Public Member Functions

 ResidualConfig ()=default
 
ConnectionType getConnectionType () const
 
template<DeviceType TDeviceType, typename TInnerInput , typename TInnerOutput >
std::shared_ptr< Module< TDeviceType, TInnerInput, TInnerOutput > > getInnerModule () const
 
float getScalingFactor () const
 Set the inner module for the residual connection.
 
bool hasInnerModule () const
 
bool useProjection () const
 
void validate () const
 Validates the configuration.
 
ResidualConfigwithConnectionType (ConnectionType type)
 Set the connection type for the residual.
 
ResidualConfigwithProjection (bool use_projection)
 Configure whether to include a projection for dimension matching.
 
ResidualConfigwithScalingFactor (float scale)
 Set the scaling factor for the residual connection.
 
- Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor to support proper polymorphic destruction.
 
const std::string & getName () const
 Gets the configured component name.
 
ComputePrecision::Policy getPrecision () const
 Gets the configured precision policy.
 
bool isTraining () const
 Gets the configured training mode.
 
template<typename Self >
auto & withName (this Self &&self, std::string name)
 Sets the name of the component with fluent interface.
 
template<typename Self >
auto & withPrecision (this Self &&self, ComputePrecision::Policy policy)
 Sets the compute precision policy with fluent interface.
 
template<typename Self >
auto & withTraining (this Self &&self, bool is_training)
 Sets the training mode with fluent interface.
 

Private Attributes

ConnectionType connection_type_ = ConnectionType::Addition
 
std::shared_ptr< void > inner_module_ptr_ = nullptr
 
float scaling_factor_ = 1.0f
 
bool use_projection_ = true
 

Additional Inherited Members

- Protected Attributes inherited from Mila::Dnn::ComponentConfig
bool is_training_ = false
 Training mode flag, defaults to false (inference mode)
 
std::string name_ = "unnamed"
 Component name, defaults to "unnamed" if not explicitly set.
 
ComputePrecision::Policy precision_ = ComputePrecision::Policy::Auto
 Precision policy for computation, defaults to Auto.
 

Detailed Description

Configuration class for Residual connection module.

Member Enumeration Documentation

◆ ConnectionType

Enumerator
Addition 

Simple addition (x + F(x))

ScaledAddition 

Scaled addition (x + alpha*F(x))

Gated 

Gated connection using learnable parameters.

Constructor & Destructor Documentation

◆ ResidualConfig()

Mila::Dnn::ResidualConfig::ResidualConfig ( )
default

Member Function Documentation

◆ getConnectionType()

ConnectionType Mila::Dnn::ResidualConfig::getConnectionType ( ) const
inline
Here is the caller graph for this function:

◆ getInnerModule()

template<DeviceType TDeviceType, typename TInnerInput , typename TInnerOutput >
std::shared_ptr< Module< TDeviceType, TInnerInput, TInnerOutput > > Mila::Dnn::ResidualConfig::getInnerModule ( ) const
inline
Here is the caller graph for this function:

◆ getScalingFactor()

float Mila::Dnn::ResidualConfig::getScalingFactor ( ) const
inline

Set the inner module for the residual connection.

The inner module defines the transformation F(x) in the residual formula x + F(x).

Parameters
inner_moduleShared pointer to the inner module
Returns
ResidualConfig& Reference to this for method chaining
Here is the caller graph for this function:

◆ hasInnerModule()

bool Mila::Dnn::ResidualConfig::hasInnerModule ( ) const
inline

◆ useProjection()

bool Mila::Dnn::ResidualConfig::useProjection ( ) const
inline
Here is the caller graph for this function:

◆ validate()

void Mila::Dnn::ResidualConfig::validate ( ) const
inlinevirtual

Validates the configuration.

Base implementation validates that the component name is not empty. Derived classes should call this base implementation and add their own validation logic.

Exceptions
std::invalid_argumentIf the configuration is invalid

Reimplemented from Mila::Dnn::ComponentConfig.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ withConnectionType()

ResidualConfig & Mila::Dnn::ResidualConfig::withConnectionType ( ConnectionType  type)
inline

Set the connection type for the residual.

Parameters
typeConnection type
Returns
ResidualConfig& Reference to this for method chaining

◆ withProjection()

ResidualConfig & Mila::Dnn::ResidualConfig::withProjection ( bool  use_projection)
inline

Configure whether to include a projection for dimension matching.

When input and module output dimensions don't match, a projection linear layer can be automatically added to make dimensions compatible.

Parameters
use_projectionWhether to use projection when needed
Returns
ResidualConfig& Reference to this for method chaining

◆ withScalingFactor()

ResidualConfig & Mila::Dnn::ResidualConfig::withScalingFactor ( float  scale)
inline

Set the scaling factor for the residual connection.

Parameters
scaleScaling factor (only used for ScaledAddition type)
Returns
ResidualConfig& Reference to this for method chaining

Member Data Documentation

◆ connection_type_

ConnectionType Mila::Dnn::ResidualConfig::connection_type_ = ConnectionType::Addition
private

◆ inner_module_ptr_

std::shared_ptr<void> Mila::Dnn::ResidualConfig::inner_module_ptr_ = nullptr
private

◆ scaling_factor_

float Mila::Dnn::ResidualConfig::scaling_factor_ = 1.0f
private

◆ use_projection_

bool Mila::Dnn::ResidualConfig::use_projection_ = true
private

The documentation for this class was generated from the following file: