Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Modules.Dropout Module Reference

Exported Modules

module  Compute.DeviceType
 
module  Serialization.ModelArchive
 
module  Compute.ComputeDevice
 
module  Compute.CudaMemoryResource
 
module  Compute.DeviceContext
 
module  Dnn.Module
 
module  Compute.MemoryResource
 
module  Dnn.Tensor
 
module  Compute.OperationRegistry
 
module  Compute.Precision
 
module  Dnn.TensorTraits
 
module  Compute.OperationAttributes
 
module  Compute.UnaryOperation
 
module  Compute.OperationBase
 
module  Compute.CpuMemoryResource
 

Classes

class  Mila::Dnn::Dropout< TDeviceType, TInput, TOutput >
 Dropout regularization module for neural networks. More...
 
class  Mila::Dnn::DropoutConfig
 Configuration class for Dropout module. More...
 

Typedefs

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CpuDropout = Dropout< DeviceType::Cpu, TInput, TOutput >
 Type alias for CPU-based dropout module with customizable tensor types.
 
template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CudaDropout = Dropout< DeviceType::Cuda, TInput, TOutput >
 Type alias for CUDA-based dropout module with customizable tensor types.
 
using ModuleBase = Module< TDeviceType, TInput, TOutput >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 

Functions

 Dropout (const std::string &device_name, const DropoutConfig &config)
 Constructs a new Dropout module with a device name.
 
 Dropout (std::shared_ptr< DeviceContext > device_context, const DropoutConfig &config)
 Constructs a new Dropout module with a provided device context.
 
void backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, Tensor< TInput, MR > &input_grad)
 Performs the backward pass of the Dropout operation.
 
void createOperation ()
 Creates the appropriate Dropout operation based on the current device context.
 
void forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output)
 Performs the forward pass of the Dropout operation.
 
void generateMask (Tensor< TOutput, MR > &mask, const std::vector< size_t > &shape)
 Generates a new dropout mask for the given shape.
 
float getProbability () const
 Gets the dropout probability used by this module.
 
unsigned int getSeed () const
 Gets the current random seed.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &archive) const override
 Serializes the module state to a ZIP archive.
 
void setSeed (unsigned int seed)
 Sets the random seed for dropout mask generation.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 

Variables

DropoutConfig config_
 Configuration for the Dropout module.
 
std::shared_ptr< Tensor< TOutput, MR > > mask_ { nullptr }
 The binary mask tensor for element selection.
 
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > operation_ { nullptr }
 The operation that implements the dropout calculation.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > output_state_
 Collection of output state tensors for caching.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameters_
 Collection of parameters for this module (empty for Dropout).
 
OperationAttributes properties_
 Operation attributes and configuration.
 
std::mt19937 rng_
 Random number generator for mask generation.
 
unsigned int seed_ { 0 }
 Random seed for reproducible dropout patterns.
 

Files

file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Regularization/Dropout.ixx
 Implementation of Dropout regularization module for neural networks.
 
file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Regularization/DropoutConfig.ixx
 Configuration interface for the Dropout regularization module in the Mila DNN framework.