Mila
Deep Neural Network Library
|
CrossEntropy loss module for neural networks. More...
Public Types | |
using | ModuleBase = Module< TDeviceType, TLogits, TTargets > |
Alias for base module type. | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Memory resource type used for tensors, selected based on device type. | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
CrossEntropy (const std::string &device_name, const CrossEntropyConfig &config) | |
Constructs a new CrossEntropy module with a device name. | |
CrossEntropy (std::shared_ptr< DeviceContext > device_context, const CrossEntropyConfig &config) | |
Constructs a new CrossEntropy module with a provided device context. | |
void | backward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, const Tensor< TLogits, MR > &output_grad, Tensor< TLogits, MR > &input_grad) |
Calculates gradients for the backward pass. | |
void | forward (const Tensor< TLogits, MR > &input, const Tensor< TTargets, MR > &targets, Tensor< TLogits, MR > &output) |
Performs the forward pass of the cross entropy operation. | |
std::shared_ptr< Tensor< TLogits, MR > > | getClassWeights () const |
Gets the class weights tensor. | |
float | getLabelSmoothing () const |
Gets the label smoothing factor. | |
int64_t | getPaddingIndex () const |
Gets the padding index. | |
int64_t | getVocabSize () const |
Gets the vocabulary size. | |
bool | ignorePadding () const |
Checks if padding is ignored. | |
bool | isReduced () const |
Checks if loss is reduced. | |
void | load (ModelArchive &archive) override |
Deserializes the module state from a ZIP archive. | |
size_t | parameterCount () const override |
Gets the number of trainable parameters in this module. | |
void | save (ModelArchive &archive) const override |
Serializes the module state to a ZIP archive. | |
std::string | toString () const override |
Generates a string representation of this module's configuration. | |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
Private Member Functions | |
void | createOperation () |
Creates the appropriate cross entropy operation for the current device. | |
void | initializeClassWeights (const std::vector< float > &weights) |
Initializes the class weights tensor from a vector of weights. | |
Private Attributes | |
OperationAttributes | attributes_ |
Operation attributes and configuration. | |
std::shared_ptr< Tensor< TLogits, MR > > | class_weights_ { nullptr } |
Optional tensor containing weights for each class. | |
CrossEntropyConfig | config_ |
Configuration for the CrossEntropy module. | |
std::shared_ptr< UnaryOperation< TDeviceType, TLogits, TTargets > > | operation_ { nullptr } |
The operation that implements the cross entropy calculation. | |
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > | output_state_ |
Collection of output state tensors for caching. | |
std::vector< std::shared_ptr< Tensor< TLogits, MR > > > | parameters_ |
Collection of parameters for this module. | |
Additional Inherited Members | |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
CrossEntropy loss module for neural networks.
This class implements the cross entropy loss function, which is commonly used in classification tasks. It computes the negative log likelihood of the correct class given the predicted probabilities.
The cross entropy loss for a single example is calculated as: -log(p_i) where p_i is the predicted probability for the correct class i.
For multi-class problems with K classes, the formula is: L(y, ?) = -?(y_k * log(?_k)) for k=1 to K where y is the ground truth (one-hot) and ? is the predicted probability distribution.
Features supported by this implementation:
TDeviceType | The device type (CPU or CUDA) on which to perform computations. |
TLogits | The data type of the predicted probabilities (typically float). |
TTargets | The data type of the target class indices (typically int). |
|
export |
Alias for base module type.
|
export |
Memory resource type used for tensors, selected based on device type.
|
inlineexplicitexport |
Constructs a new CrossEntropy module with a device name.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
config | Configuration parameters for the CrossEntropy module. |
std::invalid_argument | If the device name is invalid or the configuration is invalid |
std::runtime_error | If device type doesn't match template parameter TDeviceType |
|
inlineexplicitexport |
Constructs a new CrossEntropy module with a provided device context.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
device_context | The device context to use for this module. |
config | Configuration parameters for the CrossEntropy module. |
std::invalid_argument | If device_context is null or configuration is invalid |
std::runtime_error | If device context type doesn't match template parameter TDeviceType |
|
inlineexport |
Calculates gradients for the backward pass.
Computes the gradient of the cross entropy loss with respect to the input logits. This gradient is used during backpropagation to update network weights.
input | The input tensor from the forward pass (logits). |
targets | The target tensor containing class indices. |
output_grad | The gradient of loss with respect to the output. |
input_grad | The tensor to store gradients with respect to input. |
|
inlineexportprivate |
Creates the appropriate cross entropy operation for the current device.
Instantiates either a CPU or CUDA cross entropy operation based on the device type. Sets operation attributes from the configuration object.
|
inlineexport |
Performs the forward pass of the cross entropy operation.
Computes the cross entropy loss between the predicted logits and target indices. The operation applies any configured options such as class weighting, padding index ignoring, and label smoothing.
input | The input tensor containing predicted logits. |
targets | The target tensor containing class indices. |
output | The output tensor that will contain the loss value(s). |
|
inlineexport |
Gets the class weights tensor.
|
inlineexport |
Gets the label smoothing factor.
|
inlineexport |
Gets the padding index.
|
inlineexport |
Gets the vocabulary size.
|
inlineexport |
Checks if padding is ignored.
|
inlineexportprivate |
Initializes the class weights tensor from a vector of weights.
Creates and populates a tensor with class weights for weighted cross entropy loss. This is useful for handling imbalanced datasets where some classes are underrepresented.
weights | Vector of weight values, one per class |
|
inlineexport |
Checks if loss is reduced.
|
inlineoverrideexportvirtual |
Deserializes the module state from a ZIP archive.
Loads the class weights tensor (if expected) from the provided ZIP archive.
zip | The ZIP archive to load the module state from. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Gets the number of trainable parameters in this module.
Returns the number of elements in the class weights tensor if present, otherwise returns 0 since CrossEntropy doesn't have trainable parameters.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Serializes the module state to a ZIP archive.
Saves the class weights tensor (if present) to the provided ZIP archive.
zip | The ZIP archive to save the module state to. |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Generates a string representation of this module's configuration.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
exportprivate |
Operation attributes and configuration.
Contains settings for the CrossEntropy operation like padding index, reduction mode, etc.
|
exportprivate |
Optional tensor containing weights for each class.
Used to address class imbalance by giving different importance to different classes.
|
exportprivate |
Configuration for the CrossEntropy module.
|
exportprivate |
The operation that implements the cross entropy calculation.
|
exportprivate |
Collection of output state tensors for caching.
Stores intermediate results from forward pass needed for backward pass.
|
exportprivate |
Collection of parameters for this module.
Only contains class_weights_ if present, otherwise empty.