Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision > Class Template Referenceabstractexport
Inheritance diagram for Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >:
Collaboration diagram for Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >:

Public Types

using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
using NetworkBase = Network<TDeviceType, TPrecision>
using TensorType = Tensor<TPrecision, MR>
using TokenIndexType = Tensor<TensorDataType::INT32, MR>
Public Types inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
using ComponentPtr = typename CompositeBase::ComponentPtr
using CompositeBase = CompositeComponent<TDeviceType, TPrecision>
using MR = typename DeviceTypeTraits<TDeviceType>::memory_resource
Public Types inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
using ComponentBase = Component<TDeviceType, TPrecision>
using ComponentPtr = std::shared_ptr<Component<TDeviceType, TPrecision>>

Public Member Functions

 LanguageNetwork (const std::string &name)
 ~LanguageNetwork () override=default
virtual TokenIndexTypebackward (const TokenIndexType &input, const TensorType &output_grad)=0
 Full backward pass (training).
virtual TensorTypedecode (const TokenIndexType &input, int position)=0
 Inference decode — single-token autoregressive step.
virtual TensorTypeforward (const TokenIndexType &input)=0
 Full-sequence forward pass.
virtual TensorTypeprefill (const TokenIndexType &input)=0
 Inference prefill — process full prompt and populate the KV cache.
Public Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
 Network (const std::string &name)
 Construct network (context managed by derived class).
 ~Network () override=default
template<typename TOptimizer, typename TConfig>
std::shared_ptr< TOptimizer > createOptimizer (const TConfig &config)
 Create and configure an optimizer for this network's parameters.
DeviceId getDeviceId () const noexcept
 Get the compute device for this composite.
const ComponentType getType () const override
 Get the component type identifier.
void save (ModelArchive &archive, SerializationMode mode) const
 Save network to archive.
void synchronize () override
 Synchronize all child components.
std::string toString () const override
 Generate a human-readable description.
Public Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
 CompositeComponent (CompositeComponent &&) noexcept=default
 CompositeComponent (const CompositeComponent &)=delete
 CompositeComponent (const std::string &name)
 Construct composite component with name.
virtual ~CompositeComponent ()=default
CompositeComponentaddComponent (ComponentPtr component)
 Add a pre-constructed child component (chainable).
size_t childCount () const noexcept
 Get the number of direct children.
void clearComponents ()
 Clear all child components.
ComponentPtr findComponent (const std::string &path) const
 Resolve a dot-separated component path within this composite.
ComponentPtr getComponent (const std::string &name) const
 Retrieve a direct child component by name.
const std::vector< ComponentPtr > & getComponents () const
 Get all child components in insertion order.
std::vector< ITensor * > getGradients () const override
 Get all parameter gradients from all children.
std::vector< ITensor * > getParameters () const override
 Get all parameters from all children.
bool hasChildren () const noexcept
 Check if this composite has any children.
bool hasComponent (const std::string &name) const
 Check if a named child component exists.
CompositeComponentoperator= (CompositeComponent &&) noexcept=default
CompositeComponentoperator= (const CompositeComponent &)=delete
size_t parameterCount () const override
 Count parameters across all children.
bool removeComponent (const std::string &name)
 Get the named child components map.
ComponentPtr tryFindComponent (const std::string &path) const
 Try to resolve a dot-separated component path within this composite.
Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
Protected Member Functions inherited from Mila::Dnn::Network< TDeviceType, TPrecision >
virtual void save_ (ModelArchive &archive, SerializationMode mode) const =0
 Hook for concrete classes to save type-specific state.
void verifyArchitectureCompatibility (const PretrainedMetadata &metadata)
 Verify that imported model is compatible with network architecture.
Protected Member Functions inherited from Mila::Dnn::CompositeComponent< TDeviceType, TPrecision >
template<typename TComponent>
std::shared_ptr< TComponent > getComponentAs (const std::string &name) const
 Retrieve a typed child component by name.
void onExecutionContextSet () override
 Hook invoked after ExecutionContext is set.
void onTrainingModeChanging (TrainingMode training_mode) override
 Hook invoked when training mode is about to change.
virtual void optimize ()
 Virtual hook for graph optimization after construction.
Protected Member Functions inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.
Protected Attributes inherited from Mila::Dnn::Component< TDeviceType, TPrecision >
BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.

Member Typedef Documentation

◆ MR

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::MR = typename DeviceTypeTraits<TDeviceType>::memory_resource

◆ NetworkBase

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::NetworkBase = Network<TDeviceType, TPrecision>

◆ TensorType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::TensorType = Tensor<TPrecision, MR>

◆ TokenIndexType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::TokenIndexType = Tensor<TensorDataType::INT32, MR>

Constructor & Destructor Documentation

◆ LanguageNetwork()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::LanguageNetwork ( const std::string & name)
inlineexplicit

◆ ~LanguageNetwork()

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::~LanguageNetwork ( )
overridedefault

Member Function Documentation

◆ backward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual TokenIndexType & Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::backward ( const TokenIndexType & input,
const TensorType & output_grad )
pure virtual

Full backward pass (training).

Parameters
inputToken indices [B, T].
output_gradGradient of the loss w.r.t. logits.
Returns
Gradient w.r.t. the input embeddings.

◆ decode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual TensorType & Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::decode ( const TokenIndexType & input,
int position )
pure virtual

Inference decode — single-token autoregressive step.

Parameters
inputSingle token index [B, 1].
positionCurrent sequence position (0-based).
Returns
Logits [B, 1, vocab_size].

◆ forward()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual TensorType & Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::forward ( const TokenIndexType & input)
pure virtual

Full-sequence forward pass.

Parameters
inputToken indices [B, T].
Returns
Logits [B, T, vocab_size].

◆ prefill()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual TensorType & Mila::Dnn::LanguageNetwork< TDeviceType, TPrecision >::prefill ( const TokenIndexType & input)
pure virtual

Inference prefill — process full prompt and populate the KV cache.

Parameters
inputFull prompt token indices [B, T].
Returns
Logits for the last token position.

The documentation for this class was generated from the following file: