Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Encoder< TDeviceType, TInput, TOutput > Class Template Referenceexport

An encoder module that provides token and positional embeddings. More...

Inheritance diagram for Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >:
Collaboration diagram for Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >:

Public Types

using ModuleBase = Module< TDeviceType, TInput, TOutput >
 Alias for base module type.
 
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 Memory resource type determined based on device type.
 
- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 

Public Member Functions

 Encoder (const std::string &device_name, const EncoderConfig &config)
 Constructs a new Encoder module with a device name.
 
 Encoder (std::shared_ptr< DeviceContext > device_context, const EncoderConfig &config)
 Constructs a new Encoder module with a provided device context.
 
void forward (const Tensor< TInput, MR > &input, Tensor< TOutput, MR > &output)
 Performs the forward pass of the encoder.
 
size_t getChannels () const
 Gets the number of channels (embedding dimension).
 
size_t getMaxSequenceLength () const
 Gets the maximum sequence length.
 
size_t getVocabularyLength () const
 Gets the vocabulary length.
 
void load (ModelArchive &archive) override
 Loads the encoder parameters from a zip archive.
 
size_t parameterCount () const override
 Gets the number of parameters in the module.
 
void save (ModelArchive &archive) const override
 Saves the encoder parameters to a zip archive.
 
std::string toString () const override
 Gets the module information as a string.
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 
virtual void setTraining (bool is_training)
 Set the training mode of this module.
 

Private Member Functions

void createOperation ()
 Creates the computational operation based on current device context.
 
void initializeTensors ()
 Initializes the token and positional embedding tensors.
 

Private Attributes

OperationAttributes attributes_
 Operation-specific attributes and configuration.
 
EncoderConfig config_
 Configuration for the Encoder module.
 
std::shared_ptr< UnaryOperation< TDeviceType, TInput, TOutput > > operation_ { nullptr }
 The computational operation that implements the encoder logic.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > output_state_
 Output state tensors used for intermediate values.
 
std::vector< std::shared_ptr< Tensor< TOutput, MR > > > parameters_
 Vector of parameter tensors that will be used during forward/backward passes.
 
std::shared_ptr< Tensor< TOutput, MR > > wpe_ { nullptr }
 Position embedding table with shape (maxT,C), encodes token position information.
 
std::shared_ptr< Tensor< TOutput, MR > > wte_ { nullptr }
 Token embedding table with shape (V,C), maps token IDs to vector representations.
 

Additional Inherited Members

- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Detailed Description

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
requires ValidTensorType<TInput>&& ValidFloatTensorType<TOutput>
class Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >

An encoder module that provides token and positional embeddings.

The Encoder transforms input token IDs into continuous vector representations by:

  1. Looking up token embeddings from a vocabulary table (wte)
  2. Adding positional embeddings (wpe) based on sequence position

This implementation supports both CPU and CUDA execution depending on the device context. The encoder is a fundamental component in transformer architectures, providing the initial representation of tokens that subsequent layers will process.

Template Parameters
TDeviceTypeThe device type (CPU or CUDA) on which to perform computations.
TInputThe data type of the input token IDs (typically int).
TOutputThe data type of the output embeddings (typically float).

Member Typedef Documentation

◆ ModuleBase

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
using Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::ModuleBase = Module<TDeviceType, TInput, TOutput>
export

Alias for base module type.

◆ MR

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
using Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::MR = std::conditional_t<TDeviceType == DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource>
export

Memory resource type determined based on device type.

Constructor & Destructor Documentation

◆ Encoder() [1/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::Encoder ( const std::string &  device_name,
const EncoderConfig config 
)
inlineexplicitexport

Constructs a new Encoder module with a device name.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
configConfiguration parameters for the Encoder module.
Exceptions
std::invalid_argumentIf the device name is invalid or the configuration is invalid
std::runtime_errorIf device type doesn't match template parameter TDeviceType

◆ Encoder() [2/2]

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::Encoder ( std::shared_ptr< DeviceContext device_context,
const EncoderConfig config 
)
inlineexplicitexport

Constructs a new Encoder module with a provided device context.

Parameters
device_contextThe device context to use for this module.
configConfiguration parameters for the Encoder module.
Exceptions
std::invalid_argumentIf device_context is null or configuration is invalid
std::runtime_errorIf device context type doesn't match template parameter TDeviceType

Member Function Documentation

◆ createOperation()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
void Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::createOperation ( )
inlineexportprivate

Creates the computational operation based on current device context.

Instantiates either a CPU or CUDA encoder operation based on the current device context. The operation implements the actual embedding lookup and addition logic during forward pass.

◆ forward()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
void Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::forward ( const Tensor< TInput, MR > &  input,
Tensor< TOutput, MR > &  output 
)
inlineexport

Performs the forward pass of the encoder.

Transforms input token IDs into continuous embeddings by:

  1. Looking up token embeddings from the embedding table (wte)
  2. Adding positional embeddings (wpe) based on token position
Parameters
inputThe input tensor containing token IDs with shape (B,T).
outputThe output tensor that will contain embeddings with shape (B,T,C).

◆ getChannels()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
size_t Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::getChannels ( ) const
inlineexport

Gets the number of channels (embedding dimension).

Returns
size_t The number of channels (C).

◆ getMaxSequenceLength()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
size_t Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::getMaxSequenceLength ( ) const
inlineexport

Gets the maximum sequence length.

Returns
size_t The maximum sequence length (maxT).

◆ getVocabularyLength()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
size_t Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::getVocabularyLength ( ) const
inlineexport

Gets the vocabulary length.

Returns
size_t The vocabulary length (V).

◆ initializeTensors()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
void Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::initializeTensors ( )
inlineexportprivate

Initializes the token and positional embedding tensors.

Creates and initializes:

  • wte (word token embeddings) tensor of shape (vocab_len_, channels_)
  • wpe (word position embeddings) tensor of shape (max_seq_len_, channels_)

Both tensors are initialized using Xavier initialization to ensure proper gradient flow during training. The tensors are registered as parameters in the module's parameter map for training and serialization.

◆ load()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
void Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::load ( ModelArchive archive)
inlineoverrideexportvirtual

Loads the encoder parameters from a zip archive.

Deserializes all parameter tensors (wte and wpe) from the specified zip archive. This enables loading pretrained models for inference or continued training.

Parameters
zipThe zip archive to load the parameters from.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ parameterCount()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
size_t Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::parameterCount ( ) const
inlineoverrideexportvirtual

Gets the number of parameters in the module.

Counts all learnable parameters in the encoder, which includes all elements in the token embedding table (wte) and position embedding table (wpe).

Returns
size_t The total number of parameters.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ save()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
void Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::save ( ModelArchive archive) const
inlineoverrideexportvirtual

Saves the encoder parameters to a zip archive.

Serializes all parameter tensors (wte and wpe) to the specified zip archive. This enables model persistence for later reuse or distribution.

Parameters
zipThe zip archive to save the parameters to.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

◆ toString()

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::string Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::toString ( ) const
inlineoverrideexportvirtual

Gets the module information as a string.

Provides a human-readable description of the encoder configuration, including dimensions, parameter counts, and tensor information.

Returns
std::string The module information.

Implements Mila::Dnn::Module< TDeviceType, TInput, TOutput >.

Member Data Documentation

◆ attributes_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
OperationAttributes Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::attributes_
exportprivate

Operation-specific attributes and configuration.

◆ config_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
EncoderConfig Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::config_
exportprivate

Configuration for the Encoder module.

◆ operation_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::shared_ptr<UnaryOperation<TDeviceType, TInput, TOutput> > Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::operation_ { nullptr }
exportprivate

The computational operation that implements the encoder logic.

◆ output_state_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::vector<std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::output_state_
exportprivate

Output state tensors used for intermediate values.

Not used in this module.

◆ parameters_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::vector<std::shared_ptr<Tensor<TOutput, MR> > > Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::parameters_
exportprivate

Vector of parameter tensors that will be used during forward/backward passes.

Contains both the token embeddings (wte) and position embeddings (wpe).

◆ wpe_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::shared_ptr<Tensor<TOutput, MR> > Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::wpe_ { nullptr }
exportprivate

Position embedding table with shape (maxT,C), encodes token position information.

maxT is the maximum sequence length and C is the embedding dimension.

◆ wte_

template<DeviceType TDeviceType = DeviceType::Cuda, typename TInput = int, typename TOutput = float>
std::shared_ptr<Tensor<TOutput, MR> > Mila::Dnn::Encoder< TDeviceType, TInput, TOutput >::wte_ { nullptr }
exportprivate

Token embedding table with shape (V,C), maps token IDs to vector representations.

V is the vocabulary size and C is the embedding dimension.


The documentation for this class was generated from the following file: