Configuration class for Residual connection module.
More...
Configuration class for Residual connection module.
◆ ConnectionType
| Enumerator |
|---|
| Addition | Simple addition (x + F(x))
|
| ScaledAddition | Scaled addition (x + alpha*F(x))
|
| Gated | Gated connection using learnable parameters.
|
◆ ResidualConfig()
| Mila::Dnn::ResidualConfig::ResidualConfig |
( |
| ) |
|
|
default |
◆ getConnectionType()
◆ getInnerModule()
template<
DeviceType TDeviceType, typename TInnerInput , typename TInnerOutput >
| std::shared_ptr< Module< TDeviceType, TInnerInput, TInnerOutput > > Mila::Dnn::ResidualConfig::getInnerModule |
( |
| ) |
const |
|
inline |
◆ getScalingFactor()
| float Mila::Dnn::ResidualConfig::getScalingFactor |
( |
| ) |
const |
|
inline |
Set the inner module for the residual connection.
The inner module defines the transformation F(x) in the residual formula x + F(x).
- Parameters
-
| inner_module | Shared pointer to the inner module |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ hasInnerModule()
| bool Mila::Dnn::ResidualConfig::hasInnerModule |
( |
| ) |
const |
|
inline |
◆ useProjection()
| bool Mila::Dnn::ResidualConfig::useProjection |
( |
| ) |
const |
|
inline |
◆ validate()
| void Mila::Dnn::ResidualConfig::validate |
( |
| ) |
const |
|
inlinevirtual |
Validates the configuration.
Base implementation validates that the component name is not empty. Derived classes should call this base implementation and add their own validation logic.
- Exceptions
-
| std::invalid_argument | If the configuration is invalid |
Reimplemented from Mila::Dnn::ComponentConfig.
◆ withConnectionType()
Set the connection type for the residual.
- Parameters
-
- Returns
- ResidualConfig& Reference to this for method chaining
◆ withProjection()
| ResidualConfig & Mila::Dnn::ResidualConfig::withProjection |
( |
bool |
use_projection | ) |
|
|
inline |
Configure whether to include a projection for dimension matching.
When input and module output dimensions don't match, a projection linear layer can be automatically added to make dimensions compatible.
- Parameters
-
| use_projection | Whether to use projection when needed |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ withScalingFactor()
| ResidualConfig & Mila::Dnn::ResidualConfig::withScalingFactor |
( |
float |
scale | ) |
|
|
inline |
Set the scaling factor for the residual connection.
- Parameters
-
| scale | Scaling factor (only used for ScaledAddition type) |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ connection_type_
◆ inner_module_ptr_
| std::shared_ptr<void> Mila::Dnn::ResidualConfig::inner_module_ptr_ = nullptr |
|
private |
◆ scaling_factor_
| float Mila::Dnn::ResidualConfig::scaling_factor_ = 1.0f |
|
private |
◆ use_projection_
| bool Mila::Dnn::ResidualConfig::use_projection_ = true |
|
private |
The documentation for this class was generated from the following file: