|
Mila 0.13.48
Deep Neural Network Library
|
Fused SoftmaxCrossEntropy loss module (device-templated). More...


Public Types | |
| using | ExecutionContextType = ExecutionContext<TDeviceType> |
| using | MR = typename DeviceTypeTraits<TDeviceType>::memory_resource |
| using | TargetTensorType = Tensor<TTargets, MR> |
| using | TensorType = Tensor<TPrecision, MR> |
Public Member Functions | |
| SoftmaxCrossEntropy (IExecutionContext *exec_context, const CrossEntropyConfig &config) | |
| Construct with an existing execution context. | |
| ~SoftmaxCrossEntropy () override=default | |
| void | backward (const ITensor &logits, const ITensor &targets, const ITensor &output_grad, ITensor &logits_grad) |
| Backward pass - delegates to backend operation. | |
| void | forward (const ITensor &logits, const ITensor &targets, ITensor &output) |
| Forward pass - delegates to backend operation. | |
| const CrossEntropyConfig & | getConfig () const noexcept |
| DeviceId | getDeviceId () const override |
| Get the compute device id associated with this component. | |
| std::vector< ITensor * > | getGradients () const override |
| Return non-owning pointers to parameter gradient tensors. | |
| std::vector< ITensor * > | getParameters () const override |
| Return non-owning pointers to parameter tensors. | |
| int64_t | getVocabSize () const |
| void | onBuilding (const shape_t &input_shape) override |
| Build the module using an input shape. | |
| size_t | parameterCount () const override |
| Return number of trainable parameters. | |
| void | save_ (ModelArchive &archive, SerializationMode mode) const override |
| void | synchronize () override |
| Wait for outstanding device work submitted by this component. | |
| std::string | toString () const override |
| Produce a short, human-readable description of the component. | |
| Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits > | |
| Component (const std::string &name) | |
| Construct component with required name identifier. | |
| virtual | ~Component ()=default |
| virtual void | build (const BuildContext &context) final |
| Build the component with the provided BuildContext (canonical overload). | |
| virtual MemoryStats | getMemoryStats () const=0 |
| Return the current memory allocation breakdown for this component. | |
| const std::string | getName () const |
| Get the component's name identifier. | |
| virtual std::vector< std::string > | getParameterNames () const |
| List all available parameter names for this component. | |
| RuntimeMode | getRuntimeMode () const noexcept |
| Convenience accessor — true if currently in Eval mode. | |
| TrainingMode | getTrainingMode () const noexcept |
| The current runtime behavioral mode of this Component. | |
| virtual const ComponentType | getType () const=0 |
| Get the component type identifier. | |
| virtual bool | isBuilt () const final |
| Returns true if build() has completed successfully. | |
| bool | isInferenceMode () const noexcept |
| bool | isTrainingMode () const noexcept |
| virtual void | loadParameter (const std::string &name, const Serialization::ITensorBlob &blob) |
| Load a parameter from serialized tensor data. | |
| void | setTrainingMode (TrainingMode mode) |
| Set the runtime behavioral mode for this Component. | |
| virtual void | zeroGradients () |
| Clear all model-owned gradients for this component. | |
Protected Member Functions | |
| void | onTrainingChanging (bool newMode) override |
| Hook invoked when training mode is about to change. | |
| Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits > | |
| IExecutionContext * | getExecutionContext () const |
| Get the shared execution context. | |
| bool | hasExecutionContext () const noexcept |
| Check if execution context has been set. | |
| void | loadParameterFromBlob (const std::string ¶m_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape) |
| Load a tensor blob into a parameter tensor with validation. | |
| virtual void | onBuilding (const BuildContext &config) |
| Hook invoked by build() to allocate component buffers. | |
| virtual void | onExecutionContextSet () |
| Lifecycle hook: Called immediately after ExecutionContext is set. | |
| virtual void | onTrainingModeChanging (TrainingMode mode) |
| Hook called before TrainingMode transitions. | |
| void | setExecutionContext (IExecutionContext *context) |
| Set the execution context for this component. | |
Private Member Functions | |
| void | createOperation () |
| Create the backend compute operation. | |
| void | validateInputShape (const ITensor &input) const |
| Validate input shape for fused softmax+cross-entropy operation. | |
| void | validateInputShape (const shape_t &input_shape) const |
| Validate input shape for fused softmax+cross-entropy operation. | |
Private Attributes | |
| CrossEntropyConfig | config_ |
| std::shared_ptr< TargetTensorType > | dummy_target_grad_ { nullptr } |
| IExecutionContext * | exec_context_ { nullptr } |
| std::unique_ptr< BinaryOperation< TDeviceType, TLogits, TTargets, TPrecision > > | operation_ { nullptr } |
Additional Inherited Members | |
| Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TLogits > | |
| static constexpr DeviceType | getDeviceType () |
| Compile-time device type for this component instance. | |
| static constexpr TensorDataType | getPrecision () noexcept |
| Compile-time tensor precision for this component instance. | |
| Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TLogits > | |
| BuildContext | build_context_ |
| The BuildContext stored at build time. | |
Fused SoftmaxCrossEntropy loss module (device-templated).
Delegates computation to a device-specific UnaryOperation implementation registered in the OperationRegistry.
|
inlineexplicitexport |
Construct with an existing execution context.
| exec_context | Shared execution context for device resources. |
| config | CrossEntropy configuration (vocab_size required). |

|
overrideexportdefault |
|
inlineexport |
Backward pass - delegates to backend operation.
Computes fused gradient: dL/dlogits = softmax(logits) - one_hot(targets)

|
inlineexportprivate |
Create the backend compute operation.


|
inlineexport |
Forward pass - delegates to backend operation.
Computes fused softmax + cross-entropy loss.

|
inlineexportnoexcept |
|
inlineoverrideexportvirtual |
Get the compute device id associated with this component.
Must return the device on which parameters and operations execute.
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineoverrideexportvirtual |
Return non-owning pointers to parameter gradient tensors.
Only valid when isTraining() is true.
| std::runtime_error | if called when not in training mode or before the component has been built. |
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineoverrideexportvirtual |
Return non-owning pointers to parameter tensors.
The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineexport |
|
inlineoverrideexport |
Build the module using an input shape.
Validates input shape and triggers backend-specific setup. The fused operation has no trainable parameters.

|
inlineoverrideexportprotected |
Hook invoked when training mode is about to change.
Propagate training mode to the backend fused operation. Called with Module's training mutex held; do not call setTraining() here.
|
inlineoverrideexportvirtual |
Return number of trainable parameters.
For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineoverrideexportvirtual |
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineoverrideexportvirtual |
Wait for outstanding device work submitted by this component.
On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.
Implements Mila::Dnn::Component< TDeviceType, TLogits >.
|
inlineoverrideexportvirtual |
Produce a short, human-readable description of the component.
Implementations should keep output concise and avoid throwing.
Implements Mila::Dnn::Component< TDeviceType, TLogits >.

|
inlineexportprivate |
Validate input shape for fused softmax+cross-entropy operation.


|
inlineexportprivate |
Validate input shape for fused softmax+cross-entropy operation.
Expected shapes:
