Construct GELU activation component with optional ExecutionContext ownership.
Construct GELU activation component with optional ExecutionContext ownership.Supports two construction modes:
// Standalone mode (owns context) GeluConfig config; Gelu<DeviceType::Cpu, TensorDataType::FP32> gelu("gelu", config, Device::Cpu());
module;
#include <memory>
#include <vector>
#include <string>
#include <sstream>
#include <type_traits>
#include <stdexcept>
#include <format>
#include <utility>
#include <optional>
#include <algorithm>
{
export template<DeviceType TDeviceType, TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
{
public:
explicit Gelu(
const std::string& name,
const GeluConfig& config, std::optional<DeviceId> device_id = std::nullopt )
: ComponentBase( name ), config_( config )
{
if ( device_id.has_value() )
{
if ( device_id->type != TDeviceType )
{
throw std::invalid_argument( "Gelu: device type mismatch" );
}
this->setExecutionContext( owned_exec_context_.get() );
}
}
~Gelu()
override =
default;
TensorType& forward( const TensorType& input )
{
if ( !this->isBuilt() )
{
throw std::runtime_error( "Gelu::forward: component must be built before forward pass" );
}
const auto& input_shape = input.shape();
if ( output_view_->shape() != input_shape )
{
output_view_.emplace( output_->view( input_shape ) );
}
operation_->forward( input, *output_view_ );
return *output_view_;
}
TensorType& backward( const TensorType& input, const TensorType& output_grad )
{
if ( !this->isBuilt() )
{
throw std::runtime_error( "Gelu::backward: component must be built before backward pass" );
}
operation_->backward( input, output_grad, *input_grad_ );
return *input_grad_;
}
void synchronize() override
{
}
{
}
{
(void)mode;
meta.
set(
"type",
"Gelu" )
.
set(
"version", int64_t( 1 ) )
.
set(
"name", this->getName() )
.
set(
"template_precision",
static_cast<int64_t
>(TPrecision) );
cfg.
set(
"approximation_method",
}
static std::unique_ptr<Gelu> fromArchive_(
const std::string& component_name,
{
try
{
validateMetadata_( meta, component_name );
cfg.
getInt(
"approximation_method" ));
return std::make_unique<Gelu>( component_name, config );
}
catch ( const std::exception& e )
{
throw std::runtime_error(
std::format( "Gelu::fromArchive: error for '{}': {}",
component_name, e.what() )
);
}
}
size_t parameterCount() const override
{
return 0;
}
std::vector<ITensor*> getParameters() const override
{
return {};
}
std::vector<ITensor*> getGradients() const override
{
return {};
}
{
}
{
}
{
if ( output_ != nullptr )
{
stats.device_state_bytes += output_->getStorageSize();
}
if ( input_grad_ != nullptr )
{
stats.device_gradient_bytes += input_grad_->getStorageSize();
}
return stats;
}
{
std::ostringstream oss;
oss << "--------------------" << std::endl;
oss << "Gelu: " << this->getName() << std::endl;
return oss.str();
}
protected:
void onExecutionContextSet() override
{
createOperation();
}
void onBuilding(
const BuildContext& build_context )
override
{
operation_->build( build_context );
const auto& input_shape = build_context.
inputShape();
output_ = std::make_unique<TensorType>( dev_id, input_shape, this->getName() + ".output" );
output_view_.emplace( output_->view( input_shape ) );
{
input_grad_ = std::make_unique<TensorType>( dev_id, input_shape, this->getName() + ".input_grad" );
}
}
void onTrainingModeChanging(
TrainingMode training_mode )
override
{
operation_->setTrainingMode( training_mode );
}
private:
std::unique_ptr<IExecutionContext> owned_exec_context_{ nullptr };
std::shared_ptr<OpType> operation_{ nullptr };
std::unique_ptr<TensorType> output_{ nullptr };
std::optional<TensorType> output_view_;
std::unique_ptr<TensorType> input_grad_{ nullptr };
{
int64_t version = meta.
tryGetInt(
"version" ).value_or( 0 );
if ( version != 1 )
{
throw std::runtime_error(
std::format( "Gelu: unsupported version {} for '{}'",
version, component_name )
);
}
std::string type = meta.
tryGetString(
"type" ).value_or(
"" );
if ( type != "Gelu" )
{
throw std::runtime_error(
std::format( "Gelu: type mismatch for '{}': expected 'Gelu', got '{}'",
component_name, type )
);
}
std::string file_device = meta.
tryGetString(
"template_device" ).value_or(
"" );
int64_t file_precision = meta.
tryGetInt(
"template_precision" ).value_or( -1 );
int64_t expected_precision = static_cast<int64_t>(TPrecision);
if ( file_device != expected_device )
{
throw std::runtime_error(
std::format( "Gelu: device mismatch for '{}': archive='{}', expected='{}'",
component_name, file_device, expected_device )
);
}
if ( file_precision != expected_precision )
{
throw std::runtime_error(
std::format( "Gelu: precision mismatch for '{}': archive={}, expected={}",
component_name, file_precision, expected_precision )
);
}
}
void createOperation()
{
operation_ = std::make_shared<OpType>( this->getExecutionContext(), config_ );
if ( !operation_ )
{
throw std::runtime_error(
std::format( "Gelu: Failed to create compute backend operation for component '{}'",
this->getName() )
);
}
}
};
}
Build-time context for Component::build().
Definition Component.BuildContext.ixx:56
bool isTrainingMode() const noexcept
True if output buffers should be allocated at full input shape sequence length with gradient buffers.
Definition Component.BuildContext.ixx:202
const shape_t & inputShape() const noexcept
The full input shape this component receives.
Definition Component.BuildContext.ixx:113
Abstract base class for neural network components.
Definition Component.ixx:155
Type-erased execution context interface.
Definition IExecutionContext.ixx:24
virtual void synchronize()=0
Synchronize all pending operations.
virtual DeviceId getDeviceId() const noexcept=0
Get the device identifier.
Configuration class for GELU module.
Definition Gelu.Config.ixx:31
void validate() const override
Validate configuration parameters.
Definition Gelu.Config.ixx:62
Self && withApproximationMethod(this Self &&self, ApproximationMethod method)
Configure the approximation method for GELU computation.
Definition Gelu.Config.ixx:44
ApproximationMethod getApproximationMethod() const
Get the configured approximation method.
Definition Gelu.Config.ixx:55
ModelArchive provides high-level helpers for component serialization.
Definition ModelArchive.ixx:47
SerializationMetadata readMetadata(const std::string &path) const
Definition ModelArchive.ixx:310
void writeMetadata(const std::string &path, const SerializationMetadata &metadata)
Definition ModelArchive.ixx:274
Device-aware N-dimensional tensor.
Definition Tensor.ixx:138
Definition CudaGeluOp.Dispatch.ixx:39
std::unique_ptr< IExecutionContext > createExecutionContext(DeviceId device_id)
Create execution context for specified device.
Definition ExecutionContextFactory.ixx:23
std::string deviceTypeToString(DeviceType device_type)
Converts a DeviceType to its string representation.
Definition DeviceType.ixx:38
Definition ArchiveSerializer.ixx:19
SerializationMode
Modes for serialization and deserialization.
Definition SerializationMode.ixx:17
Definition ActivationType.ixx:13
TrainingMode
Runtime behavioral state for Components built with RuntimeMode::Training.
Definition Comonent.TrainingMode.ixx:39
ApproximationMethod
Approximation methods usable by activation functions.
Definition ApproximationMethod.ixx:18
ComponentType
Canonical list of framework-known component types.
Definition ComponentType.ixx:29
@ Gelu
Definition ComponentType.ixx:36
std::string toString(ComponentType t) noexcept
Convert a ComponentType enum value to its canonical name.
Definition ComponentType.ixx:79
void zero(Tensor< TDataType, TMemoryResource > &tensor, IExecutionContext *exec_context=nullptr)
Zero a tensor using the fastest backend implementation.
Definition TensorOps.Zero.ixx:42
Lightweight identifier for a compute device.
Definition DeviceId.ixx:38
Definition DeviceTypeTraits.ixx:12
Global memory statistics for all TrackedMemoryResource instances.
Definition MemoryResourceTracker.ixx:19
Primary traits template for unified compile-time operation dispatch.
Definition OperationTraits.Template.ixx:45