Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
/__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx

Construct Softmax with optional ExecutionContext ownership.

Construct Softmax with optional ExecutionContext ownership.Supports two construction modes:

Standalone mode (device_id provided):

  • Creates and owns an ExecutionContext for the specified device.
  • Registers the owned context with the base Component class via setExecutionContext().
  • Backend operation is created immediately in onExecutionContextSet() hook.
  • Use case: Unit tests, standalone component usage.

Shared mode (device_id not provided):

  • Does not create ExecutionContext; expects parent to provide one.
  • Parent (Network/CompositeComponent) calls setExecutionContext() after construction.
  • Backend operation created when parent sets context.
  • Use case: Components added to Network via addComponent<Softmax>(...).
Parameters
configSoftmax configuration (axis and name).
device_idOptional device identifier. If provided, creates owned ExecutionContext for standalone mode. If nullopt, expects shared context from parent.
Exceptions
std::invalid_argumentif config is invalid (via config.validate()).
std::invalid_argumentif device_id.type does not match TDeviceType.
std::runtime_errorif ExecutionContext creation fails (standalone mode).
std::runtime_errorif backend operation creation fails in onExecutionContextSet().
Note
In standalone mode, setExecutionContext() is called to register the owned context with the base class, enabling getExecutionContext() and triggering the onExecutionContextSet() hook for operation creation.

// Standalone mode (owns context) SoftmaxConfig config; config.withAxis(-1); Softmax<DeviceType::Cpu, TensorDataType::FP32> softmax(config, Device::Cpu());

module;
#include <memory>
#include <vector>
#include <string>
#include <sstream>
#include <stdexcept>
#include <cstdint>
#include <type_traits>
#include <optional>
export module Dnn.Components.Softmax;
import Dnn.Tensor;
import Dnn.ITensor;
namespace Mila::Dnn
{
using namespace Mila::Dnn::Compute;
using namespace Mila::Dnn::Serialization;
export template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Softmax : public Component<TDeviceType, TPrecision>
{
public:
using ComponentBase = Component<TDeviceType, TPrecision>;
using TensorType = Tensor<TPrecision, MR>;
explicit Softmax( const std::string& name, const SoftmaxConfig& config, std::optional<DeviceId> device_id = std::nullopt )
: ComponentBase( name ), config_( config )
{
config_.validate();
if (device_id.has_value())
{
if (device_id->type != TDeviceType)
{
throw std::invalid_argument( "Softmax: device type mismatch" );
}
owned_exec_context_ = createExecutionContext( device_id.value() );
this->setExecutionContext( owned_exec_context_.get() );
}
}
~Softmax() override = default;
// ====================================================================
// Compute operation dispatch
// ====================================================================
void forward( const ITensor& input, ITensor& output )
{
if (!this->isBuilt())
{
throw std::runtime_error( "Softmax module must be built before calling forward." );
}
operation_->forward( input, output );
}
void backward( const ITensor& input, const ITensor& output_grad, ITensor& input_grad )
{
if (!this->isBuilt())
{
throw std::runtime_error( "Softmax module must be built before calling backward." );
}
if (!this->isTraining())
{
throw std::runtime_error( "Softmax module must be in training mode to call backward. Call setTraining(true) first." );
}
operation_->backward( input, output_grad, input_grad );
}
// ====================================================================
// Synchronization
// ====================================================================
void synchronize() override
{
this->getExecutionContext()->synchronize();
}
// ====================================================================
// Serialization
// ====================================================================
void save_( ModelArchive& archive, SerializationMode mode ) const override
{
(void)archive;
(void)mode;
}
// ====================================================================
// Parameters and Gradients
// ====================================================================
size_t parameterCount() const override
{
return 0;
}
std::vector<ITensor*> getParameters() const override
{
return {};
}
std::vector<ITensor*> getGradients() const override
{
return {};
}
// ====================================================================
// Component interface
// ====================================================================
const ComponentType getType() const override
{
}
DeviceId getDeviceId() const override
{
return this->getExecutionContext()->getDeviceId();
}
MemoryStats getMemoryStats() const override
{
return {};
}
std::string toString() const override
{
std::ostringstream oss;
oss << "--------------------" << std::endl;
oss << "Softmax: " << this->getName() << std::endl;
oss << "Device: " << deviceTypeToString( this->getDeviceType() ) << std::endl;
oss << "Axis: " << config_.getAxis() << std::endl;
oss << "Parameter count: 0 (stateless)" << std::endl;
return oss.str();
}
// ====================================================================
// Configuration accessors
// ====================================================================
int64_t getAxis() const noexcept
{
return config_.getAxis();
}
/*const SoftmaxConfig& getConfig() const noexcept
{
return config_;
}*/
protected:
// ====================================================================
// Lifecycle hooks
// ====================================================================
void onExecutionContextSet() override
{
createOperation();
}
void onBuilding( const BuildContext& build_config ) override
{
const auto& input_shape = build_config.inputShape();
validateInputShape( input_shape );
// REVIEW: Why???
operation_->setParameters( nullptr, nullptr );
operation_->build( build_config );
}
void onTrainingModeChanging( TrainingMode training_mode ) override
{
operation_->setTrainingMode( training_mode );
}
private:
SoftmaxConfig config_;
std::unique_ptr<IExecutionContext> owned_exec_context_{ nullptr };
std::shared_ptr<OpType> operation_{ nullptr };
void validateInputShape( const ITensor& input ) const
{
const auto& input_shape = input.shape();
validateInputShape( input_shape );
}
void validateInputShape( const shape_t& input_shape ) const
{
if (input_shape.empty())
{
throw std::invalid_argument( "Softmax: input must have rank >= 1" );
}
int64_t axis = config_.getAxis();
const int64_t ndim = static_cast<int64_t>(input_shape.size());
if (axis < 0)
{
axis = ndim + axis;
}
if (axis < 0 || axis >= ndim)
{
throw std::invalid_argument( "Softmax: axis out of bounds for input shape" );
}
}
void createOperation()
{
operation_ = std::make_shared<OpType>( this->getExecutionContext(), config_ );
if ( !operation_ )
{
throw std::runtime_error( "Failed to create Softmax compute backend operation." );
}
}
};
}
Build-time context for Component::build().
Definition Component.BuildContext.ixx:56
const shape_t & inputShape() const noexcept
The full input shape this component receives.
Definition Component.BuildContext.ixx:113
Abstract base class for neural network components.
Definition Component.ixx:155
virtual void synchronize()=0
Synchronize all pending operations.
virtual DeviceId getDeviceId() const noexcept=0
Get the device identifier.
Abstract interface providing essential tensor information and data access.
Definition ITensor.ixx:40
virtual const shape_t & shape() const =0
Get the tensor dimensional structure.
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
Configuration class for Softmax module.
Definition SoftmaxConfig.ixx:25
void validate() const override
Validate configuration parameters.
Definition SoftmaxConfig.ixx:53
int64_t getAxis() const noexcept
Get the configured axis value.
Definition SoftmaxConfig.ixx:46
Device-aware N-dimensional tensor.
Definition Tensor.ixx:138
Definition CudaSoftmaxOp.ixx:38
Definition Device.ixx:15
std::unique_ptr< IExecutionContext > createExecutionContext(DeviceId device_id)
Create execution context for specified device.
Definition ExecutionContextFactory.ixx:23
std::string deviceTypeToString(DeviceType device_type)
Converts a DeviceType to its string representation.
Definition DeviceType.ixx:38
Definition ArchiveSerializer.ixx:19
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
Definition ActivationType.ixx:13
TrainingMode
Runtime behavioral state for Components built with RuntimeMode::Training.
Definition Comonent.TrainingMode.ixx:39
ComponentType
Canonical list of framework-known component types.
Definition ComponentType.ixx:29
@ Softmax
Definition ComponentType.ixx:40
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143
std::string toString(ComponentType t) noexcept
Convert a ComponentType enum value to its canonical name.
Definition ComponentType.ixx:79
Lightweight identifier for a compute device.
Definition DeviceId.ixx:38
Definition DeviceTypeTraits.ixx:12
Global memory statistics for all TrackedMemoryResource instances.
Definition MemoryResourceTracker.ixx:19
Primary traits template for unified compile-time operation dispatch.
Definition OperationTraits.Template.ixx:45
size_t size() const
Definition Tensor.Types.ixx:109
bool empty() const
Definition Tensor.Types.ixx:110