Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision > Class Template Referenceexport
Inheritance diagram for Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >:
Collaboration diagram for Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >:

Public Types

using ConfigType = TokenEmbeddingConfig
using CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using MR = CudaDeviceMemoryResource
using NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type
using TensorType = Tensor<TPrecision, MR>
using UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TInput, TPrecision>
Public Types inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
using MR
using TensorInputType
using TensorOutputType
Public Types inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
using DataTypeTraits

Public Member Functions

 CudaTokenEmbeddingOp (IExecutionContext *context, const TokenEmbeddingConfig &config)
void backward (const ITensor &input, const ITensor &output_grad, ITensor &input_grad) const override
 Backward pass accumulating gradients into wte (hot path).
void build (const BuildContext &config) override
 Prepare the operation for a concrete input shape (cold path).
void decode (const ITensor &input, ITensor &output) const
 Single-token decode pass (hot path).
void forward (const ITensor &input, ITensor &output) const override
 Full-sequence forward pass (hot path).
std::string getName () const override
 Human-readable operation name.
OperationType getOperationType () const override
 Operation type identifier.
void setGradients (ITensor *wte_grad, ITensor *) override
 Bind the wte gradient tensor for training (module retains ownership).
void setParameters (ITensor *wte, ITensor *) override
 Bind the wte parameter tensor (module retains ownership).
Public Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
virtual ~UnaryOperation ()=default
Public Member Functions inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
virtual ~Operation ()=default
virtual void clearGradients () noexcept
 Clear any cached gradient pointers held by the operation.
virtual TensorDataType getDataType () const
 Tensor data type for this operation.
virtual DeviceType getDeviceType () const
 Device type for this operation.
virtual std::size_t getStateMemorySize () const
 Returns the number of bytes of state memory allocated by this operation.
virtual bool isBuilt () const
 Whether build() completed successfully for a concrete input shape.
virtual bool isEvalMode () const
 Query whether operation is configured for training.
virtual void setTrainingMode (TrainingMode training_mode)
 Configure operation training-mode behavior.

Private Member Functions

void validateInputShape (const shape_t &shape) const
void validateRuntimeShape (int B, int T) const

Private Attributes

int batch_size_ { 0 }
TokenEmbeddingConfig config_
CudaExecutionContextcontext_
int embedding_dim_ { 0 }
int seq_length_ { 0 }
int vocab_size_ { 0 }
NativeTypewte_ { nullptr }
NativeTypewte_grad_ { nullptr }

Additional Inherited Members

Static Public Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
static constexpr TensorDataType data_type
static constexpr DeviceType device_type
Static Protected Member Functions inherited from Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >
static const TensorInputTypeasInputTensor (const ITensor &t)
static TensorOutputTypeasOutputTensor (ITensor &t)
Protected Attributes inherited from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >
bool is_built_
TrainingMode training_mode_

Member Typedef Documentation

◆ ConfigType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::ConfigType = TokenEmbeddingConfig

◆ CudaExecutionContext

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>

◆ MR

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::MR = CudaDeviceMemoryResource

◆ NativeType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::NativeType = typename Mila::Dnn::Compute::Cuda::TensorDataTypeMap<TPrecision>::device_type

◆ TensorType

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ UnaryOperationBase

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
using Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::UnaryOperationBase = UnaryOperation<DeviceType::Cuda, TInput, TPrecision>

Constructor & Destructor Documentation

◆ CudaTokenEmbeddingOp()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::CudaTokenEmbeddingOp ( IExecutionContext * context,
const TokenEmbeddingConfig & config )
inline

Member Function Documentation

◆ backward()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::backward ( const ITensor & input,
const ITensor & output_grad,
ITensor & input_grad ) const
inlineoverridevirtual

Backward pass accumulating gradients into wte (hot path).

Token indices are non-differentiable; input_grad is unused.

Parameters
inputToken indices used in forward [B, T] (INT32).
output_gradUpstream embedding gradient [B, T, C].
input_gradUnused (non-differentiable input).

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >.

◆ build()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::build ( const BuildContext & config)
inlineoverridevirtual

Prepare the operation for a concrete input shape (cold path).

Parameters
input_shapeToken index input shape [B, T].
Exceptions
std::runtime_errorif wte is not bound.
std::invalid_argumentif input shape is invalid.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ decode()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::decode ( const ITensor & input,
ITensor & output ) const
inline

Single-token decode pass (hot path).

Computes output[b,:] = wte[X[b,0],:] for each batch element. No position argument — positional encoding is handled downstream.

Parameters
inputSingle-token indices [B, 1] (INT32).
outputPre-allocated output buffer [B, C].

◆ forward()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::forward ( const ITensor & input,
ITensor & output ) const
inlineoverridevirtual

Full-sequence forward pass (hot path).

For each (b, t): output[b,t,:] = wte[X[b,t],:].

Parameters
inputToken indices [B, T] (INT32).
outputPre-allocated embeddings [B, T, C].

Implements Mila::Dnn::Compute::UnaryOperation< DeviceType::Cuda, TInput, TInput >.

◆ getName()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
std::string Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::getName ( ) const
inlineoverridevirtual

Human-readable operation name.

Implements Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ getOperationType()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
OperationType Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::getOperationType ( ) const
inlineoverridevirtual

◆ setGradients()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::setGradients ( ITensor * wte_grad,
ITensor *  )
inlineoverridevirtual

Bind the wte gradient tensor for training (module retains ownership).

Parameters
wte_gradGradient buffer for wte — CUDA tensor of shape [vocab_size, C].
Exceptions
std::invalid_argumenton null or non-CUDA tensor.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ setParameters()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::setParameters ( ITensor * wte,
ITensor *  )
inlineoverridevirtual

Bind the wte parameter tensor (module retains ownership).

Parameters
wteToken embedding table — CUDA tensor of shape [vocab_size, C].
Exceptions
std::invalid_argumenton null, non-CUDA, or shape-mismatched tensor.

Reimplemented from Mila::Dnn::Compute::Operation< TDeviceType, TPrecision >.

◆ validateInputShape()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::validateInputShape ( const shape_t & shape) const
inlineprivate
Here is the caller graph for this function:

◆ validateRuntimeShape()

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::validateRuntimeShape ( int B,
int T ) const
inlineprivate
Here is the caller graph for this function:

Member Data Documentation

◆ batch_size_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::batch_size_ { 0 }
private

◆ config_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
TokenEmbeddingConfig Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::config_
private

◆ context_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
CudaExecutionContext* Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::context_
private

◆ embedding_dim_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::embedding_dim_ { 0 }
private

◆ seq_length_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::seq_length_ { 0 }
private

◆ vocab_size_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
int Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::vocab_size_ { 0 }
private

◆ wte_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::wte_ { nullptr }
private

◆ wte_grad_

template<TensorDataType TInput, TensorDataType TPrecision = TInput>
NativeType* Mila::Dnn::Compute::Cuda::TokenEmbedding::CudaTokenEmbeddingOp< TInput, TPrecision >::wte_grad_ { nullptr }
private

The documentation for this class was generated from the following file: