Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets > Class Template Referenceexport

CrossEntropy loss module for neural networks. More...

Inheritance diagram for Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >:
Collaboration diagram for Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >:

Public Types

using ModuleBase = Module< TDeviceType, TLogits, TTargets >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type used for tensors, selected based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 CrossEntropy (const std::string &device_name, const CrossEntropyConfig &config)
 Constructs a new CrossEntropy module with a device name.
 
 CrossEntropy (std::shared_ptr< DeviceContext > device_context, const CrossEntropyConfig &config)
 Constructs a new CrossEntropy module with a provided device context.
 
void backward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, const Tensor< TLogits, MR > &output_grad, Tensor< TLogits, MR > &input_grad)
 Calculates gradients for the backward pass.
 
void forward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, Tensor< TLogits, MR > &output)
 Performs the forward pass of the cross entropy operation.
 
std::shared_ptr< Tensor< TLogits, MR > > getClassWeights () const
 Gets the class weights tensor.
 
float getLabelSmoothing () const
 Gets the label smoothing factor.
 
int64_t getPaddingIndex () const
 Gets the padding index.
 
int64_t getVocabSize () const
 Gets the vocabulary size.
 
bool ignorePadding () const
 Checks if padding is ignored.
 
bool isReduced () const
 Checks if loss is reduced.
 
void load (ModelArchive &archive) override
 Deserializes the module state from a ZIP archive.
 
size_t parameterCount () const override
 Gets the number of trainable parameters in this module.
 
void save (ModelArchive &archive) const override
 Serializes the module state to a ZIP archive.
 
std::string toString () const override
 Generates a string representation of this module's configuration.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 
virtual void setTraining (bool is_training)
 Set the training mode of this module.
 

Private Member Functions

void createOperation ()
 Creates the appropriate cross entropy operation for the current device.
 
void initializeClassWeights (const std::vector< float > &weights)
 Initializes the class weights tensor from a vector of weights.
 

Private Attributes

OperationAttributes attributes_
 Operation attributes and configuration.
 
std::shared_ptr< Tensor< TLogits, MR > > class_weights_ { nullptr }
 Optional tensor containing weights for each class.
 
CrossEntropyConfig config_
 Configuration for the CrossEntropy module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TLogits, TTargets > > operation_ { nullptr }
 The operation that implements the cross entropy calculation.
 
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > output_state_
 Collection of output state tensors for caching.
 
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > parameters_
 Collection of parameters for this module.
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
requires ValidFloatTensorType<TLogits>
class Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >

CrossEntropy loss module for neural networks.

This class implements the cross entropy loss function, which is commonly used in classification tasks. It computes the negative log likelihood of the correct class given the predicted probabilities.

The cross entropy loss for a single example is calculated as: -log(p_i) where p_i is the predicted probability for the correct class i.

For multi-class problems with K classes, the formula is: L(y, ?) = -?(y_k * log(?_k)) for k=1 to K where y is the ground truth (one-hot) and ? is the predicted probability distribution.

Features supported by this implementation:

  • Class weighting for imbalanced datasets
  • Padding index ignoring for variable-length sequences
  • Label smoothing for regularization
  • Optional reduction (mean or none)
Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which to perform computations.
TLogitsThe data type of the predicted probabilities (typically float).
TTargetsThe data type of the target class indices (typically int).

Member Typedef Documentation

◆ ModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
using Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::ModuleBase = Module<TDeviceType, TLogits, TTargets>
export

Alias for base module type.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
using Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>
export

Memory resource type used for tensors, selected based on device type.

Constructor & Destructor Documentation

◆ CrossEntropy() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::CrossEntropy ( const std::string &  device_name,
const CrossEntropyConfig config 
)
inlineexplicitexport

Constructs a new CrossEntropy module with a device name.

Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
configConfiguration parameters for the CrossEntropy module.
Exceptions
std::invalid_argumentIf the device name is invalid or the configuration is invalid
std::runtime_errorIf device type doesn't match template parameter TDeviceType
Here is the call graph for this function:

◆ CrossEntropy() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::CrossEntropy ( std::shared_ptr< DeviceContext device_context,
const CrossEntropyConfig config 
)
inlineexplicitexport

Constructs a new CrossEntropy module with a provided device context.

Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.

Parameters
device_contextThe device context to use for this module.
configConfiguration parameters for the CrossEntropy module.
Exceptions
std::invalid_argumentIf device_context is null or configuration is invalid
std::runtime_errorIf device context type doesn't match template parameter TDeviceType
Here is the call graph for this function:

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::backward ( const Tensor< TLogits, MR > &  input,
const Tensor< TTargets, MR > &  targets,
const Tensor< TLogits, MR > &  output_grad,
Tensor< TLogits, MR > &  input_grad 
)
inlineexport

Calculates gradients for the backward pass.

Computes the gradient of the cross entropy loss with respect to the input logits. This gradient is used during backpropagation to update network weights.

Parameters
inputThe input tensor from the forward pass (logits).
targetsThe target tensor containing class indices.
output_gradThe gradient of loss with respect to the output.
input_gradThe tensor to store gradients with respect to input.

◆ createOperation()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::createOperation ( )
inlineexportprivate

Creates the appropriate cross entropy operation for the current device.

Instantiates either a CPU or CUDA cross entropy operation based on the device type. Sets operation attributes from the configuration object.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ forward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::forward ( const Tensor< TLogits, MR > &  input,
const Tensor< TTargets, MR > &  targets,
Tensor< TLogits, MR > &  output 
)
inlineexport

Performs the forward pass of the cross entropy operation.

Computes the cross entropy loss between the predicted logits and target indices. The operation applies any configured options such as class weighting, padding index ignoring, and label smoothing.

Parameters
inputThe input tensor containing predicted logits.
targetsThe target tensor containing class indices.
outputThe output tensor that will contain the loss value(s).

◆ getClassWeights()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::shared_ptr< Tensor< TLogits, MR > > Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::getClassWeights ( ) const
inlineexport

Gets the class weights tensor.

Returns
std::shared_ptr<Tensor<TLogits, MR>> The class weights tensor or nullptr if not used.

◆ getLabelSmoothing()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
float Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::getLabelSmoothing ( ) const
inlineexport

Gets the label smoothing factor.

Returns
float The label smoothing factor (between 0 and 1).
Here is the call graph for this function:

◆ getPaddingIndex()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
int64_t Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::getPaddingIndex ( ) const
inlineexport

Gets the padding index.

Returns
int64_t The index value that represents padding.
Here is the call graph for this function:

◆ getVocabSize()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
int64_t Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::getVocabSize ( ) const
inlineexport

Gets the vocabulary size.

Returns
int64_t The number of possible classes.
Here is the call graph for this function:

◆ ignorePadding()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
bool Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::ignorePadding ( ) const
inlineexport

Checks if padding is ignored.

Returns
bool True if padding indices are ignored, false otherwise.
Here is the call graph for this function:

◆ initializeClassWeights()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::initializeClassWeights ( const std::vector< float > &  weights)
inlineexportprivate

Initializes the class weights tensor from a vector of weights.

Creates and populates a tensor with class weights for weighted cross entropy loss. This is useful for handling imbalanced datasets where some classes are underrepresented.

Parameters
weightsVector of weight values, one per class
Here is the call graph for this function:
Here is the caller graph for this function:

◆ isReduced()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
bool Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::isReduced ( ) const
inlineexport

Checks if loss is reduced.

Returns
bool True if loss is reduced (mean), false if per-sample losses are returned.
Here is the call graph for this function:

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::load ( ModelArchive archive)
inlineoverrideexportvirtual

Deserializes the module state from a ZIP archive.

Loads the class weights tensor (if expected) from the provided ZIP archive.

Parameters
zipThe ZIP archive to load the module state from.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
size_t Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::parameterCount ( ) const
inlineoverrideexportvirtual

Gets the number of trainable parameters in this module.

Returns the number of elements in the class weights tensor if present, otherwise returns 0 since CrossEntropy doesn't have trainable parameters.

Returns
size_t The total number of parameters.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
void Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::save ( ModelArchive archive) const
inlineoverrideexportvirtual

Serializes the module state to a ZIP archive.

Saves the class weights tensor (if present) to the provided ZIP archive.

Parameters
zipThe ZIP archive to save the module state to.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::string Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::toString ( ) const
inlineoverrideexportvirtual

Generates a string representation of this module's configuration.

Returns
std::string A formatted string with module information

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Here is the call graph for this function:

Member Data Documentation

◆ attributes_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
OperationAttributes Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::attributes_
exportprivate

Operation attributes and configuration.

Contains settings for the CrossEntropy operation like padding index, reduction mode, etc.

◆ class_weights_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::shared_ptr<Tensor<TLogits, MR> > Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::class_weights_ { nullptr }
exportprivate

Optional tensor containing weights for each class.

Used to address class imbalance by giving different importance to different classes.

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
CrossEntropyConfig Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::config_
exportprivate

Configuration for the CrossEntropy module.

◆ operation_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::shared_ptr<UnaryOperation<TDeviceType, TLogits, TTargets> > Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::operation_ { nullptr }
exportprivate

The operation that implements the cross entropy calculation.

◆ output_state_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::vector<std::shared_ptr<Tensor<TLogits, MR> > > Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::output_state_
exportprivate

Collection of output state tensors for caching.

Stores intermediate results from forward pass needed for backward pass.

◆ parameters_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TLogits = float, typename TTargets = int>
std::vector<std::shared_ptr<Tensor<TLogits, MR> > > Mila::Dnn::CrossEntropy< TDeviceType, TLogits, TTargets >::parameters_
exportprivate

Collection of parameters for this module.

Only contains class_weights_ if present, otherwise empty.


The documentation for this class was generated from the following file: