Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Dnn.Component Module Reference

Exported Modules

module  Dnn.ITensor
module  Compute.DeviceType
module  Dnn.ComponentType
module  Compute.Device
module  Compute.IExecutionContext
module  Dnn.TensorDataTypeTraits
module  Dnn.TensorDataType
module  Serialization.ModelArchive
module  Dnn.Tensor
module  Serialization.Tensor
module  Compute.DeviceId
module  Dnn.TensorTypes
module  Serialization.Mode

Classes

class  Mila::Dnn::BuildContext
 Build-time context for Component::build(). More...
class  Mila::Dnn::Component< TDeviceType, TPrecision >
 Abstract base class for neural network components. More...
struct  Mila::Dnn::MemoryStats
 Memory allocation breakdown for a single component. More...

Enumerations

enum class  Mila::Dnn::TrainingMode : uint8_t { Normal , Eval }
 Runtime behavioral state for Components built with RuntimeMode::Training. More...

Functions

 Component (const std::string &name)
 Construct component with required name identifier.
virtual ~Component ()=default
virtual void build (const BuildContext &context) final
 Build the component with the provided BuildContext (canonical overload).
void ensureBuilt (const char *method) const
 Throws if the component has not yet been built.
virtual DeviceId getDeviceId () const =0
 Get the compute device id associated with this component.
static constexpr DeviceType getDeviceType ()
 Compile-time device type for this component instance.
IExecutionContextgetExecutionContext () const
 Get the shared execution context.
virtual std::vector< ITensor * > getGradients () const =0
 Return non-owning pointers to parameter gradient tensors.
virtual MemoryStats getMemoryStats () const =0
 Return the current memory allocation breakdown for this component.
const std::string getName () const
 Get the component's name identifier.
virtual std::vector< std::string > getParameterNames () const
 List all available parameter names for this component.
virtual std::vector< ITensor * > getParameters () const =0
 Return non-owning pointers to parameter tensors.
static constexpr TensorDataType getPrecision () noexcept
 Compile-time tensor precision for this component instance.
RuntimeMode getRuntimeMode () const noexcept
 Convenience accessor — true if currently in Eval mode.
TrainingMode getTrainingMode () const noexcept
 The current runtime behavioral mode of this Component.
virtual const ComponentType getType () const =0
 Get the component type identifier.
bool hasExecutionContext () const noexcept
 Check if execution context has been set.
virtual bool isBuilt () const final
 Returns true if build() has completed successfully.
static bool isIdentifier (const std::string &s) noexcept
 Checks if a string is a valid component identifier.
bool isInferenceMode () const noexcept
bool isTrainingMode () const noexcept
virtual void loadParameter (const std::string &name, const Serialization::ITensorBlob &blob)
 Load a parameter from serialized tensor data.
template<TensorDataType TParameterPrecision, typename TMemoryResource>
void loadParameterFromBlob (const std::string &param_name, const Serialization::ITensorBlob &blob, Tensor< TParameterPrecision, TMemoryResource > &target, const shape_t &expected_shape)
 Load a tensor blob into a parameter tensor with validation.
virtual void onBuilding (const BuildContext &config)
 Hook invoked by build() to allocate component buffers.
virtual void onExecutionContextSet ()
 Lifecycle hook: Called immediately after ExecutionContext is set.
virtual void onTrainingModeChanging (TrainingMode mode)
 Hook called before TrainingMode transitions.
MemoryStats Mila::Dnn::operator+ (MemoryStats lhs, const MemoryStats &rhs) noexcept
 Aggregate two MemoryStats instances.
virtual size_t parameterCount () const =0
 Return number of trainable parameters.
virtual void save_ (ModelArchive &archive, SerializationMode mode) const =0
void setExecutionContext (IExecutionContext *context)
 Set the execution context for this component.
void setTrainingMode (TrainingMode mode)
 Set the runtime behavioral mode for this Component.
virtual void synchronize ()=0
 Wait for outstanding device work submitted by this component.
virtual std::string toString () const =0
 Produce a short, human-readable description of the component.
static const std::string & validateName (const std::string &name)
 Validates the component name.
virtual void zeroGradients ()
 Clear all model-owned gradients for this component.

Variables

BuildContext build_context_ { shape_t{ 1 }, RuntimeMode::Training }
 The BuildContext stored at build time.
bool built_ { false }
IExecutionContextexec_context_ { nullptr }
std::string name_
TrainingMode training_mode_ { TrainingMode::Normal }
std::mutex training_mode_mutex_

Files

file  /__w/Mila/Mila/Mila/Src/Dnn/Core/Component.ixx
 Base component interface for Mila DNN components.
file  /__w/Mila/Mila/Mila/Src/Dnn/Core/Comonent.TrainingMode.ixx
file  /__w/Mila/Mila/Mila/Src/Dnn/Core/Component.BuildContext.ixx
file  /__w/Mila/Mila/Mila/Src/Dnn/Core/Component.MemoryStats.ixx