Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::StructuralOps Struct Referenceexport
Inheritance diagram for Mila::Dnn::Compute::Cuda::StructuralOps:

Static Public Member Functions

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
static void split (const Dnn::Tensor< TDataType, TMemoryResource > &input, Dnn::Tensor< TDataType, TMemoryResource > &out0, Dnn::Tensor< TDataType, TMemoryResource > &out1, Dnn::Tensor< TDataType, TMemoryResource > &out2, IExecutionContext *exec_context=nullptr)
 Split a tensor along its last dimension into three contiguous output tensors.

Member Function Documentation

◆ split()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::Compute::Cuda::StructuralOps::split ( const Dnn::Tensor< TDataType, TMemoryResource > & input,
Dnn::Tensor< TDataType, TMemoryResource > & out0,
Dnn::Tensor< TDataType, TMemoryResource > & out1,
Dnn::Tensor< TDataType, TMemoryResource > & out2,
IExecutionContext * exec_context = nullptr )
inlinestatic

Split a tensor along its last dimension into three contiguous output tensors.

Input shape: [B, T, D0+D1+D2] Output shapes: out0[B, T, D0], out1[B, T, D1], out2[B, T, D2]

Preconditions:

  • All tensors must be on the same CUDA device.
  • Input last dim must equal sum of output last dims.
  • D0, D1, D2 must be multiples of 4 (float4 vectorization).
  • Input and outputs must be rank-3 tensors.
Template Parameters
TDataTypeTensor element type — drives vectorization width.
TMemResMemory resource type backing the tensors.
Here is the call graph for this function:

The documentation for this struct was generated from the following file: