Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
CudaOptimizers.h
Go to the documentation of this file.
1#pragma once
2
3#include <cublasLt.h>
4#include <cuda_runtime.h>
5#include <cuda_fp16.h>
6
7namespace Mila::Dnn::Compute
8{
9 template <typename Tp, typename Tg>
11 Tp* params_memory,
12 float* master_params_memory,
13 Tg* grads_memory,
14 float* m_memory,
15 float* v_memory,
16 size_t num_parameters,
17 ptrdiff_t w_stride,
18 ptrdiff_t g_stride,
19 ptrdiff_t s_stride,
20 int num_slices,
21 float learning_rate,
22 float beta1,
23 float beta2,
24 int t,
25 float eps,
26 float weight_decay,
27 float grad_scale,
28 unsigned int seed,
29 cudaStream_t stream );
30}
Definition Device.ixx:15
void adamw_update(Tp *params_memory, float *master_params_memory, Tg *grads_memory, float *m_memory, float *v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, float grad_scale, unsigned int seed, cudaStream_t stream)