template<
DeviceType TDeviceType,
TensorDataType TPrecision>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Optimizers::AdamWOptimizer< TDeviceType, TPrecision >
Device-agnostic AdamW optimizer.
Dispatches to the appropriate device-specific implementation (CPU or CUDA) based on the TDeviceType template parameter. Uses AdamWConfig for fluent configuration of hyperparameters.
- Template Parameters
-
Set the learning rate for future updates.
Updates the base learning rate used by the optimizer. Typically used for learning rate schedules (decay, warmup, cyclic, etc.).
- Parameters
-
| learning_rate | New learning rate (must be positive) |
- Exceptions
-
| std::invalid_argument | if learning_rate <= 0 |
- Note
- Takes effect immediately for the next step() call
-
Does not affect optimizer state (momentum, variance)
-
For learning rate schedules, call this at epoch or iteration boundaries
- See also
- getLearningRate()
Example with learning rate decay:
float initial_lr = 0.001f;
optimizer->setLearningRate(initial_lr);
for (size_t epoch = 0; epoch < num_epochs; ++epoch) {
if (epoch > 0 && epoch % 10 == 0) {
float new_lr = optimizer->getLearningRate() * 0.5f;
optimizer->setLearningRate(new_lr);
std::cout << "Learning rate: " << new_lr << std::endl;
}
}
Implements Mila::Dnn::Compute::Optimizer< TDeviceType, TPrecision >.
Perform one optimization step.
Updates all registered parameters using their accumulated gradients according to the optimizer's update rule (SGD, Adam, AdamW, etc.). This is the HOT PATH method called every training iteration.
For algorithms with state (Adam, AdamW):
- Updates first and second moment estimates
- Applies bias correction if needed
- Computes parameter update
- Writes updated parameters back to tensors
- Exceptions
-
| std::runtime_error | if no parameters have been registered |
| std::runtime_error | if gradient data is invalid or null |
- Note
- Gradients should be computed via backward() before calling step()
-
For CUDA implementations, may be asynchronous (uses device stream)
-
Increments internal step counter for algorithms requiring it (Adam, AdamW)
- See also
- addParameter()
-
backward()
Typical sequence:
model->zeroGradients();
model->forward(input, output);
loss = computeLoss(output, target);
model->backward(input, loss_grad);
optimizer->step();
Implements Mila::Dnn::Compute::Optimizer< TDeviceType, TPrecision >.