Mila
Deep Neural Network Library
|
Public Member Functions | |
FusedModule (std::shared_ptr< FusedOp > op) | |
void | build (Device, DType) override |
void | forward (const Tensor &input, Tensor &output) override |
![]() | |
Module (const std::string &device_name, const ComponentConfig &config) | |
Constructor with device name. | |
Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
Constructor with a specific device context. | |
virtual | ~Module ()=default |
Virtual destructor for proper cleanup in derived classes. | |
std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
Get the device context for this module. | |
Compute::DeviceType | getDeviceType () const |
Get the device type of the current device context. | |
std::string | getName () const |
Get the name of the module. | |
const auto & | getParameterTensors () const |
Get the parameter tensors of this module. | |
const ComputePrecision::Policy & | getPrecision () const |
const auto & | getStateTensors () const |
Get the state tensors of this module. | |
bool | isTraining () const |
Check if the module is in training mode. | |
virtual void | load (ModelArchive &archive)=0 |
Load the module state from a zip archive. | |
virtual size_t | parameterCount () const =0 |
Get the number of trainable parameters in the module. | |
virtual void | save (ModelArchive &archive) const =0 |
Save the module state to a zip archive. | |
virtual void | setTraining (bool is_training) |
Set the training mode of this module. | |
virtual std::string | toString () const =0 |
Convert the module to a string representation. | |
Private Attributes | |
std::shared_ptr< FusedOp > | op_ |
Additional Inherited Members | |
![]() | |
using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
![]() | |
const std::string | parametersToString () const |
Helper method to convert parameters to string representation. | |
const std::string | stateToString () const |
Helper method to convert state tensors to string representation. | |
![]() | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
Map of parameter names to parameter tensors. | |
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
Map of state names to state tensors. | |
|
inlineexplicit |
|
inlineoverride |
|
inlineoverride |
|
private |