Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Model< TDeviceType, TPrecision > Class Template Referenceabstractexport
module Dnn.Model
Inheritance diagram for Mila::Dnn::Model< TDeviceType, TPrecision >:

Public Types

using NetworkType = Network<TDeviceType, TPrecision>

Public Member Functions

 Model (const Model &)=delete
 Model (Model &&)=default
virtual ~Model ()=default
DeviceId getDeviceId () const noexcept
 The device this model runs on.
MemoryStats getMemoryStats () const
 Current memory allocation breakdown for this model.
RuntimeMode getRuntimeMode () const noexcept
 The runtime mode this model was constructed for.
bool isEval () const noexcept
 True if this model is currently in eval sub-state.
bool isInferenceMode () const noexcept
 True if this model was constructed for inference.
bool isTrainingMode () const noexcept
 True if this model was constructed for training.
Modeloperator= (const Model &)=delete
Modeloperator= (Model &&)=default
void setEval (bool eval)
 Toggle eval sub-state for this model.
virtual std::string toString () const =0
 Human-readable summary of this model's configuration.
void train ()
 Run the training loop for this model.

Protected Member Functions

 Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode)
 Construct with a fully built network and runtime mode.
virtual void onTraining ()=0
 Training loop hook — derived class owns the implementation.

Protected Attributes

std::unique_ptr< NetworkTypenetwork_
 The owned Network instance.

Private Member Functions

void ensureTrainingMode (const char *method) const

Private Attributes

RuntimeMode runtime_mode_

Member Typedef Documentation

◆ NetworkType

template<DeviceType TDeviceType, TensorDataType TPrecision>
using Mila::Dnn::Model< TDeviceType, TPrecision >::NetworkType = Network<TDeviceType, TPrecision>

Constructor & Destructor Documentation

◆ Model() [1/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Model< TDeviceType, TPrecision >::Model ( const Model< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:
Here is the caller graph for this function:

◆ Model() [2/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Model< TDeviceType, TPrecision >::Model ( Model< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ ~Model()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual Mila::Dnn::Model< TDeviceType, TPrecision >::~Model ( )
virtualdefault

◆ Model() [3/3]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Mila::Dnn::Model< TDeviceType, TPrecision >::Model ( std::unique_ptr< NetworkType > network,
RuntimeMode runtime_mode )
inlineexplicitprotected

Construct with a fully built network and runtime mode.

Called by derived class constructors only. The network must already be built and have parameters loaded before this constructor is called.

Parameters
networkFully built and loaded Network.
runtime_modeInference or Training — immutable after construction.

Member Function Documentation

◆ ensureTrainingMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Model< TDeviceType, TPrecision >::ensureTrainingMode ( const char * method) const
inlineprivate
Here is the caller graph for this function:

◆ getDeviceId()

template<DeviceType TDeviceType, TensorDataType TPrecision>
DeviceId Mila::Dnn::Model< TDeviceType, TPrecision >::getDeviceId ( ) const
inlinenoexcept

The device this model runs on.

Here is the caller graph for this function:

◆ getMemoryStats()

template<DeviceType TDeviceType, TensorDataType TPrecision>
MemoryStats Mila::Dnn::Model< TDeviceType, TPrecision >::getMemoryStats ( ) const
inline

Current memory allocation breakdown for this model.

◆ getRuntimeMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
RuntimeMode Mila::Dnn::Model< TDeviceType, TPrecision >::getRuntimeMode ( ) const
inlinenoexcept

The runtime mode this model was constructed for.

Immutable after construction. Governs which public API methods are valid.

◆ isEval()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Model< TDeviceType, TPrecision >::isEval ( ) const
inlinenoexcept

True if this model is currently in eval sub-state.

For RuntimeMode::Inference models always returns true — inference models never compute gradients. For RuntimeMode::Training models reflects the last setEval() call.

Here is the call graph for this function:

◆ isInferenceMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Model< TDeviceType, TPrecision >::isInferenceMode ( ) const
inlinenoexcept

True if this model was constructed for inference.

The model-family inference API (e.g. generate()) is valid. train() will throw.

Here is the caller graph for this function:

◆ isTrainingMode()

template<DeviceType TDeviceType, TensorDataType TPrecision>
bool Mila::Dnn::Model< TDeviceType, TPrecision >::isTrainingMode ( ) const
inlinenoexcept

True if this model was constructed for training.

train() is valid. The model-family inference API will throw.

◆ onTraining()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual void Mila::Dnn::Model< TDeviceType, TPrecision >::onTraining ( )
protectedpure virtual

Training loop hook — derived class owns the implementation.

Called by train() after precondition enforcement. The derived class has total control over data loading, optimizer construction, loss computation, backward pass, checkpointing, and sampling.

Pure virtual — a model declaring RuntimeMode::Training must provide a training loop.

Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

Here is the caller graph for this function:

◆ operator=() [1/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Model & Mila::Dnn::Model< TDeviceType, TPrecision >::operator= ( const Model< TDeviceType, TPrecision > & )
delete
Here is the call graph for this function:

◆ operator=() [2/2]

template<DeviceType TDeviceType, TensorDataType TPrecision>
Model & Mila::Dnn::Model< TDeviceType, TPrecision >::operator= ( Model< TDeviceType, TPrecision > && )
default
Here is the call graph for this function:

◆ setEval()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Model< TDeviceType, TPrecision >::setEval ( bool eval)
inline

Toggle eval sub-state for this model.

When eval is true, the forward pass runs without gradients, dropout is disabled, and batch norm uses running statistics. When eval is false, the full training pass is restored.

Cascades through Network to every Component and Operation in the graph via their onEvalChanging() hooks.

Only valid on models constructed with RuntimeMode::Training. Inference-mode models are always in eval state by definition.

Parameters
evaltrue to enter eval sub-state, false to restore training.
Exceptions
std::runtime_errorif called on a RuntimeMode::Inference model.
Here is the call graph for this function:

◆ toString()

template<DeviceType TDeviceType, TensorDataType TPrecision>
virtual std::string Mila::Dnn::Model< TDeviceType, TPrecision >::toString ( ) const
pure virtual

Human-readable summary of this model's configuration.

Implemented in Mila::Dnn::GptModel< TDeviceType, TPrecision >, and Mila::Dnn::LlamaModel< TDeviceType, TPrecision >.

◆ train()

template<DeviceType TDeviceType, TensorDataType TPrecision>
void Mila::Dnn::Model< TDeviceType, TPrecision >::train ( )
inline

Run the training loop for this model.

Enforces RuntimeMode::Training precondition then delegates entirely to onTraining(). The derived class owns the loop — data loading, optimizer, loss, checkpointing, and sampling are all derived class concerns.

Exceptions
std::runtime_errorif called on an Inference-mode model.
Here is the call graph for this function:

Member Data Documentation

◆ network_

template<DeviceType TDeviceType, TensorDataType TPrecision>
std::unique_ptr<NetworkType> Mila::Dnn::Model< TDeviceType, TPrecision >::network_
protected

The owned Network instance.

Accessible to derived classes for model-specific operations not covered by the base class API.

◆ runtime_mode_

template<DeviceType TDeviceType, TensorDataType TPrecision>
RuntimeMode Mila::Dnn::Model< TDeviceType, TPrecision >::runtime_mode_
private

The documentation for this class was generated from the following file: