|
Mila 0.13.48
Deep Neural Network Library
|
CPU-specific AdamW optimizer implementation. More...


Public Types | |
| using | ExecutionContextType = ExecutionContext<DeviceType::Cpu> |
| using | HostType = typename TensorHostTypeMap<TPrecision>::host_type |
| using | MR = CpuMemoryResource |
| using | TensorType = Tensor<TPrecision, MR> |
Public Member Functions | |
| CpuAdamWOptimizer (IExecutionContext *context, const AdamWConfig &config) | |
| Construct CPU AdamW optimizer. | |
| ~CpuAdamWOptimizer () override=default | |
| void | addParameter (ITensor *param, ITensor *grad) override |
| Register a parameter-gradient pair for optimization. | |
| float | getBeta1 () const noexcept |
| Get beta1 parameter. | |
| float | getBeta2 () const noexcept |
| Get beta2 parameter. | |
| float | getEpsilon () const noexcept |
| Get epsilon parameter. | |
| float | getLearningRate () const override |
| Zero all gradient tensors. | |
| size_t | getParameterCount () const noexcept |
| Get number of registered parameter groups. | |
| size_t | getStepCount () const noexcept |
| Get current step count. | |
| float | getWeightDecay () const noexcept |
| Get weight decay parameter. | |
| void | setLearningRate (float learning_rate) override |
| Set learning rate for future steps. | |
| void | step () override |
| Perform one AdamW optimization step. | |
| Public Member Functions inherited from Mila::Dnn::Compute::Optimizer< DeviceType::Cpu, TPrecision > | |
| virtual | ~Optimizer ()=default |
Private Member Functions | |
| void | updateParameter (HostType *param_data, const HostType *grad_data, float *m_data, float *v_data, size_t num_params, float beta1_correction, float beta2_correction) |
| Update a single parameter using AdamW algorithm. | |
Static Private Member Functions | |
| static std::string | shapeToString (const shape_t &shape) |
| Convert shape to string for error messages. | |
Private Attributes | |
| AdamWConfig | config_ |
| IExecutionContext * | context_ |
| std::vector< const HostType * > | grad_data_ |
| std::vector< ITensor * > | grads_ |
| float | learning_rate_ |
| std::vector< float * > | m_data_ |
| std::vector< std::shared_ptr< Tensor< TensorDataType::FP32, MR > > > | m_states_ |
| std::vector< HostType * > | param_data_ |
| std::vector< ITensor * > | params_ |
| size_t | step_count_ { 0 } |
| std::vector< float * > | v_data_ |
| std::vector< std::shared_ptr< Tensor< TensorDataType::FP32, MR > > > | v_states_ |
CPU-specific AdamW optimizer implementation.
Implements the AdamW algorithm using scalar CPU loops. Maintains per-parameter state tensors (first moment, second moment) and performs synchronous parameter updates.
AdamW algorithm:
Features:
| TPrecision | Abstract tensor precision (TensorDataType) |
| using Mila::Dnn::Compute::CpuAdamWOptimizer< TPrecision >::ExecutionContextType = ExecutionContext<DeviceType::Cpu> |
| using Mila::Dnn::Compute::CpuAdamWOptimizer< TPrecision >::HostType = typename TensorHostTypeMap<TPrecision>::host_type |
| using Mila::Dnn::Compute::CpuAdamWOptimizer< TPrecision >::MR = CpuMemoryResource |
| using Mila::Dnn::Compute::CpuAdamWOptimizer< TPrecision >::TensorType = Tensor<TPrecision, MR> |
|
inlineexplicit |
Construct CPU AdamW optimizer.
| exec_context | CPU execution context |
| learning_rate | Initial learning rate (typical: 1e-3 to 1e-4) |
| beta1 | Exponential decay rate for first moment (typical: 0.9) |
| beta2 | Exponential decay rate for second moment (typical: 0.999) |
| epsilon | Small constant for numerical stability (typical: 1e-8) |
| weight_decay | Weight decay coefficient (typical: 0.01) |
| std::invalid_argument | if exec_context is null |
| std::invalid_argument | if learning_rate <= 0 |
| std::invalid_argument | if beta1, beta2 not in (0, 1) |
| std::invalid_argument | if epsilon <= 0 |
|
overridedefault |
|
inlineoverridevirtual |
Register a parameter-gradient pair for optimization.
The optimizer does not take ownership of the parameter/gradient tensors. The caller (typically a Module) must ensure the tensors remain valid for the lifetime of the optimizer.
Allocates momentum and variance state tensors on CPU matching the parameter shape. State tensors are zero-initialized.
| param | Parameter tensor to optimize (non-owning, must be on CPU) |
| grad | Gradient tensor (non-owning, must match param shape and device) |
| std::invalid_argument | if param or grad is null |
| std::invalid_argument | if param and grad shapes don't match |
| std::invalid_argument | if param or grad is not a CPU tensor |
| std::invalid_argument | if param or grad data type doesn't match optimizer precision |
| std::runtime_error | if state allocation fails |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cpu, TPrecision >.

|
inlinenoexcept |
Get beta1 parameter.

|
inlinenoexcept |
Get beta2 parameter.

|
inlinenoexcept |
Get epsilon parameter.

|
inlineoverridevirtual |
Zero all gradient tensors.
Clears all registered gradient tensors on CPU.
| std::runtime_error | if no parameters have been registered |
Get current learning rate.
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cpu, TPrecision >.
|
inlinenoexcept |
Get number of registered parameter groups.
|
inlinenoexcept |
Get current step count.
|
inlinenoexcept |
Get weight decay parameter.

|
inlineoverridevirtual |
Set learning rate for future steps.
| learning_rate | New learning rate (must be positive) |
| std::invalid_argument | if learning_rate <= 0 |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cpu, TPrecision >.
|
inlinestaticprivate |
Convert shape to string for error messages.


|
inlineoverridevirtual |
Perform one AdamW optimization step.
Updates all registered parameters on CPU using scalar loops. Execution is synchronous.
| std::runtime_error | if no parameters have been registered |
Implements Mila::Dnn::Compute::Optimizer< DeviceType::Cpu, TPrecision >.

|
inlineprivate |
Update a single parameter using AdamW algorithm.
Performs the AdamW update for a single parameter tensor using scalar loops. Implements the complete AdamW algorithm including:
| param_data | Parameter data pointer |
| grad_data | Gradient data pointer |
| m_data | First moment state pointer |
| v_data | Second moment state pointer |
| num_params | Number of scalar parameters |
| beta1_correction | Bias correction for first moment (1 - beta1^t) |
| beta2_correction | Bias correction for second moment (1 - beta2^t) |


|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |