Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::CrossEntropyConfig Class Referenceexport

Configuration class for CrossEntropy module. More...

Inheritance diagram for Mila::Dnn::CrossEntropyConfig:
Collaboration diagram for Mila::Dnn::CrossEntropyConfig:

Public Member Functions

 CrossEntropyConfig (int64_t vocab_size)
 Constructor with required vocabulary size parameter.
 
const std::vector< float > & getClassWeights () const
 Get the class weights.
 
float getLabelSmoothing () const
 Get the label smoothing factor.
 
int64_t getPaddingIndex () const
 Get the padding index.
 
bool getReduction () const
 Check if loss should be reduced.
 
int64_t getVocabSize () const
 Get the vocabulary size.
 
bool ignorePadding () const
 Check if padding should be ignored.
 
void validate () const
 Validate configuration parameters.
 
CrossEntropyConfigwithClassWeights (const std::vector< float > &weights)
 Set class weights for weighted cross entropy.
 
CrossEntropyConfigwithIgnorePadding (bool ignore_pad)
 Configure whether to ignore padding index.
 
CrossEntropyConfigwithLabelSmoothing (float smoothing)
 Configure whether to apply label smoothing.
 
CrossEntropyConfigwithPaddingIndex (int64_t pad_idx)
 Set the padding index to ignore.
 
CrossEntropyConfigwithReduction (bool reduce)
 Configure whether to reduce the loss.
 
- Public Member Functions inherited from Mila::Dnn::ComponentConfig
virtual ~ComponentConfig ()=default
 Virtual destructor to support proper polymorphic destruction.
 
const std::string & getName () const
 Gets the configured component name.
 
ComputePrecision::Policy getPrecision () const
 Gets the configured precision policy.
 
bool isTraining () const
 Gets the configured training mode.
 
template<typename Self >
auto & withName (this Self &&self, std::string name)
 Sets the name of the component with fluent interface.
 
template<typename Self >
auto & withPrecision (this Self &&self, ComputePrecision::Policy policy)
 Sets the compute precision policy with fluent interface.
 
template<typename Self >
auto & withTraining (this Self &&self, bool is_training)
 Sets the training mode with fluent interface.
 

Private Attributes

std::vector< float > class_weights_
 
bool ignore_padding_ = false
 
float label_smoothing_ = 0.0f
 
int64_t padding_idx_ = -1
 
bool reduce_ = true
 
int64_t vocab_size_
 

Additional Inherited Members

- Protected Attributes inherited from Mila::Dnn::ComponentConfig
bool is_training_ = false
 Training mode flag, defaults to false (inference mode)
 
std::string name_ = "unnamed"
 Component name, defaults to "unnamed" if not explicitly set.
 
ComputePrecision::Policy precision_ = ComputePrecision::Policy::Auto
 Precision policy for computation, defaults to Auto.
 

Detailed Description

Configuration class for CrossEntropy module.

Constructor & Destructor Documentation

◆ CrossEntropyConfig()

Mila::Dnn::CrossEntropyConfig::CrossEntropyConfig ( int64_t  vocab_size)
inlineexplicit

Constructor with required vocabulary size parameter.

Parameters
vocab_sizeThe size of the vocabulary (number of possible classes)

Member Function Documentation

◆ getClassWeights()

const std::vector< float > & Mila::Dnn::CrossEntropyConfig::getClassWeights ( ) const
inline

Get the class weights.

Here is the caller graph for this function:

◆ getLabelSmoothing()

float Mila::Dnn::CrossEntropyConfig::getLabelSmoothing ( ) const
inline

Get the label smoothing factor.

Here is the caller graph for this function:

◆ getPaddingIndex()

int64_t Mila::Dnn::CrossEntropyConfig::getPaddingIndex ( ) const
inline

Get the padding index.

Here is the caller graph for this function:

◆ getReduction()

bool Mila::Dnn::CrossEntropyConfig::getReduction ( ) const
inline

Check if loss should be reduced.

Here is the caller graph for this function:

◆ getVocabSize()

int64_t Mila::Dnn::CrossEntropyConfig::getVocabSize ( ) const
inline

Get the vocabulary size.

Here is the caller graph for this function:

◆ ignorePadding()

bool Mila::Dnn::CrossEntropyConfig::ignorePadding ( ) const
inline

Check if padding should be ignored.

Here is the caller graph for this function:

◆ validate()

void Mila::Dnn::CrossEntropyConfig::validate ( ) const
inlinevirtual

Validate configuration parameters.

Exceptions
std::invalid_argumentIf validation fails

Reimplemented from Mila::Dnn::ComponentConfig.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ withClassWeights()

CrossEntropyConfig & Mila::Dnn::CrossEntropyConfig::withClassWeights ( const std::vector< float > &  weights)
inline

Set class weights for weighted cross entropy.

Parameters
weightsVector of weights for each class
Returns
CrossEntropyConfig& Reference to this for method chaining

◆ withIgnorePadding()

CrossEntropyConfig & Mila::Dnn::CrossEntropyConfig::withIgnorePadding ( bool  ignore_pad)
inline

Configure whether to ignore padding index.

When true, targets with the specified padding index will not contribute to the loss.

Parameters
ignore_padEnable padding index ignoring
Returns
CrossEntropyConfig& Reference to this for method chaining

◆ withLabelSmoothing()

CrossEntropyConfig & Mila::Dnn::CrossEntropyConfig::withLabelSmoothing ( float  smoothing)
inline

Configure whether to apply label smoothing.

Parameters
smoothingLabel smoothing factor (0.0 to 1.0)
Returns
CrossEntropyConfig& Reference to this for method chaining

◆ withPaddingIndex()

CrossEntropyConfig & Mila::Dnn::CrossEntropyConfig::withPaddingIndex ( int64_t  pad_idx)
inline

Set the padding index to ignore.

Parameters
pad_idxThe padding index value to ignore in loss calculation
Returns
CrossEntropyConfig& Reference to this for method chaining

◆ withReduction()

CrossEntropyConfig & Mila::Dnn::CrossEntropyConfig::withReduction ( bool  reduce)
inline

Configure whether to reduce the loss.

When true, returns the mean of losses. When false, returns per-sample losses.

Parameters
reduceWhether to average the loss
Returns
CrossEntropyConfig& Reference to this for method chaining

Member Data Documentation

◆ class_weights_

std::vector<float> Mila::Dnn::CrossEntropyConfig::class_weights_
private

◆ ignore_padding_

bool Mila::Dnn::CrossEntropyConfig::ignore_padding_ = false
private

◆ label_smoothing_

float Mila::Dnn::CrossEntropyConfig::label_smoothing_ = 0.0f
private

◆ padding_idx_

int64_t Mila::Dnn::CrossEntropyConfig::padding_idx_ = -1
private

◆ reduce_

bool Mila::Dnn::CrossEntropyConfig::reduce_ = true
private

◆ vocab_size_

int64_t Mila::Dnn::CrossEntropyConfig::vocab_size_
private

The documentation for this class was generated from the following file: