Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Modules.Linear Module Reference

Exported Modules

module  Compute.UnaryOperation
 
module  Compute.DeviceContext
 
module  Compute.CudaMemoryResource
 
module  Compute.MemoryResource
 
module  Compute.DeviceType
 
module  Dnn.Module
 
module  Dnn.TensorTraits
 
module  Compute.OperationRegistry
 
module  Compute.Precision
 
module  Dnn.TensorHelpers
 
module  Dnn.Tensor
 
module  Serialization.ModelArchive
 
module  Compute.ComputeDevice
 
module  Compute.CpuMemoryResource
 
module  Compute.OperationBase
 
module  Compute.OperationAttributes
 

Classes

class  Mila::Dnn::Linear< TDeviceType, TInput, TOutput >
 A class representing a linear transformation module. More...
 
class  Mila::Dnn::LinearConfig
 Configuration class for Linear module. More...
 

Typedefs

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CpuLinear = Linear< DeviceType::Cpu, TInput, TOutput >
 Type alias for CPU-based linear module with customizable tensor types.
 
template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CudaLinear = Linear< DeviceType::Cuda, TInput, TOutput >
 Type alias for CUDA-based linear module with customizable tensor types.
 
using ModuleBase = Module< TDeviceType, TInput, TOutput >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 

Functions

 Linear (const std::string &device_name, const LinearConfig &config)
 Constructs a new Linear module with a device name.
 
 Linear (std::shared_ptr< DeviceContext > device_context, const LinearConfig &config)
 Constructs a new Linear module with a provided device context.
 
void backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad)
 Performs the backward pass of the Linear operation.
 
void createOperation ()
 Creates the appropriate Linear operation based on the current device context.
 
void forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output)
 Performs the forward pass of the Linear operation.
 
std::optional< std::shared_ptr< Tensor< TOutput, MR > > > getBias ()
 Retrieves the bias tensor if present.
 
std::shared_ptr< Tensor< TOutput, MR > > getWeight ()
 Retrieves the weight tensor for this linear layer.
 
bool hasBias () const
 Checks whether the module has a bias tensor.
 
void initializeParameterGradients ()
 Initializes gradient tensors for parameters.
 
void initializeParameters ()
 Initializes the tensors needed for the Linear operation.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &zip) const override
 Serializes the module state to a ZIP archive.
 
std::string toString () const override
 Converts the module information to a human-readable string.
 

Variables

std::shared_ptr< Tensor< TOutput, MR > > bias_ { nullptr }
 The bias tensor added after the matrix multiplication.
 
LinearConfig config_
 Configuration for the Linear module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > operation_ { nullptr }
 The underlying operation that implements the Linear transformation.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > output_state_
 Cache of intermediate tensors needed for backward pass.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameter_grads_
 Gradients for the parameters of this module.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameters_
 Collection of trainable parameters for this module.
 
OperationAttributes properties_
 Additional configuration options for the linear operation.
 
std::shared_ptr< Tensor< TOutput, MR > > weight_ { nullptr }
 The weight tensor for the linear transformation.
 

Files

file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Layers/Linear.ixx
 Implementation of the Linear (fully connected) module for neural networks.
 
file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Layers/LinearConfig.ixx
 Configuration interface for the Linear module in the Mila DNN framework.