|
| CudaEncoderOp (const EncoderConfig &config) |
| Constructs a new CUDA Encoder operation with the default device context.
|
|
| CudaEncoderOp (std::shared_ptr< DeviceContext > context, const EncoderConfig &config) |
| Constructs a new CUDA Encoder operation with a specific device context.
|
|
void | backward (const Tensor< int, MR > &input, const Tensor< TOutput, MR > &output, const Tensor< TOutput, MR > &output_gradient, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meter_gradients, Tensor< int, MR > &input_gradient, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const |
| Performs the backward pass of the Encoder operation.
|
|
void | forward (const Tensor< int, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const override |
| Performs the forward pass of the Encoder operation on CUDA.
|
|
std::string | getName () const override |
| Gets the name of this operation.
|
|
| UnaryOperation (OperationType operation_type) |
| Constructs a UnaryOperation with the specified operation type.
|
|
| UnaryOperation (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs a UnaryOperation with the specified operation type and device context.
|
|
virtual | ~UnaryOperation ()=default |
| Virtual destructor for proper cleanup of derived classes.
|
|
virtual void | backward (const Tensor< TInput, MR > &grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_grads) const |
| Executes the backward pass of a unary operation.
|
|
virtual void | backward (const Tensor< TInput, MR > &input, const Tensor< TOutput, MR > &output_grad, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meter_grads, Tensor< TInput, MR > &input_grad, const OperationAttributes &properties, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const |
| Executes the comprehensive backward pass of a unary operation.
|
|
virtual void | forward (const Tensor< TInput, MR > &input, const std::vector< std::shared_ptr< Tensor< TOutput, MR > > > ¶meters, const OperationAttributes &properties, Tensor< TOutput, MR > &output, std::vector< std::shared_ptr< Tensor< TOutput, MR > > > &output_state) const =0 |
| Executes the forward pass of a unary operation.
|
|
| OperationBase (OperationType operation_type, std::shared_ptr< DeviceContext > context) |
| Constructs an OperationBase object with a specific device context and compute precision.
|
|
virtual | ~OperationBase ()=default |
| Virtual destructor for the OperationBase class.
|
|
std::shared_ptr< DeviceContext > | getDeviceContext () const |
| Gets the device context associated with this operation.
|
|
DeviceType | getDeviceType () const |
| Gets the device type for this operation.
|
|
OperationType | getOperationType () const |
| Gets the operation type enumeration value.
|
|
template<typename TInput, typename TOutput = TInput>
requires ValidFloatTensorTypes<TInput,TOutput>
class Mila::Dnn::Compute::CudaEncoderOp< TInput, TOutput >
CUDA implementation of the Encoder operation for transformer models.
This class provides a CUDA-based implementation of the Encoder operation, which performs token embedding lookups and positional embedding additions. It transforms discrete token IDs into continuous vector representations by combining:
- Token embeddings from a learned vocabulary table (wte)
- Positional embeddings that encode sequence position information (wpe)
The implementation is optimized for NVIDIA GPUs using CUDA for high-performance computation, supporting both integer and half-precision floating-point operations.
- Template Parameters
-
int | The data type of the input tensor elements (typically uint16_t or int for token IDs). |
TDataType | The data type used for computation and output (typically half or float). |