Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Component Class Referenceabstractexport

Abstract base class for all components in the Mila framework. More...

Collaboration diagram for Mila::Dnn::Component:

Public Member Functions

 Component (const std::string &device_name)
 Constructor with device name.
 
 Component (std::shared_ptr< DeviceContext > context)
 Constructor with a specific device context.
 
virtual ~Component ()=default
 Virtual destructor for proper cleanup in derived classes.
 
const ComputePrecisiongetComputePrecision () const
 Get the compute precision policy for this component.
 
std::shared_ptr< DeviceContextgetDeviceContext () const
 Get the device context for this component.
 
DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the component.
 
bool isTraining () const
 Check if the component is in training mode.
 
ComponentsetComputePrecisionPolicy (ComputePrecision::Policy policy)
 Set the compute precision policy explicitly.
 
void setName (const std::string &name)
 Set the name of the component.
 
virtual void setTraining (bool is_training)
 Set the training mode of this component.
 
virtual std::string toString () const =0
 Convert the component to a string representation.
 

Protected Member Functions

virtual void validateDeviceType () const
 Validate that the device type matches the derived class requirements.
 

Static Private Member Functions

static std::shared_ptr< DeviceContextcreateContext (const std::string &device_name)
 Helper method to create a DeviceContext from a device name.
 

Private Attributes

ComputePrecision compute_precision_
 The compute precision policy for this component.
 
std::shared_ptr< DeviceContextdevice_context_
 The device context used for this component's computations.
 
bool is_training_ = false
 Whether the component is in training mode.
 
std::string name_ = "unnamed"
 The name of the component.
 

Detailed Description

Abstract base class for all components in the Mila framework.

The Component class establishes a common interface and behavior for both Modules and Operations. It handles device context management, naming, compute precision, and training mode.

Constructor & Destructor Documentation

◆ Component() [1/2]

Mila::Dnn::Component::Component ( const std::string &  device_name)
inlineexplicit

Constructor with device name.

Parameters
device_nameThe name of the device to use (e.g., "CPU", "CUDA:0").
policyThe compute precision policy to use (default is Auto).
Here is the call graph for this function:

◆ Component() [2/2]

Mila::Dnn::Component::Component ( std::shared_ptr< DeviceContext context)
inlineexplicit

Constructor with a specific device context.

Parameters
contextThe device context to use for this component.
policyThe compute precision policy to use (default is Auto).
Exceptions
std::invalid_argumentIf the provided context is nullptr.
Here is the call graph for this function:

◆ ~Component()

virtual Mila::Dnn::Component::~Component ( )
virtualdefault

Virtual destructor for proper cleanup in derived classes.

Member Function Documentation

◆ createContext()

static std::shared_ptr< DeviceContext > Mila::Dnn::Component::createContext ( const std::string &  device_name)
inlinestaticprivate

Helper method to create a DeviceContext from a device name.

Parameters
device_nameName of the device to create a context for.
Returns
std::shared_ptr<DeviceContext> The created device context.

◆ getComputePrecision()

const ComputePrecision & Mila::Dnn::Component::getComputePrecision ( ) const
inline

Get the compute precision policy for this component.

Returns
const ComputePrecision& The compute precision policy

◆ getDeviceContext()

std::shared_ptr< DeviceContext > Mila::Dnn::Component::getDeviceContext ( ) const
inline

Get the device context for this component.

Returns
std::shared_ptr<DeviceContext> The device context.

◆ getDeviceType()

DeviceType Mila::Dnn::Component::getDeviceType ( ) const
inline

Get the device type of the current device context.

Returns
DeviceType The device type (CPU or CUDA).

◆ getName()

std::string Mila::Dnn::Component::getName ( ) const
inline

Get the name of the component.

Returns
std::string Name of the component.

◆ isTraining()

bool Mila::Dnn::Component::isTraining ( ) const
inline

Check if the component is in training mode.

Returns
true If the component is in training mode.
false If the component is in inference mode.

◆ setComputePrecisionPolicy()

Component & Mila::Dnn::Component::setComputePrecisionPolicy ( ComputePrecision::Policy  policy)
inline

Set the compute precision policy explicitly.

Parameters
policyThe precision policy to use
Returns
Component& Reference to this component for method chaining
Here is the call graph for this function:

◆ setName()

void Mila::Dnn::Component::setName ( const std::string &  name)
inline

Set the name of the component.

Parameters
nameThe name to set. Must not be empty and cannot contain a dot ('.').
Exceptions
std::invalid_argumentIf the name is empty or contains a dot.

◆ setTraining()

virtual void Mila::Dnn::Component::setTraining ( bool  is_training)
inlinevirtual

Set the training mode of this component.

Parameters
is_trainingTrue if the component is in training mode, false for inference.

◆ toString()

virtual std::string Mila::Dnn::Component::toString ( ) const
pure virtual

Convert the component to a string representation.

Returns
std::string String representation of the component.

◆ validateDeviceType()

virtual void Mila::Dnn::Component::validateDeviceType ( ) const
inlineprotectedvirtual

Validate that the device type matches the derived class requirements.

This method should be overridden in derived classes to enforce device type constraints.

Here is the caller graph for this function:

Member Data Documentation

◆ compute_precision_

ComputePrecision Mila::Dnn::Component::compute_precision_
private

The compute precision policy for this component.

◆ device_context_

std::shared_ptr<DeviceContext> Mila::Dnn::Component::device_context_
private

The device context used for this component's computations.

◆ is_training_

bool Mila::Dnn::Component::is_training_ = false
private

Whether the component is in training mode.

◆ name_

std::string Mila::Dnn::Component::name_ = "unnamed"
private

The name of the component.


The documentation for this class was generated from the following file: