|
| | Model (std::unique_ptr< NetworkType > network, RuntimeMode runtime_mode) |
| | Construct with a fully built network and runtime mode.
|
| virtual void | onTraining ()=0 |
| | Training loop hook — derived class owns the implementation.
|
◆ NetworkType
◆ Model() [1/3]
| Mila::Dnn::Model< TDeviceType, TPrecision >::Model |
( |
const Model< TDeviceType, TPrecision > & | | ) |
|
|
delete |
◆ Model() [2/3]
| Mila::Dnn::Model< TDeviceType, TPrecision >::Model |
( |
Model< TDeviceType, TPrecision > && | | ) |
|
|
default |
◆ ~Model()
◆ Model() [3/3]
Construct with a fully built network and runtime mode.
Called by derived class constructors only. The network must already be built and have parameters loaded before this constructor is called.
- Parameters
-
| network | Fully built and loaded Network. |
| runtime_mode | Inference or Training — immutable after construction. |
◆ ensureTrainingMode()
| void Mila::Dnn::Model< TDeviceType, TPrecision >::ensureTrainingMode |
( |
const char * | method | ) |
const |
|
inlineprivate |
◆ getDeviceId()
The device this model runs on.
◆ getMemoryStats()
Current memory allocation breakdown for this model.
◆ getRuntimeMode()
The runtime mode this model was constructed for.
Immutable after construction. Governs which public API methods are valid.
◆ isEval()
◆ isInferenceMode()
True if this model was constructed for inference.
The model-family inference API (e.g. generate()) is valid. train() will throw.
◆ isTrainingMode()
True if this model was constructed for training.
train() is valid. The model-family inference API will throw.
◆ onTraining()
◆ operator=() [1/2]
◆ operator=() [2/2]
◆ setEval()
Toggle eval sub-state for this model.
When eval is true, the forward pass runs without gradients, dropout is disabled, and batch norm uses running statistics. When eval is false, the full training pass is restored.
Cascades through Network to every Component and Operation in the graph via their onEvalChanging() hooks.
Only valid on models constructed with RuntimeMode::Training. Inference-mode models are always in eval state by definition.
- Parameters
-
| eval | true to enter eval sub-state, false to restore training. |
- Exceptions
-
◆ toString()
| virtual std::string Mila::Dnn::Model< TDeviceType, TPrecision >::toString |
( |
| ) |
const |
|
pure virtual |
◆ train()
Run the training loop for this model.
Enforces RuntimeMode::Training precondition then delegates entirely to onTraining(). The derived class owns the loop — data loading, optimizer, loss, checkpointing, and sampling are all derived class concerns.
- Exceptions
-
| std::runtime_error | if called on an Inference-mode model. |
◆ network_
The owned Network instance.
Accessible to derived classes for model-specific operations not covered by the base class API.
◆ runtime_mode_
The documentation for this class was generated from the following file: