Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision > Class Template Referenceexport

Encoder module for token and positional embeddings (device-templated). More...

Inheritance diagram for Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >:
Collaboration diagram for Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >:

Public Types

using ComponentBase = Component<TDeviceType, TPrecision>
using EmbeddingsTensorType = Tensor<TPrecision, MR>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using TokenIndexType = Tensor<TIndex, MR>

Public Member Functions

 Lpe (const std::string &name, const LpeConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct Encoder component.
 ~Lpe () override=default
TokenIndexTypebackward (const TokenIndexType &input, const EmbeddingsTensorType &output_grad)
 Backward pass - compute parameter gradients and return owned input-grad.
EmbeddingsTensorTypedecode (const TokenIndexType &input, int position)
 Decode pass - single token embedding at a specific sequence position.
EmbeddingsTensorTypeforward (const TokenIndexType &input)
 Forward pass - returns component-owned embeddings tensor.
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
int64_t getEmbeddingDim () const noexcept
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
int64_t getMaxSequenceLength () const noexcept
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
int64_t getVocabularyLength () const noexcept
EmbeddingsTensorTypegetWpeGrad () const noexcept
EmbeddingsTensorTypegetWteGrad () const noexcept
void loadParameter (const std::string &name, const ITensorBlob &blob) override
 Load a parameter from serialized tensor data.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void zeroGradients () override
 Clear all model-owned gradients for this component.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.

Protected Member Functions

void onBuilding (const BuildContext &build_config) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Called after ExecutionContext is set on the base Component.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.

Private Types

using OpType = typename OperationTraits<OperationType::LpeOp, TDeviceType, TPrecision>::type

Private Member Functions

void createOperation ()
void initializeParameterGradients ()
void initializeParameters ()
void validateInputShape (const shape_t &input_shape) const
void validateInputShape (const TokenIndexType &input) const

Private Attributes

LpeConfig config_
std::unique_ptr< EmbeddingsTensorTypecurrent_output_view_ { nullptr }
IPositionalDecodedecode_path_ { nullptr }
std::unique_ptr< TokenIndexTypeinput_grad_ { nullptr }
int64_t max_batch_size_ { 0 }
int64_t max_seq_len_ { 0 }
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypeoutput_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewpe_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewpe_grad_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewte_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewte_grad_ { nullptr }

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >
BuildContext build_context_
 The BuildContext stored at build time.

Detailed Description

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
requires PrecisionSupportedOnDevice<TPrecision, TDeviceType>
class Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >

Encoder module for token and positional embeddings (device-templated).

Delegates computation to a device-specific UnaryOperation implementation registered in the OperationRegistry.

The Encoder transforms input token IDs into continuous vector representations:

  1. Looks up token embeddings from vocabulary table (wte)
  2. Adds positional embeddings (wpe) based on sequence position

Module owns trainable parameters (wte, wpe) and exposes them via accessors. The operation implements embedding lookup and position encoding addition.

Construction modes:

  • Standalone: provide a DeviceId to create and own an ExecutionContext.
  • Deferred/shared: omit DeviceId and caller must call setExecutionContext() before build().
Template Parameters
TDeviceTypeDevice type (DeviceType::Cpu or DeviceType::Cuda)
TIndexData type for token indices (typically INT32)
TPrecisionAbstract tensor precision (TensorDataType) for embeddings

Constructor & Destructor Documentation

◆ Lpe()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::Lpe ( const std::string & name,
const LpeConfig & config,
std::optional< DeviceId > device_id = std::nullopt )
inlineexplicitexport

Construct Encoder component.

Two modes:

  • Standalone mode: provide DeviceId and component will create and own an ExecutionContext.
  • Child mode: omit DeviceId and parent must call setExecutionContext() before build().
Parameters
nameComponent name identifier (mandatory)
configEncoder configuration
device_idOptional DeviceId to create owned ExecutionContext (standalone mode)

◆ ~Lpe()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::~Lpe ( )
overrideexportdefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
TokenIndexType & Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::backward ( const TokenIndexType & input,
const EmbeddingsTensorType & output_grad )
inlineexport

Backward pass - compute parameter gradients and return owned input-grad.

Token indices are discrete and not differentiable; the backend may still expect an input-gradient tensor. The component owns a token-index-typed input-gradient buffer that is passed to the backend and returned.

Parameters
inputInput token indices tensor used during forward.
output_gradGradient w.r.t. embeddings [B, T, C].
Returns
Reference to component-owned token-index-typed input-grad tensor.
Exceptions
std::runtime_errorif component is not built, not in training mode, or backend/buffers are not initialized.

◆ createOperation()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::createOperation ( )
inlineexportprivate

◆ decode()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
EmbeddingsTensorType & Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::decode ( const TokenIndexType & input,
int position )
inlineexport

Decode pass - single token embedding at a specific sequence position.

Unlike forward() which processes a full sequence [B, T] and uses positions 0..T-1, decode() processes a single token and uses the caller-supplied position for the positional embedding lookup. This is critical for correctness in KV cache autoregressive generation — without the correct position, wpe[0] would be used for every generated token, corrupting all subsequent attention computations.

Parameters
inputSingle token index tensor [1, 1]
positionActual sequence position (prefill_len + decode_step)
Returns
Reference to component-owned embedding tensor [1, 1, C]
Exceptions
std::runtime_errorif component is not built.

◆ forward()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
EmbeddingsTensorType & Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::forward ( const TokenIndexType & input)
inlineexport

Forward pass - returns component-owned embeddings tensor.

Parameters
inputInput token indices tensor [B, T]
Returns
Reference to component-owned embeddings tensor [B, T, C]
Exceptions
std::runtime_errorif component is not built or backend not initialized.

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
DeviceId Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getDeviceId ( ) const
inlineoverrideexportvirtual

Get the compute device id associated with this component.

Must return the device on which parameters and operations execute.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ getEmbeddingDim()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
int64_t Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getEmbeddingDim ( ) const
inlineexportnoexcept

◆ getGradients()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
std::vector< ITensor * > Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getGradients ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter gradient tensors.

Only valid when isTraining() is true.

Exceptions
std::runtime_errorif called when not in training mode or before the component has been built.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ getMaxSequenceLength()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
int64_t Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getMaxSequenceLength ( ) const
inlineexportnoexcept

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
MemoryStats Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getMemoryStats ( ) const
inlineoverrideexportvirtual

Return the current memory allocation breakdown for this component.

Reflects allocations at the moment of the call. The returned stats naturally track the component lifecycle:

After construction — parameters only After build( Inference ) — parameters + T=1 state buffers After build( Training ) — parameters + T=full state buffers After setEvaluation( false ) — parameters + state + gradients

For CompositeComponent and Network, the returned stats are the recursive aggregate of all child components.

May be called at any time — no lifecycle preconditions.

Returns
MemoryStats reflecting current allocations.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ getParameters()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
std::vector< ITensor * > Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getParameters ( ) const
inlineoverrideexportvirtual

Return non-owning pointers to parameter tensors.

The returned tensor pointers remain valid for the lifetime of the component. Order should be canonical (weights before biases).

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ getType()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
const ComponentType Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getType ( ) const
inlineoverrideexportvirtual

Get the component type identifier.

Used for serialization and runtime type identification.

Returns
Component type enum value.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ getVocabularyLength()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
int64_t Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getVocabularyLength ( ) const
inlineexportnoexcept

◆ getWpeGrad()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
EmbeddingsTensorType * Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getWpeGrad ( ) const
inlineexportnoexcept

◆ getWteGrad()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
EmbeddingsTensorType * Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::getWteGrad ( ) const
inlineexportnoexcept

◆ initializeParameterGradients()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::initializeParameterGradients ( )
inlineexportprivate

◆ initializeParameters()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::initializeParameters ( )
inlineexportprivate

◆ loadParameter()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::loadParameter ( const std::string & name,
const ITensorBlob & blob )
inlineoverrideexportvirtual

Load a parameter from serialized tensor data.

Loads raw tensor bytes directly into an existing parameter tensor, handling precision conversion and device upload as needed.

The component validates that the blob's shape matches the parameter's expected shape, then delegates to the backend to perform:

  • Precision conversion (blob dtype → parameter dtype)
  • Device upload (CPU bytes → target device)
Parameters
nameParameter name used to locate the target tensor.
blobSerialized tensor metadata and raw bytes.
Exceptions
std::runtime_errorif component has no parameters to load.
std::runtime_errorif blob shape doesn't match parameter shape.

Reimplemented from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ onBuilding()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::onBuilding ( const BuildContext & config)
inlineoverrideexportprotectedvirtual

Hook invoked by build() to allocate component buffers.

Receives the stored BuildContext. Implementations must use config.allocationSeqLen() when sizing output buffers — this is the single call that makes Inference and Training allocate the correct buffer sizes automatically without per-component logic.

// Example — Linear component:
shape_t out_shape =
{
config.batchSize(),
config.allocationSeqLen(), // 1 for Inference, T for Training
config_.getOutputFeatures()
};
output_ = std::make_unique<TensorType>( device, out_shape,
this->getName() + ".output" );
LpeConfig config_
Definition Lpe.ixx:504
TensorShape shape_t
Row-major shape descriptor for tensor dimensional sizes.
Definition Tensor.Types.ixx:143

The default implementation forwards to the legacy onBuilding( const shape_t& ) overload for backwards compatibility. New components should override this overload directly.

Note
Do not call build() or onBuilding() from within this hook.
Implementations should either succeed fully or leave no partial state, as a failed build() may be retried.
Parameters
configBuild-time configuration. Use config.allocationSeqLen() to obtain the correct output buffer sequence dimension.

Reimplemented from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ onExecutionContextSet()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::onExecutionContextSet ( )
inlineoverrideexportprotectedvirtual

Called after ExecutionContext is set on the base Component.

Initialize device-bound parameters and create the backend operation.

Reimplemented from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ onTrainingModeChanging()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::onTrainingModeChanging ( TrainingMode mode)
inlineoverrideexportprotectedvirtual

Hook called before TrainingMode transitions.

Called by setTrainingMode() after validation and lock acquisition, before the internal state is updated. Derived classes override to respond to the transition — e.g. zeroing gradient buffers on transition to Eval, or re-enabling dropout on transition to Training.

The default implementation is a no-op.

Parameters
modeThe incoming TrainingMode.

Reimplemented from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ parameterCount()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
size_t Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::parameterCount ( ) const
inlineoverrideexportvirtual

Return number of trainable parameters.

For leaf components this is the element count of owned parameter tensors. CompositeComponent and Network implementations should return the recursive aggregate across all children.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ save_()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::save_ ( ModelArchive & archive,
SerializationMode mode ) const
inlineoverrideexportvirtual

◆ synchronize()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::synchronize ( )
inlineoverrideexportvirtual

Wait for outstanding device work submitted by this component.

On CPU this may be a no-op. Use to ensure results are visible to the host or to measure synchronous timings.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ toString()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
std::string Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::toString ( ) const
inlineoverrideexportvirtual

Produce a short, human-readable description of the component.

Implementations should keep output concise and avoid throwing.

Implements Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.

◆ validateInputShape() [1/2]

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::validateInputShape ( const shape_t & input_shape) const
inlineexportprivate

◆ validateInputShape() [2/2]

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::validateInputShape ( const TokenIndexType & input) const
inlineexportprivate

◆ zeroGradients()

template<DeviceType TDeviceType, TensorDataType TIndex = dtype_t::INT32, TensorDataType TPrecision = dtype_t::FP32>
void Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >::zeroGradients ( )
inlineoverrideexportvirtual

Clear all model-owned gradients for this component.

Default implementation is a no-op. Composite components should override to recurse to children. Leaf components should override to zero their parameter and activation gradients using device-aware helpers.

Reimplemented from Mila::Dnn::Component< TDeviceType, dtype_t::FP32 >.


The documentation for this class was generated from the following file:
  • /__w/Mila/Mila/Mila/Src/Dnn/Components/Encodings/Lpe/Lpe.ixx