Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision > Class Template Referenceexport

CUDA implementation of RMS Normalization. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >:

Public Types

using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
using DataTypeTraits

Public Member Functions

 CudaRmsNormOp (IExecutionContext *context, const RmsNormConfig &config)
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override
 Execute backward pass (hot path).
void build (const BuildContext &config) override
 Prepare operation for execution with concrete input shape.
void forward (const ITensor &input, ITensor &output) const override
 Execute forward pass (hot path).
const RmsNormConfiggetConfig () const
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void setGradients (ITensor *weight_grad, ITensor *bias_grad) override
 Bind component-owned parameter gradient tensors for training.
void setParameters (ITensor *weight, ITensor *bias) override
 Bind component-owned parameter tensors.
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Private Attributes

NativeTypebias_ { nullptr }
NativeTypebias_grad_ { nullptr }
RmsNormConfig config_
CudaExecutionContextcontext_
Detail::cuda_rmsnorm_impl< NativeTypeimpl_
int inner_size_ { 0 }
int64_t norm_axis_ { -1 }
int norm_dim_ { 0 }
int outer_size_ { 0 }
NativeTyperstd_ { nullptr }
std::shared_ptr< TensorTyperstd_tensor_
NativeTypeweight_ { nullptr }
NativeTypeweight_grad_ { nullptr }

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >

CUDA implementation of RMS Normalization.

Normalizes activations using root-mean-square across a specified axis, then applies an optional affine transform with learnable weight and bias.

This class mirrors the LayerNorm op structure but calls RMS-specific kernels.

Template Parameters
TPrecisionAbstract tensor precision (FP32, FP16, etc.)

Member Typedef Documentation

◆ CudaExecutionContext

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

◆ NativeType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type

◆ TensorType

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TPrecision>
using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision>

Constructor & Destructor Documentation

◆ CudaRmsNormOp()

template<TensorDataType TPrecision>
Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::CudaRmsNormOp ( IExecutionContext * context,
const RmsNormConfig & config )
inline

Member Function Documentation

◆ backward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Execute backward pass (hot path).

Computes input gradient and accumulates parameter gradients using forward-pass statistics cached during forward().

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ build()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::build ( const BuildContext & config)
inlineoverridevirtual

Prepare operation for execution with concrete input shape.

Computes normalization axis, partitions tensor dimensions, and allocates forward-pass statistics storage required by backward.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ forward()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Execute forward pass (hot path).

Computes RMS-normalized output and caches forward-pass statistics required for backward().

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.

◆ getConfig()

template<TensorDataType TPrecision>
const RmsNormConfig & Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::getConfig ( ) const
inline

◆ getName()

template<TensorDataType TPrecision>
std::string Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ getOperationType()

template<TensorDataType TPrecision>
OperationType Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::getOperationType ( ) const
inlineoverridevirtual

◆ setGradients()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::setGradients ( ITensor * weight_grad,
ITensor * bias_grad )
inlineoverridevirtual

Bind component-owned parameter gradient tensors for training.

Caches native device gradient pointers for backward pass writes.

Parameters
weight_gradGradient accumulator for weight parameter (required)
bias_gradGradient accumulator for bias parameter (optional)

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

◆ setParameters()

template<TensorDataType TPrecision>
void Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::setParameters ( ITensor * weight,
ITensor * bias )
inlineoverridevirtual

Bind component-owned parameter tensors.

Caches native device pointers for zero-overhead hot-path access. Weight is required; bias is optional based on configuration.

Parameters
weightScaling parameter applied after normalization (required)
biasShift parameter applied after normalization (optional)

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.

Member Data Documentation

◆ bias_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::bias_ { nullptr }
private

◆ bias_grad_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::bias_grad_ { nullptr }
private

◆ config_

template<TensorDataType TPrecision>
RmsNormConfig Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::config_
private

◆ context_

template<TensorDataType TPrecision>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::context_
private

◆ impl_

◆ inner_size_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::inner_size_ { 0 }
private

◆ norm_axis_

template<TensorDataType TPrecision>
int64_t Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::norm_axis_ { -1 }
private

◆ norm_dim_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::norm_dim_ { 0 }
private

◆ outer_size_

template<TensorDataType TPrecision>
int Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::outer_size_ { 0 }
private

◆ rstd_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::rstd_ { nullptr }
private

◆ rstd_tensor_

template<TensorDataType TPrecision>
std::shared_ptr<TensorType> Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::rstd_tensor_
private

◆ weight_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::weight_ { nullptr }
private

◆ weight_grad_

template<TensorDataType TPrecision>
NativeType* Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::weight_grad_ { nullptr }
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Normalizations/RmsNorm/RmsNormOp.ixx