|
Mila 0.13.48
Deep Neural Network Library
|
Fused CUDA implementation of Softmax + CrossEntropy using abstract TensorDataType API. More...


Public Types | |
| using | BinaryOperationBase = BinaryOperation<DeviceType::Cuda, TPrecision, TLogits, TTargets> |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | LogitsTensorType = Tensor<TLogits, MR> |
| using | MR = CudaDeviceMemoryResource |
| using | NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TLogits>::device_type |
| using | TargetsNativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TTargets>::device_type |
| using | TargetsTensorType = Tensor<TTargets, MR> |
| Public Types inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 > | |
| using | MR |
| using | ParameterGradTensor |
| using | ParameterTensor |
| using | TensorLeftType |
| using | TensorOutputType |
| using | TensorRightType |
| Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaSoftmaxCrossEntropyOp (IExecutionContext *context, const CrossEntropyConfig &config) | |
| Construct fused Softmax+CrossEntropy operation with execution context. | |
| void | backward (const ITensor &logits, const ITensor &targets, const ITensor &loss_grad, ITensor &logits_grad, ITensor &targets_grad) const override |
| Backward pass - HOT PATH, computes fused gradient. | |
| void | build (const BuildContext &config) override |
| Build the operation for a concrete input shape. | |
| void | forward (const ITensor &logits, const ITensor &targets, ITensor &output) const override |
| Forward pass - HOT PATH, computes fused softmax+cross-entropy loss. | |
| const CrossEntropyConfig & | getConfig () const |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| void | setParameters (ITensor *class_weights, ITensor *bias) override |
| Bind optional class weights parameter. | |
| Public Member Functions inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 > | |
| virtual | ~BinaryOperation ()=default |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual std::size_t | getStateMemorySize () const |
| Returns the number of bytes of state memory allocated by this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setGradients (ITensor *weight_grad, ITensor *bias_grad) |
| Bind module-owned gradient tensors to the operation. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
Private Attributes | |
| int | cached_batch_size_ { 0 } |
| std::shared_ptr< LogitsTensorType > | cached_probs_ |
| int | cached_seq_len_ { 0 } |
| cudaStream_t | cached_stream_ { nullptr } |
| int | cached_vocab_size_ { 0 } |
| const NativeType * | class_weights_ { nullptr } |
| CrossEntropyConfig | config_ |
| CudaExecutionContext * | context_ |
Additional Inherited Members | |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
| Static Protected Member Functions inherited from Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 > | |
| static const TensorLeftType & | asLeftTensor (const ITensor &t) |
| static TensorOutputType & | asOutputTensor (ITensor &t) |
| static const TensorRightType & | asRightTensor (const ITensor &t) |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
Fused CUDA implementation of Softmax + CrossEntropy using abstract TensorDataType API.
This operation combines softmax normalization and cross-entropy loss computation into a single numerically stable binary operation (logits + targets ? loss).
Key properties:
Design philosophy:
| TLogitsPrecision | Precision for logits/gradients (FP32, FP16) |
| TTargets | Target indices data type (typically INT32) |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::BinaryOperationBase = BinaryOperation<DeviceType::Cuda, TPrecision, TLogits, TTargets> |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::LogitsTensorType = Tensor<TLogits, MR> |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TLogits>::device_type |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::TargetsNativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TTargets>::device_type |
| using Mila::Dnn::Compute::Cuda::SoftmaxCrossEntropy::CudaSoftmaxCrossEntropyOp< TPrecision, TLogits, TTargets >::TargetsTensorType = Tensor<TTargets, MR> |
|
inline |
Construct fused Softmax+CrossEntropy operation with execution context.
| context | CUDA execution context |
| config | CrossEntropy configuration (vocab_size required) |
|
inlineoverridevirtual |
Backward pass - HOT PATH, computes fused gradient.
Beautiful property of fused softmax+cross-entropy: dL/dlogits = softmax(logits) - one_hot(targets)
Algorithm: For each sample: dL/dlogits[i] = prob[i] - (i == target ? 1 : 0) Scale by output_gradient
| logits | Logits tensor from forward pass |
| targets | Targets tensor from forward pass |
| loss_grad | Gradient w.r.t. loss (per-sample gradients) |
| logits_grad | Output: gradient w.r.t. logits |
| targets_grad | Unused (targets are integers, not differentiable) |
Implements Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >.
|
inlineoverridevirtual |
Build the operation for a concrete input shape.
This is the COLD PATH where all setup, validation, and computation happens ONCE.
Expected input shape: [batch_size, seq_length, vocab_size] or [batch_size, vocab_size] Target shape: [batch_size, seq_length] or [batch_size]
Responsibilities:
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Forward pass - HOT PATH, computes fused softmax+cross-entropy loss.
Computes: loss = -log(softmax(logits)[target])
Algorithm (numerically stable): For each sample:
| inputA | Logits tensor [batch, seq, vocab] |
| inputB | Targets tensor [batch, seq] (integer class indices) |
| output | Loss tensor (per-sample losses [batch, seq]) |
Implements Mila::Dnn::Compute::BinaryOperation< DeviceType::Cuda, TPrecision, TPrecision, TensorDataType::INT32 >.
|
inline |
|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
inlineoverridevirtual |
Bind optional class weights parameter.
| class_weights | Optional class weights tensor (may be null) |
| bias | Unused (must be null) |
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.
|
private |
|
mutableprivate |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |