Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Components.Lpe Module Reference

Exported Modules

module  Dnn.TensorTypes
module  Dnn.TensorDataTypeTraits
module  Serialization.ModelArchive
module  Compute.MemoryResource
module  Dnn.ComponentType
module  Logging.Logger
module  Dnn.TensorOps
module  Compute.OperationTraits
module  Compute.UnaryOperation
module  Dnn.TensorDataType
module  Compute.ExecutionContext
module  Dnn.ITensor
module  Compute.DeviceType
module  Serialization.Mode
module  Compute.ExecutionContextFactory
module  Dnn.Components.LpeConfig
module  Compute.CpuMemoryResource
module  Dnn.Component
module  Serialization.Tensor
module  Compute.IPositionalDecode
module  Compute.DeviceId
module  Dnn.Tensor
module  Compute.Device
module  Dnn.TensorHelpers
module  Compute.DeviceTypeTraits

Classes

class  Mila::Dnn::Lpe< TDeviceType, TIndex, TPrecision >
 Encoder module for token and positional embeddings (device-templated). More...

Typedefs

using ComponentBase = Component<TDeviceType, TPrecision>
using EmbeddingsTensorType = Tensor<TPrecision, MR>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using OpType = typename OperationTraits<OperationType::LpeOp, TDeviceType, TPrecision>::type
using TokenIndexType = Tensor<TIndex, MR>

Functions

 Lpe (const std::string &name, const LpeConfig &config, std::optional< DeviceId > device_id=std::nullopt)
 Construct Encoder component.
 ~Lpe () override=default
TokenIndexTypebackward (const TokenIndexType &input, const EmbeddingsTensorType &output_grad)
 Backward pass - compute parameter gradients and return owned input-grad.
void createOperation ()
EmbeddingsTensorTypedecode (const TokenIndexType &input, int position)
 Decode pass - single token embedding at a specific sequence position.
EmbeddingsTensorTypeforward (const TokenIndexType &input)
 Forward pass - returns component-owned embeddings tensor.
DeviceId getDeviceId () const override
 Get the compute device id associated with this component.
int64_t getEmbeddingDim () const noexcept
std::vector< ITensor * > getGradients () const override
 Return non-owning pointers to parameter gradient tensors.
int64_t getMaxSequenceLength () const noexcept
MemoryStats getMemoryStats () const override
 Return the current memory allocation breakdown for this component.
std::vector< ITensor * > getParameters () const override
 Return non-owning pointers to parameter tensors.
const ComponentType getType () const override
 Get the component type identifier.
int64_t getVocabularyLength () const noexcept
EmbeddingsTensorTypegetWpeGrad () const noexcept
EmbeddingsTensorTypegetWteGrad () const noexcept
void initializeParameterGradients ()
void initializeParameters ()
void loadParameter (const std::string &name, const ITensorBlob &blob) override
 Load a parameter from serialized tensor data.
void onBuilding (const BuildContext &build_config) override
 Hook invoked by build() to allocate component buffers.
void onExecutionContextSet () override
 Called after ExecutionContext is set on the base Component.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook called before TrainingMode transitions.
size_t parameterCount () const override
 Return number of trainable parameters.
void save_ (ModelArchive &archive, SerializationMode mode) const override
void synchronize () override
 Wait for outstanding device work submitted by this component.
std::string toString () const override
 Produce a short, human-readable description of the component.
void validateInputShape (const shape_t &input_shape) const
void validateInputShape (const TokenIndexType &input) const
void zeroGradients () override
 Clear all model-owned gradients for this component.

Variables

LpeConfig config_
std::unique_ptr< EmbeddingsTensorTypecurrent_output_view_ { nullptr }
IPositionalDecodedecode_path_ { nullptr }
std::unique_ptr< TokenIndexTypeinput_grad_ { nullptr }
int64_t max_batch_size_ { 0 }
int64_t max_seq_len_ { 0 }
std::shared_ptr< OpTypeoperation_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypeoutput_ { nullptr }
std::unique_ptr< IExecutionContextowned_exec_context_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewpe_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewpe_grad_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewte_ { nullptr }
std::unique_ptr< EmbeddingsTensorTypewte_grad_ { nullptr }

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Components/Encodings/Lpe/Lpe.ixx