Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision > Class Template Referenceexport

Fused SoftmaxCrossEntropy loss module (device-templated). More...

Inheritance diagram for Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >:
Collaboration diagram for Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >:

Public Types

using ExecutionContextType = ExecutionContext<TDeviceType>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TargetTensorType = Tensor<TTargets, MR>
using TensorType = Tensor<TPrecision, MR>

Public Member Functions

 SoftmaxCrossEntropy (IExecutionContext *exec_context, const CrossEntropyConfig &config)
 Construct with an existing execution context.
 ~SoftmaxCrossEntropy () override=default
void backward (const ITensor &logits, const ITensor &targets, const ITensor &output_grad, ITensor &logits_grad)
 Backward pass - delegates to backend operation.
void forward (const ITensor &logits, const ITensor &targets, ITensor &output)
 Forward pass - delegates to backend operation.
const CrossEntropyConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
int64_t getVocabSize () const
void onBuilding (const shape_t &input_shape) override
 Build the module using an input shape.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual MemoryStats getMemoryStats () const=0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual const ComponentType getType () const=0
 Get the component type identifier.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Protected Member Functions

void onTrainingChanging (bool newMode) override
 Hook invoked when training mode is about to change.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
virtual void onExecutionContextSet ()
 Lifecycle hook: Called immediately after ExecutionContext is set.
virtual void onTrainingModeChanging (TrainingMode mode)
 Hook called before TrainingMode transitions.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Member Functions

void createOperation ()
 Create the backend compute operation.
void validateInputShape (const ITensor &input) const
 Validate input shape for fused softmax+cross-entropy operation.
void validateInputShape (const shape_t &input_shape) const
 Validate input shape for fused softmax+cross-entropy operation.

Private Attributes

CrossEntropyConfig config_
std::shared_ptr< TargetTensorTypedummy_target_grad_ { nullptr }
IExecutionContextexec_context_ { nullptr }
std::unique_ptr< BinaryOperation< TDeviceType, TLogits, TTargets, TPrecision > > operation_ { nullptr }

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TLogits >
BuildContext build_context_
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >

Fused SoftmaxCrossEntropy loss module (device-templated).

Delegates computation to a device-specific UnaryOperation implementation registered in the OperationRegistry.

Constructor & Destructor Documentation

◆ SoftmaxCrossEntropy()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::SoftmaxCrossEntropy ( IExecutionContext * exec_context,
const CrossEntropyConfig & config )
inlineexplicitexport

Construct with an existing execution context.

Parameters
exec_contextShared execution context for device resources.
configCrossEntropy configuration (vocab_size required).
Here is the call graph for this function:

◆ ~SoftmaxCrossEntropy()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::~SoftmaxCrossEntropy ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::backward ( const ITensor & logits,
const ITensor & targets,
const ITensor & output_grad,
ITensor & logits_grad )
inlineexport

Backward pass - delegates to backend operation.

Computes fused gradient: dL/dlogits = softmax(logits) - one_hot(targets)

Here is the call graph for this function:

◆ createOperation()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::createOperation ( )
inlineexportprivate

Create the backend compute operation.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ forward()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::forward ( const ITensor & logits,
const ITensor & targets,
ITensor & output )
inlineexport

Forward pass - delegates to backend operation.

Computes fused softmax + cross-entropy loss.

Here is the call graph for this function:

◆ getConfig()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
const CrossEntropyConfig & Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::getConfig ( ) const
inlineexportnoexcept

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
DeviceId Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::getDeviceId ( ) const
inlineoverrideexportvirtual

Get the compute device id associated with this component.

Must return the device on which parameters and operations execute.

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
std::vector< ITensor * > Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::getGradients ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter gradient tensors.

Only valid when isTraining() is true.

Exceptions
std::runtime_errorif called when not in training mode or before the component has been built.

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
std::vector< ITensor * > Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::getParameters ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter tensors.

The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

◆ getVocabSize()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
int64_t Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::getVocabSize ( ) const
inlineexport

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::onBuilding ( const shape_t & input_shape)
inlineoverrideexport

Build the module using an input shape.

Validates input shape and triggers backend-specific setup. The fused operation has no trainable parameters.

Here is the call graph for this function:

◆ onTrainingChanging()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::onTrainingChanging ( bool newMode)
inlineoverrideexportprotected

Hook invoked when training mode is about to change.

Propagate training mode to the backend fused operation. Called with Module's training mutex held; do not call setTraining() here.

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
size_t Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::parameterCount ( ) const
inlineoverrideexportvirtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::synchronize ( )
inlineoverrideexportvirtual

Wait for outstanding device work submitted by this component.

On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

◆ toString()

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
std::string Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::toString ( ) const
inlineoverrideexportvirtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implements Mila::Dnn::Component< TDeviceType, TLogits >.

Here is the call graph for this function:

◆ validateInputShape() [1/2]

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::validateInputShape ( const ITensor & input) const
inlineexportprivate

Validate input shape for fused softmax+cross-entropy operation.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ validateInputShape() [2/2]

template<DeviceType TDeviceType, TensorDataType TLogits, TensorDataType TTargets = dtype_t::INT32, TensorDataType TPrecision = TLogits>
void Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineexportprivate

Validate input shape for fused softmax+cross-entropy operation.

Expected shapes:

  • Input logits: [B, S, V] or [B, V]
  • Targets: [B, S] or [B]
Here is the call graph for this function:

The documentation for this class was generated from the following file: