Mila
Deep Neural Network Library
|
A class implementing a residual connection module. More...
Public Types | |
using | ModuleBase = Module< TDeviceType, TInput, TOutput > |
Alias for base module type. | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Memory resource type used for tensors, selected based on device type. | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
Residual (const std::string &device_name, const ResidualConfig &config) | |
Constructs a new Residual module with a device name. | |
Residual (std::shared_ptr< DeviceContext > device_context, const ResidualConfig &config) | |
Constructs a new Residual module with a provided device context. | |
void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad) |
Performs the backward pass of the Residual connection. | |
void | forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output) |
Performs the forward pass of the Residual connection. | |
std::shared_ptr< Module< TDeviceType, TInput, TOutput > > | getInnerModule () |
Gets the inner module. | |
void | load (ModelArchive &archive) override |
Deserializes the module state from a ZIP archive. | |
size_t | parameterCount () const override |
Gets the number of trainable parameters in this module. | |
void | save (ModelArchive &zip) const override |
Serializes the module state to a ZIP archive. | |
std::string | toString () const override |
Converts the module information to a human-readable string. | |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
Private Member Functions | |
void | addTensors (const Tensor< TInput, MR > &a, const Tensor< TInput, MR > &b, Tensor< TInput, MR > &result) |
Adds two tensors element-wise. | |
void | createOperation () |
Creates an appropriate operation based on the connection type. | |
void | createProjection (const std::vector< size_t > &input_shape, const std::vector< size_t > &output_shape) |
Creates a projection layer when input and output dimensions don't match. | |
void | initializeGateParameters (const std::vector< size_t > &shape) |
Initializes parameters for gated connections. | |
bool | tensorShapesMatch (const Tensor< TInput, MR > &a, const Tensor< TOutput, MR > &b) |
Checks if two tensor shapes match for residual connection. | |
Static Private Member Functions | |
static std::string | connectionTypeToString (ResidualConfig::ConnectionType type) |
Converts connection type enum to string for display purposes. | |
Private Attributes | |
ResidualConfig | config_ |
Configuration for the Residual module. | |
std::shared_ptr< Tensor< TOutput, MR > > | gate_weights_ |
Learnable gate weights for gated residual connections. | |
std::shared_ptr< BinaryOperation< TDeviceType, TInput, TOutput, TOutput > > | gated_operation_ |
Binary operation for gated residual connections. | |
Tensor< TInput, MR > | inner_input_grad_ {} |
Temporary tensor to store inner module gradients during backward pass. | |
std::shared_ptr< Module< TDeviceType, TInput, TOutput > > | inner_module_ |
The inner module implementing the transformation F(x). | |
Tensor< TOutput, MR > | inner_output_ {} |
Temporary tensor to store inner module output during forward pass. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | inner_parameter_grads_ |
Gradients for inner parameters. | |
std::shared_ptr< BinaryOperation< TDeviceType, TInput, TOutput, TOutput > > | operation_ |
Binary operation for standard and scaled residual connections. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | output_state_ |
Output state tensors for backward pass. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | parameter_grads_ |
Gradients for trainable parameters. | |
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > | parameters_ |
Collection of trainable parameters. | |
std::shared_ptr< Linear< TDeviceType, TInput, TOutput > > | projection_ |
Optional projection layer for dimension matching. | |
Tensor< TOutput, MR > | projection_output_ {} |
Temporary tensor to store projection output during forward pass. | |
OperationAttributes | properties_ |
Operation-specific attributes. | |
Tensor< TInput, MR > | temp_grad_ {} |
Temporary tensor for gradient accumulation. | |
Additional Inherited Members | |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
A class implementing a residual connection module.
Residual connections help deep neural networks avoid vanishing gradients by providing shortcut connections. The basic formula is y = x + F(x), where F is a differentiable function (usually a sequence of neural network layers).
This implementation supports three types of residual connections:
When input and output dimensions don't match, an optional projection layer can be automatically added to make the dimensions compatible.
TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
TInput | The data type of the input tensor elements. |
TOutput | The data type of the output tensor elements, defaults to TInput. |
|
export |
Alias for base module type.
|
export |
Memory resource type used for tensors, selected based on device type.
|
inlineexplicitexport |
Constructs a new Residual module with a device name.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
config | Configuration parameters for the Residual module. |
std::invalid_argument | If the device name is invalid or the configuration is invalid |
std::runtime_error | If device type doesn't match template parameter TDeviceType or inner module type mismatch |
|
inlineexplicitexport |
Constructs a new Residual module with a provided device context.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
device_context | The device context to use for this module. |
config | Configuration parameters for the Residual module. |
std::invalid_argument | If device_context is null or configuration is invalid |
std::runtime_error | If device context type doesn't match template parameter TDeviceType or inner module type mismatch |
|
inlineexportprivate |
Adds two tensors element-wise.
Helper method to add gradients from different paths during backpropagation.
a | First input tensor |
b | Second input tensor |
result | Output tensor for the sum |
|
inlineexport |
Performs the backward pass of the Residual connection.
Computes gradients for the input tensor and parameters based on the output gradients. Handles backpropagation through the inner module and projection layer (if present).
input | The input tensor from the forward pass. |
output_grad | The gradient of loss with respect to the output. |
input_grad | The tensor to store gradients with respect to input. |
|
inlinestaticexportprivate |
Converts connection type enum to string for display purposes.
type | The connection type enum value |
|
inlineexportprivate |
Creates an appropriate operation based on the connection type.
Instantiates the correct operation implementation based on the configured connection type (Addition, ScaledAddition, or Gated) and device type.
|
inlineexportprivate |
Creates a projection layer when input and output dimensions don't match.
Instantiates a Linear layer to project the input to the correct dimensions to match the output of the inner module.
input_shape | Shape of the input tensor |
output_shape | Shape of the output tensor from inner module |
|
inlineexport |
Performs the forward pass of the Residual connection.
Applies the residual transformation based on the configured connection type:
Handles projection when input and inner module dimensions don't match.
input | The input tensor to be processed. |
output | The output tensor where the results will be stored. |
std::runtime_error | If dimensions don't match and projection is disabled. |
|
inlineexport |
Gets the inner module.
Returns the inner module that implements the transformation F(x) in the residual connection formula y = x + F(x).
|
inlineexportprivate |
Initializes parameters for gated connections.
Creates and initializes the learnable gate weights for gated residual connections. The gate weights determine how much of the input vs. transformed output to use.
shape | Shape of the tensor for gate weights |
|
inlineoverrideexportvirtual |
Deserializes the module state from a ZIP archive.
Loads the state of the inner module, projection layer (if present), and gating parameters (if used) from the provided ZIP archive.
zip | The ZIP archive to load the module state from. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Gets the number of trainable parameters in this module.
Counts the total number of trainable parameters in the residual module, including the inner module, projection layer (if present), and gating parameters (if using gated connections).
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Serializes the module state to a ZIP archive.
Saves the state of the inner module, projection layer (if present), and gating parameters (if used) to the provided ZIP archive.
zip | The ZIP archive to save the module state to. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineexportprivate |
Checks if two tensor shapes match for residual connection.
a | First tensor to compare |
b | Second tensor to compare |
|
inlineoverrideexportvirtual |
Converts the module information to a human-readable string.
Includes detailed information about the module configuration including:
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
exportprivate |
Configuration for the Residual module.
|
exportprivate |
Learnable gate weights for gated residual connections.
|
exportprivate |
Binary operation for gated residual connections.
|
exportprivate |
Temporary tensor to store inner module gradients during backward pass.
|
exportprivate |
The inner module implementing the transformation F(x).
|
exportprivate |
Temporary tensor to store inner module output during forward pass.
|
exportprivate |
Gradients for inner parameters.
|
exportprivate |
Binary operation for standard and scaled residual connections.
|
exportprivate |
Output state tensors for backward pass.
|
exportprivate |
Gradients for trainable parameters.
|
exportprivate |
Collection of trainable parameters.
|
exportprivate |
Optional projection layer for dimension matching.
|
exportprivate |
Temporary tensor to store projection output during forward pass.
|
exportprivate |
Operation-specific attributes.
|
exportprivate |
Temporary tensor for gradient accumulation.