Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision > Class Template Referenceexport

CUDA implementation of the Lpe (token + positional embedding) operation. More...

Inheritance diagram for Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >:

Public Types

using ConfigType = LpeConfig
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TInput, TPrecision>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
using DataTypeTraits

Public Member Functions

 CudaLpeOp (IExecutionContext *context, const LpeConfig &config)
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override
 Backward pass accumulating gradients into wte and wpe (hot path).
void build (const BuildContext &config) override
 Prepare the operation for a concrete input shape (cold path).
void decode (const ITensor &input, ITensor &output, int position) override
 Chunked prefill with explicit position offset.
void forward (const ITensor &input, ITensor &output) const override
 Full-sequence forward pass (hot path).
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void setGradients (ITensor *wte_grad, ITensor *wpe_grad) override
 Bind wte and wpe gradient tensors for training (module retains ownership).
void setParameters (ITensor *wte, ITensor *wpe) override
 Bind wte and wpe parameter tensors (module retains ownership).
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.
Public Member Functions inherited from Mila::Dnn::Compute::IPositionalDecode
virtual ~IPositionalDecode ()=default

Private Member Functions

void validateInputShape (const shape_t &input_shape) const

Private Attributes

int batch_size_ { 0 }
LpeConfig config_
CudaExecutionContextcontext_
int embedding_dim_ { 0 }
int seq_length_ { 0 }
NativeTypewpe_ { nullptr }
int wpe_embedding_dim_ { 0 }
NativeTypewpe_grad_ { nullptr }
int wpe_max_seq_len_ { 0 }
NativeTypewte_ { nullptr }
int wte_embedding_dim_ { 0 }
NativeTypewte_grad_ { nullptr }
int wte_vocab_size_ { 0 }

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
bool is_built_
TrainingMode training_mode_

Detailed Description

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
requires PrecisionSupportedOnDevice<TPrecision, DeviceType::Cuda>
class Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >

CUDA implementation of the Lpe (token + positional embedding) operation.

Combines a token embedding lookup (wte) with a positional embedding lookup (wpe) on CUDA devices, supporting FP32 and FP16 precision.

Design:

  • Two-phase initialization: build() performs all setup once; forward(), backward(), and decode() are pure hot-path dispatch with no per-call overhead.
  • Parameters (wte, wpe) are bound via setParameters() before build() and owned by the calling Lpe component.
  • Token indices (INT32) are non-differentiable; no input gradient is produced.
  • Implements IPositionalDecode so the owning Lpe component can call decode() with the correct absolute sequence position during KV-cache autoregressive generation, avoiding the wpe[0] bug that forward() with T=1 would produce.
Template Parameters
TInputData type of token index input (typically INT32).
TPrecisionPrecision of embedding output (FP32 or FP16).

Member Typedef Documentation

◆ ConfigType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::ConfigType = LpeConfig

◆ CudaExecutionContext

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::MR = CudaDeviceMemoryResource

◆ NativeType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type

◆ TensorType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TInput, TPrecision>

Constructor & Destructor Documentation

◆ CudaLpeOp()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::CudaLpeOp ( IExecutionContext * context,
const LpeConfig & config )
inline

Member Function Documentation

◆ backward()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Backward pass accumulating gradients into wte and wpe (hot path).

Token indices are non-differentiable; input_grad is unused.

Parameters
inputToken indices used in forward [B, T] (INT32).
output_gradUpstream embedding gradient [B, T, C].
input_gradUnused (non-differentiable input).

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >.

◆ build()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::build ( const BuildContext & config)
inlineoverridevirtual

Prepare the operation for a concrete input shape (cold path).

Validates parameters, caches B, T, and C for hot-path dispatch, and verifies that the sequence length fits within the positional embedding table. Must be called after setParameters() and before forward(), backward(), or decode().

Parameters
input_shapeToken index input shape [B, T].
Exceptions
std::runtime_errorif parameters are not bound.
std::invalid_argumentif input shape is invalid or sequence length exceeds the positional embedding capacity.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ decode()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::decode ( const ITensor & input,
ITensor & output,
int position )
inlineoverridevirtual

Chunked prefill with explicit position offset.

Computes output[b,t,:] = wte[X[b,t],:] + wpe[position_offset + t,:] by shifting the wpe base pointer before calling the standard forward kernel. No dedicated prefill kernel is needed.

Parameters
inputToken indices [B, T] (INT32).
outputPre-allocated embeddings [B, T, C].
position_offsetAbsolute position of the first token in this chunk.

Single-token decode with an explicit sequence position (hot path).

Computes output[b,:] = wte[X[b,0],:] + wpe[position,:] for each batch element. The dispatch implementation shifts the wpe pointer to row position and calls the forward kernel with T=1, so no dedicated decode kernel is required.

Parameters
inputSingle-token indices [B, 1] (INT32).
outputPre-allocated output buffer [B, 1, C].
positionZero-based absolute sequence position for the wpe lookup.

Implements Mila::Dnn::Compute::IPositionalDecode.

◆ forward()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Full-sequence forward pass (hot path).

For each (b, t): output[b,t,:] = wte[X[b,t],:] + wpe[t,:].

Parameters
inputToken indices [B, T] (INT32).
outputPre-allocated embeddings [B, T, C].
Exceptions
std::runtime_errorif the input shape exceeds the built maximum.

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >.

◆ getName()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
std::string Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ getOperationType()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
OperationType Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::getOperationType ( ) const
inlineoverridevirtual

◆ setGradients()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::setGradients ( ITensor * wte_grad,
ITensor * wpe_grad )
inlineoverridevirtual

Bind wte and wpe gradient tensors for training (module retains ownership).

Parameters
wte_gradGradient buffer for wte — CUDA tensor of shape [vocab_size, C].
wpe_gradGradient buffer for wpe — CUDA tensor of shape [max_seq_len, C].
Exceptions
std::invalid_argumenton null or non-CUDA tensors.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ setParameters()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::setParameters ( ITensor * wte,
ITensor * wpe )
inlineoverridevirtual

Bind wte and wpe parameter tensors (module retains ownership).

Caches native device pointers and validates tensor shapes against the configuration. Must be called before build().

Parameters
wteToken embedding table — CUDA tensor of shape [vocab_size, C].
wpePositional embedding table — CUDA tensor of shape [max_seq_len, C].
Exceptions
std::invalid_argumenton null, non-CUDA, or shape-mismatched tensors.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ validateInputShape()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineprivate
Here is the caller graph for this function:

Member Data Documentation

◆ batch_size_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::batch_size_ { 0 }
private

◆ config_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
LpeConfig Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::config_
private

◆ context_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::context_
private

◆ embedding_dim_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::embedding_dim_ { 0 }
private

◆ seq_length_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::seq_length_ { 0 }
private

◆ wpe_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wpe_ { nullptr }
private

◆ wpe_embedding_dim_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wpe_embedding_dim_ { 0 }
private

◆ wpe_grad_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wpe_grad_ { nullptr }
private

◆ wpe_max_seq_len_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wpe_max_seq_len_ { 0 }
private

◆ wte_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wte_ { nullptr }
private

◆ wte_embedding_dim_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wte_embedding_dim_ { 0 }
private

◆ wte_grad_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wte_grad_ { nullptr }
private

◆ wte_vocab_size_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::Lpe::CudaLpeOp< TInput, TPrecision >::wte_vocab_size_ { 0 }
private

The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Lpe/CudaLpeOp.ixx