25#include <cuda_runtime.h>
35 float* qkvr,
float* att,
37 int B,
int T,
int C,
int NH,
38 cudaStream_t stream );
42 half* qkvr, half* att,
44 int B,
int T,
int C,
int NH,
45 cudaStream_t stream );
48 template <
typename TPrecision>
57 cudaStream_t stream );
59 template <
typename TPrecision>
62 const TPrecision* dY_loss,
68 cudaStream_t stream );
void cuda_softmax_crossentropy_backward(TPrecision *dX, const TPrecision *dY_loss, const TPrecision *Y, const int *targets, int batch_size, int seq_len, int vocab_size, cudaStream_t stream)
void cuda_mha_forward_fp32(float *Y, float *qkvr, float *att, const float *X, int B, int T, int C, int NH, cudaStream_t stream)
void cuda_mha_forward_fp16(half *Y, half *qkvr, half *att, const half *X, int B, int T, int C, int NH, cudaStream_t stream)
void cuda_softmax_crossentropy_forward(TPrecision *Y_loss, TPrecision *Y, const TPrecision *X, const int *targets, int batch_size, int seq_len, int vocab_size, cudaStream_t stream)