Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Modules.CrossEntropy Module Reference

Exported Modules

module  Compute.OperationAttributes
 
module  Compute.DeviceContext
 
module  Compute.CudaMemoryResource
 
module  Compute.MemoryResource
 
module  Compute.DeviceType
 
module  Dnn.Module
 
module  Dnn.Tensor
 
module  Dnn.TensorHelpers
 
module  Dnn.TensorTraits
 
module  Compute.OperationRegistry
 
module  Compute.Precision
 
module  Compute.CpuMemoryResource
 
module  Compute.OperationBase
 
module  Compute.UnaryOperation
 
module  Serialization.ModelArchive
 
module  Compute.ComputeDevice
 

Classes

class  Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >
 CrossEntropy loss module for neural networks. More...
 
class  Mila::Dnn::CrossEntropyConfig
 Configuration class for CrossEntropy module. More...
 

Typedefs

template<typename TLogits = float, typename TTargets = int>
using Mila::Dnn::CpuCrossEntropy = CrossEntropy< DeviceType::Cpu, TLogits, TTargets >
 Type alias for CPU-based cross entropy module with customizable tensor types.
 
template<typename TLogits = float, typename TTargets = int>
using Mila::Dnn::CudaCrossEntropy = CrossEntropy< DeviceType::Cuda, TLogits, TTargets >
 Type alias for CUDA-based cross entropy module with customizable tensor types.
 
using ModuleBase = Module< TDeviceType, TLogits, TTargets >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 

Functions

 CrossEntropy (const std::string &device_name, const CrossEntropyConfig &config)
 Constructs a new CrossEntropy module with a device name.
 
 CrossEntropy (std::shared_ptr< DeviceContext > device_context, const CrossEntropyConfig &config)
 Constructs a new CrossEntropy module with a provided device context.
 
void backward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, const Tensor< TLogits, MR > &output_grad, Tensor< TLogits, MR > &input_grad)
 Calculates gradients for the backward pass.
 
void createOperation ()
 Creates the appropriate cross entropy operation for the current device.
 
void forward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, Tensor< TLogits, MR > &output)
 Performs the forward pass of the cross entropy operation.
 
std::shared_ptr< Tensor< TLogits, MR > > getClassWeights () const
 Gets the class weights tensor.
 
float getLabelSmoothing () const
 Gets the label smoothing factor.
 
int64_t getPaddingIndex () const
 Gets the padding index.
 
int64_t getVocabSize () const
 Gets the vocabulary size.
 
bool ignorePadding () const
 Checks if padding is ignored.
 
void initializeClassWeights (const std::vector< float > &weights)
 Initializes the class weights tensor from a vector of weights.
 
bool isReduced () const
 Checks if loss is reduced.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &archive) const override
 Serializes the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 

Variables

OperationAttributes attributes_
 Operation attributes and configuration.
 
std::shared_ptr< Tensor< TLogits, MR > > class_weights_ { nullptr }
 Optional tensor containing weights for each class.
 
CrossEntropyConfig config_
 Configuration for the CrossEntropy module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TLogits, TTargets > > operation_ { nullptr }
 The operation that implements the cross entropy calculation.
 
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > output_state_
 Collection of output state tensors for caching.
 
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > parameters_
 Collection of parameters for this module.
 

Files

file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Losses/CrossEntropy.ixx
 Implementation of the CrossEntropy loss function module.
 
file  /home/runner/work/Mila/Mila/Mila/Src/Dnn/Modules/Losses/CrossEntropyConfig.ixx
 Configuration interface for the CrossEntropy module in the Mila DNN framework.