|
Mila 0.13.48
Deep Neural Network Library
|
#include <cublasLt.h>#include <cuda_runtime.h>#include <cuda_fp16.h>Go to the source code of this file.
Namespaces | |
| namespace | Mila |
| Mila main API namespace. | |
| namespace | Mila::Dnn |
| namespace | Mila::Dnn::Compute |
Functions | |
| template<typename Tp, typename Tg> | |
| void | Mila::Dnn::Compute::adamw_update (Tp *params_memory, float *master_params_memory, Tg *grads_memory, float *m_memory, float *v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, float grad_scale, unsigned int seed, cudaStream_t stream) |