Configuration class for Residual connection module.
More...
Configuration class for Residual connection module.
◆ ConnectionType
Enumerator |
---|
Addition | Simple addition (x + F(x))
|
ScaledAddition | Scaled addition (x + alpha*F(x))
|
Gated | Gated connection using learnable parameters.
|
◆ ResidualConfig()
Mila::Dnn::ResidualConfig::ResidualConfig |
( |
| ) |
|
|
default |
◆ getConnectionType()
◆ getInnerModule()
template<
DeviceType TDeviceType, typename TInnerInput , typename TInnerOutput >
std::shared_ptr< Module< TDeviceType, TInnerInput, TInnerOutput > > Mila::Dnn::ResidualConfig::getInnerModule |
( |
| ) |
const |
|
inline |
◆ getScalingFactor()
float Mila::Dnn::ResidualConfig::getScalingFactor |
( |
| ) |
const |
|
inline |
Set the inner module for the residual connection.
The inner module defines the transformation F(x) in the residual formula x + F(x).
- Parameters
-
inner_module | Shared pointer to the inner module |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ hasInnerModule()
bool Mila::Dnn::ResidualConfig::hasInnerModule |
( |
| ) |
const |
|
inline |
◆ useProjection()
bool Mila::Dnn::ResidualConfig::useProjection |
( |
| ) |
const |
|
inline |
◆ validate()
void Mila::Dnn::ResidualConfig::validate |
( |
| ) |
const |
|
inlinevirtual |
Validates the configuration.
Base implementation validates that the component name is not empty. Derived classes should call this base implementation and add their own validation logic.
- Exceptions
-
std::invalid_argument | If the configuration is invalid |
Reimplemented from Mila::Dnn::ComponentConfig.
◆ withConnectionType()
Set the connection type for the residual.
- Parameters
-
- Returns
- ResidualConfig& Reference to this for method chaining
◆ withProjection()
ResidualConfig & Mila::Dnn::ResidualConfig::withProjection |
( |
bool |
use_projection | ) |
|
|
inline |
Configure whether to include a projection for dimension matching.
When input and module output dimensions don't match, a projection linear layer can be automatically added to make dimensions compatible.
- Parameters
-
use_projection | Whether to use projection when needed |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ withScalingFactor()
ResidualConfig & Mila::Dnn::ResidualConfig::withScalingFactor |
( |
float |
scale | ) |
|
|
inline |
Set the scaling factor for the residual connection.
- Parameters
-
scale | Scaling factor (only used for ScaledAddition type) |
- Returns
- ResidualConfig& Reference to this for method chaining
◆ connection_type_
◆ inner_module_ptr_
std::shared_ptr<void> Mila::Dnn::ResidualConfig::inner_module_ptr_ = nullptr |
|
private |
◆ scaling_factor_
float Mila::Dnn::ResidualConfig::scaling_factor_ = 1.0f |
|
private |
◆ use_projection_
bool Mila::Dnn::ResidualConfig::use_projection_ = true |
|
private |
The documentation for this class was generated from the following file: