Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets > Class Template Referenceexport

Fused CUDA implementation of Softmax + CrossEntropy using abstract TensorDataType API. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >:

Public Types

using BinaryOperationBase = BinaryOperation<DeviceType::Cuda, TPrecision, TLogits, TTargets>
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using LogitsTensorType = Tensor<TLogits, MR>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TLogits>::device_type
using TargetsNativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TTargets>::device_type
using TargetsTensorType = Tensor<TTargets, MR>
Public Types inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >
using MR
using ParameterGradTensor
using ParameterTensor
using TensorLeftType
using TensorOutputType
using TensorRightType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
using DataTypeTraits

Public Member Functions

 CudaSoftmaxCrossEntropyOp (IExecutionContext *context, const CrossEntropyConfig &config)
 Construct fused Softmax+CrossEntropy operation with execution context.
void backward (const ITensor &logits, const ITensor &targets, const ITensor &loss_grad, ITensor &logits_grad, ITensor &targets_grad) const override
 Backward pass - HOT PATH, computes fused gradient.
void build (const BuildContext &config) override
 Build the operation for a concrete input shape.
void forward (const ITensor &logits, const ITensor &targets, ITensor &output) const override
 Forward pass - HOT PATH, computes fused softmax+cross-entropy loss.
const CrossEntropyConfiggetConfig () const
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void setParameters (ITensor *class_weights, ITensor *bias) override
 Bind optional class weights parameter.
Public Member Functions inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >
virtual ~BinaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setGradients (ITensor *weight_grad, ITensor *bias_grad)
 Bind module-owned gradient tensors to the operation.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Private Attributes

int cached_batch_size_ { 0 }
std::shared_ptr< LogitsTensorTypecached_probs_
int cached_seq_len_ { 0 }
cudaStream_t cached_stream_ { nullptr }
int cached_vocab_size_ { 0 }
const NativeTypeclass_weights_ { nullptr }
CrossEntropyConfig config_
CudaExecutionContextcontext_

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >
static const TensorLeftTypeasLeftTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
static const TensorRightTypeasRightTensor (const ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
requires PrecisionSupportedOnDevice<TPrecision, DeviceType::Cuda> && PrecisionSupportedOnDevice<TLogits, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >

Fused CUDA implementation of Softmax + CrossEntropy using abstract TensorDataType API.

This operation combines softmax normalization and cross-entropy loss computation into a single numerically stable binary operation (logits + targets ? loss).

Key properties:

  1. Numerical Stability: Uses log-sum-exp trick to avoid overflow/underflow
  2. Performance: Single GPU kernel pass, no intermediate probability tensor in API
  3. Simplified Gradient: dL/dlogits = softmax(logits) - one_hot(targets)
  4. Memory Efficiency: Probabilities cached internally for backward pass

Design philosophy:

  • Two-phase initialization: build() does all setup, forward()/backward() are pure dispatch
  • Internal state management: Probabilities cached as mutable private member
  • Forward computes loss directly from logits
  • Backward uses cached probabilities from forward pass
  • All dimension computation happens once in build()
Template Parameters
TLogitsPrecisionPrecision for logits/gradients (FP32, FP16)
TTargetsTarget indices data type (typically INT32)

Member Typedef Documentation

◆ BinaryOperationBase

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::BinaryOperationBase = BinaryOperation<DeviceType::Cuda, TPrecision, TLogits, TTargets>

◆ CudaExecutionContext

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ LogitsTensorType

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::LogitsTensorType = Tensor<TLogits, MR>

◆ MR

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::MR = CudaDeviceMemoryResource

◆ NativeType

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TLogits>::device_type

◆ TargetsNativeType

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::TargetsNativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TTargets>::device_type

◆ TargetsTensorType

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::TargetsTensorType = Tensor<TTargets, MR>

Constructor & Destructor Documentation

◆ CudaSoftmaxCrossEntropyOp()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::CudaSoftmaxCrossEntropyOp ( IExecutionContext * context,
const CrossEntropyConfig & config )
inline

Construct fused Softmax+CrossEntropy operation with execution context.

Parameters
contextCUDA execution context
configCrossEntropy configuration (vocab_size required)

Member Function Documentation

◆ backward()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
void Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::backward ( const ITensor & logits,
const ITensor & targets,
const ITensor & loss_grad,
ITensor & logits_grad,
ITensor & targets_grad ) const
inlineoverridevirtual

Backward pass - HOT PATH, computes fused gradient.

Beautiful property of fused softmax+cross-entropy: dL/dlogits = softmax(logits) - one_hot(targets)

Algorithm: For each sample: dL/dlogits[i] = prob[i] - (i == target ? 1 : 0) Scale by output_gradient

Parameters
logitsLogits tensor from forward pass
targetsTargets tensor from forward pass
loss_gradGradient w.r.t. loss (per-sample gradients)
logits_gradOutput: gradient w.r.t. logits
targets_gradUnused (targets are integers, not differentiable)

Implements Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >.

◆ build()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
void Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::build ( const BuildContext & config)
inlineoverridevirtual

Build the operation for a concrete input shape.

This is the COLD PATH where all setup, validation, and computation happens ONCE.

Expected input shape: [batch_size, seq_length, vocab_size] or [batch_size, vocab_size] Target shape: [batch_size, seq_length] or [batch_size]

Responsibilities:

  1. Validate input shape (rank >= 2, last dim = vocab_size)
  2. Compute and cache dimension sizes (batch, seq, vocab)
  3. Cache CUDA stream
  4. Allocate internal probability cache for backward pass

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ forward()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
void Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::forward ( const ITensor & logits,
const ITensor & targets,
ITensor & output ) const
inlineoverridevirtual

Forward pass - HOT PATH, computes fused softmax+cross-entropy loss.

Computes: loss = -log(softmax(logits)[target])

Algorithm (numerically stable): For each sample:

  1. max_logit = max(logits)
  2. sum_exp = sum(exp(logits - max_logit))
  3. loss = -(logits[target] - max_logit - log(sum_exp))
Parameters
inputALogits tensor [batch, seq, vocab]
inputBTargets tensor [batch, seq] (integer class indices)
outputLoss tensor (per-sample losses [batch, seq])

Implements Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >.

◆ getConfig()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
const CrossEntropyConfig & Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::getConfig ( ) const
inline

◆ getName()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
std::string Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ getOperationType()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
OperationType Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::getOperationType ( ) const
inlineoverridevirtual

◆ setParameters()

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
void Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::setParameters ( ITensor * class_weights,
ITensor * bias )
inlineoverridevirtual

Bind optional class weights parameter.

Parameters
class_weightsOptional class weights tensor (may be null)
biasUnused (must be null)

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

Member Data Documentation

◆ cached_batch_size_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
int Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::cached_batch_size_ { 0 }
private

◆ cached_probs_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
std::shared_ptr<LogitsTensorType> Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::cached_probs_
mutableprivate

◆ cached_seq_len_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
int Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::cached_seq_len_ { 0 }
private

◆ cached_stream_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
cudaStream_t Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::cached_stream_ { nullptr }
private

◆ cached_vocab_size_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
int Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::cached_vocab_size_ { 0 }
private

◆ class_weights_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
const NativeType* Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::class_weights_ { nullptr }
private

◆ config_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
CrossEntropyConfig Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::config_
private

◆ context_

template<TensorDataType TPrecision, TensorDataType TLogits = TPrecision, TensorDataType TTargets = TensorDataType::INT32>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::context_
private

The documentation for this class was generated from the following file: