|
Mila 0.13.48
Deep Neural Network Library
|
CUDA-specific AdamW optimizer implementation. More...


Public Types | |
| using | CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using | MR = CudaDeviceMemoryResource |
| using | NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::native_type |
| using | TensorType = Tensor<TPrecision, MR> |
Public Member Functions | |
| CudaAdamWOptimizer (IExecutionContext *context, const OptimizerConfig &config) | |
| Construct CUDA AdamW optimizer. | |
| ~CudaAdamWOptimizer () override=default | |
| void | addParameter (ITensor *param, ITensor *grad) override |
| Register a parameter-gradient pair for optimization. | |
| float | getBeta1 () const noexcept |
| Get beta1 parameter. | |
| float | getBeta2 () const noexcept |
| Get beta2 parameter. | |
| float | getEpsilon () const noexcept |
| Get epsilon parameter. | |
| float | getLearningRate () const override |
| Zero all gradient tensors. | |
| size_t | getParameterCount () const noexcept |
| Get number of registered parameter groups. | |
| size_t | getStepCount () const noexcept |
| Get current step count. | |
| float | getWeightDecay () const noexcept |
| Get weight decay parameter. | |
| void | setLearningRate (float learning_rate) override |
| Set learning rate for future steps. | |
| void | setWeightDecay (float weight_decay) |
| Set weight decay coefficient. | |
| void | step () override |
| Perform one AdamW optimization step. | |
| Public Member Functions inherited from Mila::Dnn::Compute::Optimizer< DeviceType::Cuda, TPrecision > | |
| virtual | ~Optimizer ()=default |
Private Member Functions | |
| void | validateHyperparameters () const |
| Validate optimizer hyperparameters. | |
Static Private Member Functions | |
| static std::string | shapeToString (const shape_t &shape) |
| Convert shape to string for error messages. | |
Private Attributes | |
| OptimizerConfig | config_ |
| CudaExecutionContext * | exec_context_ |
| std::vector< NativeType * > | grad_data_ |
| float | grad_scale_ { 1.0f } |
| std::vector< ITensor * > | grads_ |
| std::vector< float * > | m_data_ |
| std::vector< std::shared_ptr< Tensor< TensorDataType::FP32, MR > > > | m_states_ |
| std::vector< float * > | master_param_data_ |
| std::vector< std::shared_ptr< Tensor< TensorDataType::FP32, MR > > > | master_params_ |
| std::vector< NativeType * > | param_data_ |
| std::vector< ITensor * > | params_ |
| size_t | step_count_ |
| std::vector< float * > | v_data_ |
| std::vector< std::shared_ptr< Tensor< TensorDataType::FP32, MR > > > | v_states_ |
CUDA-specific AdamW optimizer implementation.
Implements the AdamW algorithm using optimized CUDA kernels from adamw.cuh. Maintains per-parameter state tensors (first moment, second moment) on the GPU and performs asynchronous parameter updates via CUDA streams.
AdamW algorithm:
Features:
| TPrecision | Abstract tensor precision (TensorDataType) |
| using Mila::Dnn::Compute::CudaAdamWOptimizer< TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda> |
| using Mila::Dnn::Compute::CudaAdamWOptimizer< TPrecision >::MR = CudaDeviceMemoryResource |
| using Mila::Dnn::Compute::CudaAdamWOptimizer< TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::native_type |
| using Mila::Dnn::Compute::CudaAdamWOptimizer< TPrecision >::TensorType = Tensor<TPrecision, MR> |
|
inlineexplicit |
Construct CUDA AdamW optimizer.
| exec_context | CUDA execution context for stream and device management |
| config | AdamW optimizer configuration |
| std::invalid_argument | if exec_context is null |

|
overridedefault |
|
inlineoverridevirtual |
Register a parameter-gradient pair for optimization.
The optimizer does not take ownership of the parameter/gradient tensors. The caller (typically a Module) must ensure the tensors remain valid for the lifetime of the optimizer.
Allocates momentum and variance state tensors on the GPU matching the parameter shape. State tensors are zero-initialized.
| param | Parameter tensor to optimize (non-owning, must be on CUDA device) |
| grad | Gradient tensor (non-owning, must match param shape and device) |
| std::invalid_argument | if param or grad is null |
| std::invalid_argument | if param and grad shapes don't match |
| std::invalid_argument | if param or grad is not a CUDA tensor |
| std::invalid_argument | if param or grad data type doesn't match optimizer precision |
| std::runtime_error | if state allocation fails |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cuda, TPrecision >.

|
inlinenoexcept |
Get beta1 parameter.
|
inlinenoexcept |
Get beta2 parameter.
|
inlinenoexcept |
Get epsilon parameter.
|
inlineoverridevirtual |
Zero all gradient tensors.
Asynchronously clears all registered gradient tensors on the GPU.
| std::runtime_error | if no parameters have been registered |
Get current learning rate.
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cuda, TPrecision >.
|
inlinenoexcept |
Get number of registered parameter groups.
|
inlinenoexcept |
Get current step count.
|
inlinenoexcept |
Get weight decay parameter.
|
inlineoverridevirtual |
Set learning rate for future steps.
| learning_rate | New learning rate (must be positive) |
| std::invalid_argument | if learning_rate <= 0 |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cuda, TPrecision >.
|
inline |
Set weight decay coefficient.
| weight_decay | New weight decay (must be non-negative) |
| std::invalid_argument | if weight_decay < 0 |
|
inlinestaticprivate |
Convert shape to string for error messages.


|
inlineoverridevirtual |
Perform one AdamW optimization step.
Updates all registered parameters asynchronously on the GPU using the AdamW CUDA kernel. Execution happens on the execution context's CUDA stream.
| std::runtime_error | if no parameters have been registered |
| std::runtime_error | if CUDA kernel launch fails |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cuda, TPrecision >.

|
inlineprivate |
Validate optimizer hyperparameters.

|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |