Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Loss< TDeviceType, TPrecision > Class Template Referenceexport
module Dnn.Loss

Abstract base class for neural network loss functions. More...

Inheritance diagram for Mila::Dnn::Loss< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::Loss< TDeviceType, TPrecision >:

Public Types

using ExecutionContextType = ExecutionContext<TDeviceType>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TensorType = Tensor<TPrecision, MR>

Public Member Functions

virtual ~Loss ()=default
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual DeviceId getDeviceId () const =0
 Get the compute device id associated with this component.
virtual std::vector< ITensor * > getGradients () const =0
 Return non-owning pointers to parameter gradient tensors.
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
virtual std::vector< ITensor * > getParameters () const =0
 Return non-owning pointers to parameter tensors.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual const ComponentType getType () const =0
 Get the component type identifier.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
virtual size_t parameterCount () const =0
 Return number of trainable parameters.
virtual void save_ (ModelArchive &archive, SerializationMode mode) const =0
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void synchronize ()=0
 Wait for outstanding device work submitted by this component.
virtual std::string toString () const =0
 Produce a short, human-readable description of the component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
virtual void onExecutionContextSet ()
 Lifecycle hook: Called immediately after ExecutionContext is set.
virtual void onTrainingModeChanging (TrainingMode mode)
 Hook called before TrainingMode transitions.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Detailed Description

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Loss< TDeviceType, TPrecision >

Abstract base class for neural network loss functions.

Template Parameters
TDeviceTypeCompile-time device identifier for this loss.
TPrecisionData type used for computations.

Loss functions compute a scalar loss value given model predictions and target values. They may also provide hooks for optimizing the network graph and configuring reduction modes.

Member Typedef Documentation

◆ ExecutionContextType

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Loss< TDeviceType, TPrecision >::ExecutionContextType = ExecutionContext<TDeviceType>

◆ MR

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Loss< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource

◆ TensorType

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Loss< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR>

Constructor & Destructor Documentation

◆ ~Loss()

template<Compute::DeviceType TDeviceType, TensorDataType TPrecision>
virtual Mila::Dnn::Loss< TDeviceType, TPrecision >::~Loss ( )
virtualdefault

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Core/Loss.ixx