Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.SoftmaxCrossEntropy Module Reference

Exported Modules

module  Compute.BinaryOperation
module  Dnn.TensorDataType
module  Compute.DeviceTypeTraits
module  Dnn.Component
module  Compute.CpuMemoryResource
module  Compute.MemoryResource
module  Dnn.Components.CrossEntropyConfig
module  Compute.Device
module  Compute.DeviceId
module  Dnn.Tensor
module  Compute.DeviceType
module  Serialization.Mode
module  Compute.ExecutionContext
module  Dnn.ITensor
module  Compute.OperationRegistry
module  Serialization.ModelArchive
module  Dnn.TensorDataTypeTraits
module  Dnn.TensorTypes

Classes

class  Mila::Dnn::SoftmaxCrossEntropy< TDeviceType, TLogits, TTargets, TPrecision >
 Fused SoftmaxCrossEntropy loss module (device-templated). More...

Typedefs

using ExecutionContextType = ExecutionContext<TDeviceType>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TargetTensorType = Tensor<TTargets, MR>
using TensorType = Tensor<TPrecision, MR>

Functions

 SoftmaxCrossEntropy (IExecutionContext *exec_context, const CrossEntropyConfig &config)
 Construct with an existing execution context.
 ~SoftmaxCrossEntropy () override=default
void backward (const ITensor &logits, const ITensor &targets, const ITensor &output_grad, ITensor &logits_grad)
 Backward pass - delegates to backend operation.
void createOperation ()
 Create the backend compute operation.
void forward (const ITensor &logits, const ITensor &targets, ITensor &output)
 Forward pass - delegates to backend operation.
const CrossEntropyConfiggetConfig () const noexcept
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
int64_t getVocabSize () const
void onBuilding (const shape_t &input_shape) override
 Build the module using an input shape.
void onTrainingChanging (bool newMode) override
 Hook invoked when training mode is about to change.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void validateInputShape (const ITensor &input) const
 Validate input shape for fused softmax+cross-entropy operation.
void validateInputShape (const shape_t &input_shape) const
 Validate input shape for fused softmax+cross-entropy operation.

Variables

CrossEntropyConfig config_
std::shared_ptr< TargetTensorTypedummy_target_grad_ { nullptr }
IExecutionContextexec_context_ { nullptr }
std::unique_ptr< BinaryOperation< TDeviceType, TLogits, TTargets, TPrecision > > operation_ { nullptr }

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Losses/SoftmaxCrossEntropy.ixx
 Device-templated fused SoftmaxCrossEntropy loss module.