Mila
Deep Neural Network Library
|
Abstract base class for all components in the Mila framework. More...
Public Member Functions | |
Component (const std::string &device_name) | |
Constructor with device name. | |
Component (std::shared_ptr< DeviceContext > context) | |
Constructor with a specific device context. | |
virtual | ~Component ()=default |
Virtual destructor for proper cleanup in derived classes. | |
const ComputePrecision & | getComputePrecision () const |
Get the compute precision policy for this component. | |
std::shared_ptr< DeviceContext > | getDeviceContext () const |
Get the device context for this component. | |
DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the component. | |
bool | isTraining () const |
Check if the component is in training mode. | |
Component & | setComputePrecisionPolicy (ComputePrecision::Policy policy) |
Set the compute precision policy explicitly. | |
void | setName (const std::string &name) |
Set the name of the component. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this component. | |
virtual std::string | toString () const =0 |
Convert the component to a string representation. | |
Protected Member Functions | |
virtual void | validateDeviceType () const |
Validate that the device type matches the derived class requirements. | |
Static Private Member Functions | |
static std::shared_ptr< DeviceContext > | createContext (const std::string &device_name) |
Helper method to create a DeviceContext from a device name. | |
Private Attributes | |
ComputePrecision | compute_precision_ |
The compute precision policy for this component. | |
std::shared_ptr< DeviceContext > | device_context_ |
The device context used for this component's computations. | |
bool | is_training_ = false |
Whether the component is in training mode. | |
std::string | name_ = "unnamed" |
The name of the component. | |
Abstract base class for all components in the Mila framework.
The Component class establishes a common interface and behavior for both Modules and Operations. It handles device context management, naming, compute precision, and training mode.
|
inlineexplicit |
Constructor with device name.
device_name | The name of the device to use (e.g., "CPU", "CUDA:0"). |
policy | The compute precision policy to use (default is Auto). |
|
inlineexplicit |
Constructor with a specific device context.
context | The device context to use for this component. |
policy | The compute precision policy to use (default is Auto). |
std::invalid_argument | If the provided context is nullptr. |
|
virtualdefault |
Virtual destructor for proper cleanup in derived classes.
|
inlinestaticprivate |
Helper method to create a DeviceContext from a device name.
device_name | Name of the device to create a context for. |
|
inline |
Get the compute precision policy for this component.
|
inline |
Get the device context for this component.
|
inline |
Get the device type of the current device context.
|
inline |
Get the name of the component.
|
inline |
Check if the component is in training mode.
|
inline |
Set the compute precision policy explicitly.
policy | The precision policy to use |
|
inline |
Set the name of the component.
name | The name to set. Must not be empty and cannot contain a dot ('.'). |
std::invalid_argument | If the name is empty or contains a dot. |
|
inlinevirtual |
Set the training mode of this component.
is_training | True if the component is in training mode, false for inference. |
|
pure virtual |
Convert the component to a string representation.
|
inlineprotectedvirtual |
Validate that the device type matches the derived class requirements.
This method should be overridden in derived classes to enforce device type constraints.
|
private |
The compute precision policy for this component.
|
private |
The device context used for this component's computations.
|
private |
Whether the component is in training mode.
|
private |
The name of the component.