|
Mila 0.13.48
Deep Neural Network Library
|
CUDA implementation of RMS Normalization. More...


Public Types | |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using | TensorType = Tensor<TPrecision, MR> |
| using | UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
| Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| using | MR |
| using | TensorInputType |
| using | TensorOutputType |
| Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| using | DataTypeTraits |
Public Member Functions | |
| CudaRmsNormOp (IExecutionContext *context, const RmsNormConfig &config) | |
| void | backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override |
| Execute backward pass (hot path). | |
| void | build (const BuildContext &config) override |
| Prepare operation for execution with concrete input shape. | |
| void | forward (const ITensor &input, ITensor &output) const override |
| Execute forward pass (hot path). | |
| const RmsNormConfig & | getConfig () const |
| std::string | getName () const override |
| Human-readable operation name. | |
| OperationType | getOperationType () const override |
| Operation type identifier. | |
| void | setGradients (ITensor *weight_grad, ITensor *bias_grad) override |
| Bind component-owned parameter gradient tensors for training. | |
| void | setParameters (ITensor *weight, ITensor *bias) override |
| Bind component-owned parameter tensors. | |
| Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| virtual | ~UnaryOperation ()=default |
| Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| virtual | ~Operation ()=default |
| virtual void | clearGradients () noexcept |
| Clear any cached gradient pointers held by the operation. | |
| virtual TensorDataType | getDataType () const |
| Tensor data type for this operation. | |
| virtual DeviceType | getDeviceType () const |
| Device type for this operation. | |
| virtual std::size_t | getStateMemorySize () const |
| Returns the number of bytes of state memory allocated by this operation. | |
| virtual bool | isBuilt () const |
| Whether build() completed successfully for a concrete input shape. | |
| virtual bool | isEvalMode () const |
| Query whether operation is configured for training. | |
| virtual void | setTrainingMode (TrainingMode training_mode) |
| Configure operation training-mode behavior. | |
Private Attributes | |
| NativeType * | bias_ { nullptr } |
| NativeType * | bias_grad_ { nullptr } |
| RmsNormConfig | config_ |
| CudaExecutionContext * | context_ |
| Detail::cuda_rmsnorm_impl< NativeType > | impl_ |
| int | inner_size_ { 0 } |
| int64_t | norm_axis_ { -1 } |
| int | norm_dim_ { 0 } |
| int | outer_size_ { 0 } |
| NativeType * | rstd_ { nullptr } |
| std::shared_ptr< TensorType > | rstd_tensor_ |
| NativeType * | weight_ { nullptr } |
| NativeType * | weight_grad_ { nullptr } |
Additional Inherited Members | |
| Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| static constexpr TensorDataType | data_type |
| static constexpr DeviceType | device_type |
| Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision > | |
| static const TensorInputType & | asInputTensor (const ITensor &t) |
| static TensorOutputType & | asOutputTensor (ITensor &t) |
| Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TInput > | |
| bool | is_built_ |
| TrainingMode | training_mode_ |
CUDA implementation of RMS Normalization.
Normalizes activations using root-mean-square across a specified axis, then applies an optional affine transform with learnable weight and bias.
This class mirrors the LayerNorm op structure but calls RMS-specific kernels.
| TPrecision | Abstract tensor precision (FP32, FP16, etc.) |
| using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type |
| using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::TensorType = Tensor<TPrecision, MR> |
| using Mila::Dnn::Compute::Cuda::RmsNorm::CudaRmsNormOp< TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TPrecision> |
|
inline |
|
inlineoverridevirtual |
Execute backward pass (hot path).
Computes input gradient and accumulates parameter gradients using forward-pass statistics cached during forward().
Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.
|
inlineoverridevirtual |
Prepare operation for execution with concrete input shape.
Computes normalization axis, partitions tensor dimensions, and allocates forward-pass statistics storage required by backward.
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Execute forward pass (hot path).
Computes RMS-normalized output and caches forward-pass statistics required for backward().
Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TPrecision >.
|
inline |
|
inlineoverridevirtual |
Human-readable operation name.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Operation type identifier.
Implements Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Bind component-owned parameter gradient tensors for training.
Caches native device gradient pointers for backward pass writes.
| weight_grad | Gradient accumulator for weight parameter (required) |
| bias_grad | Gradient accumulator for bias parameter (optional) |
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
inlineoverridevirtual |
Bind component-owned parameter tensors.
Caches native device pointers for zero-overhead hot-path access. Weight is required; bias is optional based on configuration.
| weight | Scaling parameter applied after normalization (required) |
| bias | Shift parameter applied after normalization (optional) |
Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TInput >.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |