|
Mila
Deep Neural Network Library
|


Public Member Functions | |
| FusedModule (std::shared_ptr< FusedOp > op) | |
| void | build (Device, DType) override |
| void | forward (const Tensor &input, Tensor &output) override |
Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| Module (const std::string &device_name, const ComponentConfig &config) | |
| Constructor with device name. | |
| Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config) | |
| Constructor with a specific device context. | |
| virtual | ~Module ()=default |
| Virtual destructor for proper cleanup in derived classes. | |
| std::shared_ptr< Compute::DeviceContext > | getDeviceContext () const |
| Get the device context for this module. | |
| Compute::DeviceType | getDeviceType () const |
| Get the device type of the current device context. | |
| std::string | getName () const |
| Get the name of the module. | |
| const auto & | getParameterTensors () const |
| Get the parameter tensors of this module. | |
| const ComputePrecision::Policy & | getPrecision () const |
| const auto & | getStateTensors () const |
| Get the state tensors of this module. | |
| bool | isTraining () const |
| Check if the module is in training mode. | |
| virtual void | load (ModelArchive &archive)=0 |
| Load the module state from a zip archive. | |
| virtual size_t | parameterCount () const =0 |
| Get the number of trainable parameters in the module. | |
| virtual void | save (ModelArchive &archive) const =0 |
| Save the module state to a zip archive. | |
| virtual void | setTraining (bool is_training) |
| Set the training mode of this module. | |
| virtual std::string | toString () const =0 |
| Convert the module to a string representation. | |
Private Attributes | |
| std::shared_ptr< FusedOp > | op_ |
Additional Inherited Members | |
Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| using | MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource > |
Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| const std::string | parametersToString () const |
| Helper method to convert parameters to string representation. | |
| const std::string | stateToString () const |
| Helper method to convert state tensors to string representation. | |
Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput > | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | parameter_map_ = {} |
| Map of parameter names to parameter tensors. | |
| std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > | state_map_ = {} |
| Map of state names to state tensors. | |
|
inlineexplicit |
|
inlineoverride |
|
inlineoverride |
|
private |