8#include <cuda_runtime.h>
30 const float* __restrict__ src,
31 float* __restrict__ out_a,
32 float* __restrict__ out_b,
33 float* __restrict__ out_c,
35 int dim_a,
int dim_b,
int dim_c,
36 cudaStream_t stream );
39 const __nv_bfloat16* __restrict__ src,
40 __nv_bfloat16* __restrict__ out0,
41 __nv_bfloat16* __restrict__ out1,
42 __nv_bfloat16* __restrict__ out2,
44 int D0,
int D1,
int D2,
45 cudaStream_t stream );
Definition CublasLt.Utils.ixx:15
void cuda_split3_bf16(const __nv_bfloat16 *__restrict__ src, __nv_bfloat16 *__restrict__ out0, __nv_bfloat16 *__restrict__ out1, __nv_bfloat16 *__restrict__ out2, int rows, int D0, int D1, int D2, cudaStream_t stream)
void cuda_split3_fp32(const float *__restrict__ src, float *__restrict__ out_a, float *__restrict__ out_b, float *__restrict__ out_c, int src_rows, int dim_a, int dim_b, int dim_c, cudaStream_t stream)
Vectorized 3-way last-dimension split, float32.