Mila
Deep Neural Network Library
|
Gaussian Error Linear Unit (GELU) activation function module. More...
Public Types | |
using | ModuleBase = Module< TDeviceType, TDataType, TDataType > |
Alias for base module type. | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Memory resource type determined based on device type. | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Public Member Functions | |
Gelu (const std::string &device_name, const GeluConfig &config) | |
Constructs a Gelu module using device name and configuration. | |
Gelu (std::shared_ptr< DeviceContext > device_context, const GeluConfig &config) | |
Constructs a Gelu module with an existing device context and configuration. | |
void | backward (const Tensor< TDataType, MR > &input, const Tensor< TDataType, MR > &output_grad, Tensor< TDataType, MR > &input_grad) |
Performs backward propagation, computing gradients for GELU activation. | |
void | forward (const Tensor< TDataType, MR > &input, Tensor< TDataType, MR > &output) |
Performs forward propagation through the GELU activation function. | |
GeluConfig::ApproximationMethod | getApproximationMethod () const |
Returns the current approximation method used by this GELU instance. | |
void | load (ModelArchive &archive) override |
Deserializes module state from a ZIP archive. | |
size_t | parameterCount () const override |
Returns the number of trainable parameters in this module. | |
void | save (ModelArchive &zip) const override |
Serializes module state to a ZIP archive. | |
std::string | toString () const override |
Generates a string representation of this module's configuration. | |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
Private Member Functions | |
void | createOperation () |
Initializes the appropriate GELU operation implementation. | |
Static Private Member Functions | |
static std::string | approximationMethodToString (GeluConfig::ApproximationMethod method) |
Converts approximation method enum to human-readable string. | |
Private Attributes | |
GeluConfig | config_ |
Configuration for the GELU module. | |
std::shared_ptr< UnaryOperation< TDeviceType, TDataType, TDataType > > | operation_ { nullptr } |
The underlying computational operation that implements GELU. | |
std::vector< std::shared_ptr< Tensor< TDataType, MR > > > | output_state_ |
Output state cache for backward propagation. | |
std::vector< std::shared_ptr< Tensor< TDataType, MR > > > | parameters_ |
Parameter tensors for the operation. | |
OperationAttributes | properties_ |
Additional attributes for operation customization. | |
Additional Inherited Members | |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
Gaussian Error Linear Unit (GELU) activation function module.
GELU is defined mathematically as: GELU(x) = x * phi(x)
Where phi(x) is the cumulative distribution function of the standard normal distribution.
Three approximation methods are supported (configured via GeluConfig):
Note: Currently only the Tanh approximation is fully supported in the implementation.
TDeviceType | Computing device type (CPU or CUDA) |
TDataType | Floating-point data type for computations (e.g., float, half ) |
|
export |
Alias for base module type.
|
export |
Memory resource type determined based on device type.
Automatically selects appropriate memory resource (CPU or CUDA) based on TDeviceType.
|
inlineexplicitexport |
Constructs a Gelu module using device name and configuration.
Creates a new DeviceContext internally using the provided device name. This constructor is useful for creating standalone modules without pre-existing device contexts.
device_name | Device identifier string (e.g., "cpu", "cuda:0") |
config | Configuration parameters for the GELU module |
std::invalid_argument | If the device name is invalid or the configuration is invalid |
std::runtime_error | If device type doesn't match template parameter TDeviceType |
|
inlineexplicitexport |
Constructs a Gelu module with an existing device context and configuration.
Uses a pre-existing DeviceContext instance. This constructor is useful when integrating the module into a larger network that shares device contexts across modules.
device_context | Shared pointer to an existing device context |
config | Configuration parameters for the GELU module |
std::invalid_argument | If device_context is null or configuration is invalid |
std::runtime_error | If device context type doesn't match template parameter TDeviceType |
|
inlinestaticexportprivate |
Converts approximation method enum to human-readable string.
method | The approximation method to convert |
|
inlineexport |
Performs backward propagation, computing gradients for GELU activation.
Computes the gradient of the GELU function with respect to its inputs, which is needed for training via backpropagation.
The GELU derivative is: d/dx GELU(x) = ?(x) + x * ?'(x)
Where ?'(x) is the derivative of the CDF (the PDF of the standard normal distribution).
input | Original input tensor from the forward pass |
output_grad | Gradient tensor from the next layer (?L/?output) |
input_grad | Output tensor to store the computed gradients (?L/?input) |
|
inlineexportprivate |
Initializes the appropriate GELU operation implementation.
Creates the device-specific operation implementation based on the template parameter TDeviceType and registers it with the operation registry.
The operation choice is determined at compile-time via constexpr branching.
|
inlineexport |
Performs forward propagation through the GELU activation function.
Applies the GELU transformation element-wise to each value in the input tensor. The specific approximation method used is determined by the GeluConfig setting.
input | Input tensor to transform |
output | Tensor where results will be stored (must be pre-allocated with matching dimensions) |
|
inlineexport |
Returns the current approximation method used by this GELU instance.
|
inlineoverrideexportvirtual |
Deserializes module state from a ZIP archive.
Implementation of the Module interface for deserialization. Since GELU has no learnable parameters, this is a no-op implementation.
zip | ZIP archive for deserialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Returns the number of trainable parameters in this module.
GELU is a parameterless activation function with no trainable weights.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Serializes module state to a ZIP archive.
Implementation of the Module interface for serialization. Since GELU has no learnable parameters, this is a no-op implementation.
zip | ZIP archive for serialization |
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
inlineoverrideexportvirtual |
Generates a string representation of this module's configuration.
Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.
|
exportprivate |
Configuration for the GELU module.
Stores the settings that define how the GELU function should be computed, particularly which approximation method to use.
|
exportprivate |
The underlying computational operation that implements GELU.
This pointer is initialized based on the device type and configuration, providing the device-specific implementation of the GELU function.
|
exportprivate |
Output state cache for backward propagation.
Stores intermediate results from the forward pass that may be needed during backward propagation to efficiently compute gradients.
|
exportprivate |
Parameter tensors for the operation.
Empty for GELU since it has no trainable parameters, but required by the UnaryOperation interface.
|
exportprivate |
Additional attributes for operation customization.
Holds configuration values that might be needed by specific implementations of the GELU operation.