Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx

Construct GELU activation component with optional ExecutionContext ownership.

Construct GELU activation component with optional ExecutionContext ownership.Supports two construction modes:

Standalone mode (device_id provided):

  • Creates and owns an ExecutionContext for the specified device.
  • Registers the owned context with the base Component class via setExecutionContext().
  • Backend operation is created immediately in onExecutionContextSet() hook.
  • Use case: Unit tests, standalone component usage.

Shared mode (device_id not provided):

  • Does not create ExecutionContext; expects parent to provide one.
  • Parent (Network/CompositeComponent) calls setExecutionContext() after construction.
  • Backend operation created when parent sets context.
  • Use case: Components added to Network via addComponent<Gelu>(...).
Parameters
nameComponent name identifier (mandatory).
configGELU configuration (approximation method).
device_idOptional device identifier. If provided, creates owned ExecutionContext for standalone mode. If nullopt, expects shared context from parent.
Exceptions
std::invalid_argumentif config is invalid (via config.validate()).
std::invalid_argumentif device_id.type does not match TDeviceType.
std::runtime_errorif ExecutionContext creation fails (standalone mode).
std::runtime_errorif backend operation creation fails in onExecutionContextSet().
Note
In standalone mode, setExecutionContext() is called to register the owned context with the base class, enabling getExecutionContext() and triggering the onExecutionContextSet() hook for operation creation.

// Standalone mode (owns context) GeluConfig config; Gelu<DeviceType::Cpu, TensorDataType::FP32> gelu("gelu", config, Device::Cpu());

module;
#include <memory>
#include <vector>
#include <string>
#include <sstream>
#include <type_traits>
#include <stdexcept>
#include <format>
#include <utility>
#include <optional>
#include <algorithm>
export module Dnn.Components.Gelu;
import Dnn.Tensor;
import Dnn.ITensor;
namespace Mila::Dnn
{
using namespace Mila::Dnn::Compute;
using namespace Mila::Dnn::Serialization;
export template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Gelu : public Component<TDeviceType, TPrecision>
{
public:
using TensorType = Tensor<TPrecision, MR>;
using ComponentBase = Component<TDeviceType, TPrecision>;
explicit Gelu( const std::string& name, const GeluConfig& config, std::optional<DeviceId> device_id = std::nullopt )
: ComponentBase( name ), config_( config )
{
config_.validate();
if ( device_id.has_value() )
{
if ( device_id->type != TDeviceType )
{
throw std::invalid_argument( "Gelu: device type mismatch" );
}
owned_exec_context_ = createExecutionContext( device_id.value() );
this->setExecutionContext( owned_exec_context_.get() );
}
}
~Gelu() override = default;
// ====================================================================
// Computation Dispatch
// ====================================================================
TensorType& forward( const TensorType& input )
{
if ( !this->isBuilt() )
{
throw std::runtime_error( "Gelu::forward: component must be built before forward pass" );
}
const auto& input_shape = input.shape();
if ( output_view_->shape() != input_shape )
{
output_view_.emplace( output_->view( input_shape ) );
}
operation_->forward( input, *output_view_ );
return *output_view_;
}
TensorType& backward( const TensorType& input, const TensorType& output_grad )
{
if ( !this->isBuilt() )
{
throw std::runtime_error( "Gelu::backward: component must be built before backward pass" );
}
// REVIEW
/*if ( !this->isTraining() )
{
throw std::runtime_error( "Gelu::backward: component must be in training mode to compute gradients" );
}*/
// Zero input gradient buffer before backward pass. No exeptions.
// Backend ops use accumulation (atomicAdd/+=) which requires pre-zeroed buffers
// to prevent gradient buildup across calls. Without this, gradients grow linearly
// with each call -> explosion.
zero( *input_grad_ );
operation_->backward( input, output_grad, *input_grad_ );
return *input_grad_;
}
void synchronize() override
{
this->getExecutionContext()->synchronize();
}
ApproximationMethod getApproximationMethod() const
{
return config_.getApproximationMethod();
}
// ====================================================================
// Serialization
// ====================================================================
void save_( ModelArchive& archive, SerializationMode mode ) const override
{
(void)mode;
meta.set( "type", "Gelu" )
.set( "version", int64_t( 1 ) )
.set( "name", this->getName() )
.set( "template_device", deviceTypeToString( TDeviceType ) )
.set( "template_precision", static_cast<int64_t>(TPrecision) );
archive.writeMetadata( "meta.json", meta );
cfg.set( "approximation_method",
static_cast<int64_t>(config_.getApproximationMethod()) );
archive.writeMetadata( "config.json", cfg );
}
static std::unique_ptr<Gelu> fromArchive_(
ModelArchive& archive,
const std::string& component_name,
IExecutionContext* exec_context )
{
try
{
SerializationMetadata meta = archive.readMetadata( "meta.json" );
validateMetadata_( meta, component_name );
SerializationMetadata cfg = archive.readMetadata( "config.json" );
GeluConfig config;
auto approx_method = static_cast<ApproximationMethod>(
cfg.getInt( "approximation_method" ));
config.withApproximationMethod( approx_method );
config.validate();
return std::make_unique<Gelu>( component_name, config );
}
catch ( const std::exception& e )
{
throw std::runtime_error(
std::format( "Gelu::fromArchive: error for '{}': {}",
component_name, e.what() )
);
}
}
// ====================================================================
// Parameters and Gradients
// ====================================================================
size_t parameterCount() const override
{
return 0;
}
std::vector<ITensor*> getParameters() const override
{
return {};
}
std::vector<ITensor*> getGradients() const override
{
return {};
}
// ====================================================================
// Identification and Description
// ====================================================================
const ComponentType getType() const override
{
}
// ====================================================================
// State and Configuration
// ====================================================================
DeviceId getDeviceId() const override
{
return this->getExecutionContext()->getDeviceId();
}
MemoryStats getMemoryStats() const override
{
MemoryStats stats;
if ( output_ != nullptr )
{
stats.device_state_bytes += output_->getStorageSize();
}
if ( input_grad_ != nullptr )
{
stats.device_gradient_bytes += input_grad_->getStorageSize();
}
return stats;
}
std::string toString() const override
{
std::ostringstream oss;
oss << "--------------------" << std::endl;
oss << "Gelu: " << this->getName() << std::endl;
oss << "Device: " << deviceTypeToString( this->getDeviceType() ) << std::endl;
// FIXME: oss << "Approximation Method: " << config_.toString( config_.getApproximationMethod() ) << std::endl;
return oss.str();
}
protected:
void onExecutionContextSet() override
{
createOperation();
}
void onBuilding( const BuildContext& build_context ) override
{
operation_->build( build_context );
const auto& input_shape = build_context.inputShape();
// Allocate owned output and input-gradient tensors with device binding.
// Buffers are owned by this component and reused across calls.
DeviceId dev_id = this->getExecutionContext()->getDeviceId();
output_ = std::make_unique<TensorType>( dev_id, input_shape, this->getName() + ".output" );
output_view_.emplace( output_->view( input_shape ) );
if ( build_context.isTrainingMode() )
{
input_grad_ = std::make_unique<TensorType>( dev_id, input_shape, this->getName() + ".input_grad" );
zero( *input_grad_ );
}
}
void onTrainingModeChanging( TrainingMode training_mode ) override
{
operation_->setTrainingMode( training_mode );
}
private:
GeluConfig config_;
std::unique_ptr<IExecutionContext> owned_exec_context_{ nullptr };
std::shared_ptr<OpType> operation_{ nullptr };
std::unique_ptr<TensorType> output_{ nullptr };
std::optional<TensorType> output_view_;
std::unique_ptr<TensorType> input_grad_{ nullptr };
static void validateMetadata_( const SerializationMetadata& meta, const std::string& component_name )
{
int64_t version = meta.tryGetInt( "version" ).value_or( 0 );
if ( version != 1 )
{
throw std::runtime_error(
std::format( "Gelu: unsupported version {} for '{}'",
version, component_name )
);
}
std::string type = meta.tryGetString( "type" ).value_or( "" );
if ( type != "Gelu" )
{
throw std::runtime_error(
std::format( "Gelu: type mismatch for '{}': expected 'Gelu', got '{}'",
component_name, type )
);
}
std::string file_device = meta.tryGetString( "template_device" ).value_or( "" );
int64_t file_precision = meta.tryGetInt( "template_precision" ).value_or( -1 );
std::string expected_device = deviceTypeToString( TDeviceType );
int64_t expected_precision = static_cast<int64_t>(TPrecision);
if ( file_device != expected_device )
{
throw std::runtime_error(
std::format( "Gelu: device mismatch for '{}': archive='{}', expected='{}'",
component_name, file_device, expected_device )
);
}
if ( file_precision != expected_precision )
{
throw std::runtime_error(
std::format( "Gelu: precision mismatch for '{}': archive={}, expected={}",
component_name, file_precision, expected_precision )
);
}
}
void createOperation()
{
operation_ = std::make_shared<OpType>( this->getExecutionContext(), config_ );
if ( !operation_ )
{
throw std::runtime_error(
std::format( "Gelu: Failed to create compute backend operation for component '{}'",
this->getName() )
);
}
}
};
}
Build-time context for Component::build().
Definition Component.BuildContext.ixx:56
bool isTrainingMode() const noexcept
True if output buffers should be allocated at full input shape sequence length with gradient buffers.
Definition Component.BuildContext.ixx:202
const shape_t & inputShape() const noexcept
The full input shape this component receives.
Definition Component.BuildContext.ixx:113
Abstract base class for neural network components.
Definition Component.ixx:155
Type-erased execution context interface.
Definition IExecutionContext.ixx:24
virtual void synchronize()=0
Synchronize all pending operations.
virtual DeviceId getDeviceId() const noexcept=0
Get the device identifier.
Configuration class for GELU module.
Definition Gelu.Config.ixx:31
void validate() const override
Validate configuration parameters.
Definition Gelu.Config.ixx:62
Self && withApproximationMethod(this Self &&self, ApproximationMethod method)
Configure the approximation method for GELU computation.
Definition Gelu.Config.ixx:44
ApproximationMethod getApproximationMethod() const
Get the configured approximation method.
Definition Gelu.Config.ixx:55
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
SerializationMetadata readMetadata(const std::string &path) const
Definition ModelArchive.ixx:310
void writeMetadata(const std::string &path, const SerializationMetadata &metadata)
Definition ModelArchive.ixx:274
Type-safe metadata container for component serialization.
Definition SerializationMetadata.ixx:52
std::optional< int64_t > tryGetInt(const std::string &key) const noexcept
Get optional integer value.
Definition SerializationMetadata.ixx:454
int64_t getInt(const std::string &key) const
Get integer value.
Definition SerializationMetadata.ixx:287
std::optional< std::string > tryGetString(const std::string &key) const noexcept
Get optional string value.
Definition SerializationMetadata.ixx:432
SerializationMetadata & set(const std::string &key, MetadataValue value)
Set metadata value with automatic type deduction.
Definition SerializationMetadata.ixx:68
Device-aware N-dimensional tensor.
Definition Tensor.ixx:138
Definition CudaGeluOp.Dispatch.ixx:39
Definition Device.ixx:15
std::unique_ptr< IExecutionContext > createExecutionContext(DeviceId device_id)
Create execution context for specified device.
Definition ExecutionContextFactory.ixx:23
std::string deviceTypeToString(DeviceType device_type)
Converts a DeviceType to its string representation.
Definition DeviceType.ixx:38
Definition ArchiveSerializer.ixx:19
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
Definition ActivationType.ixx:13
TrainingMode
Runtime behavioral state for Components built with RuntimeMode::Training.
Definition Comonent.TrainingMode.ixx:39
ApproximationMethod
Approximation methods usable by activation functions.
Definition ApproximationMethod.ixx:18
ComponentType
Canonical list of framework-known component types.
Definition ComponentType.ixx:29
@ Gelu
Definition ComponentType.ixx:36
std::string toString(ComponentType t) noexcept
Convert a ComponentType enum value to its canonical name.
Definition ComponentType.ixx:79
void zero(Tensor< TDataType, TMemoryResource > &tensor, IExecutionContext *exec_context=nullptr)
Zero a tensor using the fastest backend implementation.
Definition TensorOps.Zero.ixx:42
Lightweight identifier for a compute device.
Definition DeviceId.ixx:38
Definition DeviceTypeTraits.ixx:12
Global memory statistics for all TrackedMemoryResource instances.
Definition MemoryResourceTracker.ixx:19
Primary traits template for unified compile-time operation dispatch.
Definition OperationTraits.Template.ixx:45