Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::CudaGeluOp< TDataType > Class Template Referenceexport

CUDA implementation of the GELU activation function for neural networks. More...

Inheritance diagram for Mila::Dnn::Compute::CudaGeluOp< TDataType >:
Collaboration diagram for Mila::Dnn::Compute::CudaGeluOp< TDataType >:

Public Types

using MR = typename CudaDevice::MR
 
using UnaryOperationBase = UnaryOperation< DeviceType::Cuda, TDataType, TDataType >
 
- Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TDataType, TDataType >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, HostMemoryResource >
 Memory resource type based on device type.
 

Public Member Functions

 CudaGeluOp (const GeluConfig &config)
 
 CudaGeluOp (std::shared_ptr< DeviceContext > context, const GeluConfig &config)
 
void backward (const Tensor< TDataType, MR > &input, const Tensor< TDataType, MR > &output, const Tensor< TDataType, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameter_gradients, Tensor< TDataType, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &output_state) const
 Performs the backward pass of the GELU activation function.
 
void forward (const Tensor< TDataType, MR > &input, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameters, const OperationAttributes &properties, Tensor< TDataType, MR > &output, std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &output_state) const override
 Performs the forward pass of the GELU activation function on CUDA.
 
const GeluConfiggetConfig () const
 
std::string getName () const override
 Gets the name of this operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TDataType, TDataType >
 UnaryOperation (OperationType operation_type)
 Constructs a UnaryOperation with the specified operation type.
 
 UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs a UnaryOperation with the specified operation type and device context.
 
virtual ~UnaryOperation ()=default
 Virtual destructor for proper cleanup of derived classes.
 
virtual void backward (const Tensor< TDataType, MR > &grad, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &output_grads) const
 Executes the backward pass of a unary operation.
 
virtual void backward (const Tensor< TDataType, MR > &input, const Tensor< TDataType, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameters, std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameter_grads, Tensor< TDataType, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &output_state) const
 Executes the comprehensive backward pass of a unary operation.
 
virtual void forward (const Tensor< TDataType, MR > &input, const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &parameters, const OperationAttributes &properties, Tensor< TDataType, MR > &output, std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &output_state) const=0
 Executes the forward pass of a unary operation.
 
- Public Member Functions inherited from Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >
 OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context)
 Constructs an OperationBase object with a specific device context and compute precision.
 
virtual ~OperationBase ()=default
 Virtual destructor for the OperationBase class.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Gets the device context associated with this operation.
 
DeviceType getDeviceType () const
 Gets the device type for this operation.
 
OperationType getOperationType () const
 Gets the operation type enumeration value.
 

Private Attributes

GeluConfig config_
 Configuration for the GELU operation.
 
Detail::cuda_gelu_impl< TDataType > impl_
 Implementation details for the GELU operation.
 

Detailed Description

template<typename TDataType>
requires ValidFloatTensorType<TDataType>
class Mila::Dnn::Compute::CudaGeluOp< TDataType >

CUDA implementation of the GELU activation function for neural networks.

This class provides a CUDA-based implementation of the Gaussian Error Linear Unit (GELU) activation function, which is commonly used in transformer architectures. GELU is a smooth approximation of the ReLU function that applies a non-linear transformation to its input.

The implementation leverages CUDA for GPU acceleration, providing efficient computation for large neural network models. It also supports different precision modes via the ComputePrecision policy.

Template Parameters
TDataTypeThe data type of the output tensor elements.
TInputThe data type of the input tensor elements (defaults to TDataType).

Member Typedef Documentation

◆ MR

template<typename TDataType >
using Mila::Dnn::Compute::CudaGeluOp< TDataType >::MR = typename CudaDevice::MR

◆ UnaryOperationBase

template<typename TDataType >
using Mila::Dnn::Compute::CudaGeluOp< TDataType >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TDataType, TDataType>

Constructor & Destructor Documentation

◆ CudaGeluOp() [1/2]

template<typename TDataType >
Mila::Dnn::Compute::CudaGeluOp< TDataType >::CudaGeluOp ( const GeluConfig config)
inline
Here is the call graph for this function:

◆ CudaGeluOp() [2/2]

template<typename TDataType >
Mila::Dnn::Compute::CudaGeluOp< TDataType >::CudaGeluOp ( std::shared_ptr< DeviceContext context,
const GeluConfig config 
)
inline
Here is the call graph for this function:

Member Function Documentation

◆ backward()

template<typename TDataType >
void Mila::Dnn::Compute::CudaGeluOp< TDataType >::backward ( const Tensor< TDataType, MR > &  input,
const Tensor< TDataType, MR > &  output,
const Tensor< TDataType, MR > &  output_gradient,
const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &  parameters,
std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &  parameter_gradients,
Tensor< TDataType, MR > &  input_gradient,
const OperationAttributes properties,
const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &  output_state 
) const
inline

Performs the backward pass of the GELU activation function.

Computes gradients with respect to inputs for the GELU function. The precision policy affects the computation in the same way as the forward pass.

Parameters
inputInput tensor from the forward pass.
outputOutput tensor from the forward pass.
output_gradientGradient of the loss with respect to the output.
parametersParameters tensor from forward pass (not used).
parameter_gradientsGradients for parameters (not used).
input_gradientGradient of the loss with respect to the input.
propertiesAdditional attributes for the operation.
output_stateCache tensors from forward pass.
Here is the call graph for this function:

◆ forward()

template<typename TDataType >
void Mila::Dnn::Compute::CudaGeluOp< TDataType >::forward ( const Tensor< TDataType, MR > &  input,
const std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &  parameters,
const OperationAttributes properties,
Tensor< TDataType, MR > &  output,
std::vector< std::shared_ptr< Tensor< TDataType, MR > > > &  output_state 
) const
inlineoverride

Performs the forward pass of the GELU activation function on CUDA.

Computes the GELU transformation of the input elements: GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/?) * (x + 0.044715 * x^3)))

The precision policy affects how the computation is performed:

  • Performance: May use faster but less precise algorithms
  • Accuracy: Will use the most accurate algorithm available
  • Auto: Will select an appropriate balance based on the hardware
  • Disabled: Will use the standard precision of the input/output types
Parameters
inputInput tensor containing the values to transform.
parametersAdditional parameters (not used in this operation).
propertiesAdditional attributes for the operation.
outputOutput tensor to store the transformed values.
output_stateCache for intermediate results (not used in this operation).
Here is the call graph for this function:

◆ getConfig()

template<typename TDataType >
const GeluConfig & Mila::Dnn::Compute::CudaGeluOp< TDataType >::getConfig ( ) const
inline

◆ getName()

template<typename TDataType >
std::string Mila::Dnn::Compute::CudaGeluOp< TDataType >::getName ( ) const
inlineoverridevirtual

Gets the name of this operation.

Returns
std::string The name of the operation ("Cuda::GeluOp").

Implements Mila::Dnn::Compute::OperationBase< TDeviceType, TInput1, TInput2, TOutput >.

Member Data Documentation

◆ config_

template<typename TDataType >
GeluConfig Mila::Dnn::Compute::CudaGeluOp< TDataType >::config_
private

Configuration for the GELU operation.

◆ impl_

template<typename TDataType >
Detail::cuda_gelu_impl<TDataType> Mila::Dnn::Compute::CudaGeluOp< TDataType >::impl_
private

Implementation details for the GELU operation.


The documentation for this class was generated from the following file: