Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType > Class Template Referenceexport
Inheritance diagram for Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >:
Collaboration diagram for Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >:

Public Member Functions

 FusedModule (std::shared_ptr< FusedOp > op)
 
void build (Device, DType) override
 
void forward (const Tensor &input, Tensor &output) override
 
- Public Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
 Module (const std::string &device_name, const ComponentConfig &config)
 Constructor with device name.
 
 Module (std::shared_ptr< DeviceContext > context, const ComponentConfig &config)
 Constructor with a specific device context.
 
virtual ~Module ()=default
 Virtual destructor for proper cleanup in derived classes.
 
std::shared_ptr< Compute::DeviceContextgetDeviceContext () const
 Get the device context for this module.
 
Compute::DeviceType getDeviceType () const
 Get the device type of the current device context.
 
std::string getName () const
 Get the name of the module.
 
const auto & getParameterTensors () const
 Get the parameter tensors of this module.
 
const ComputePrecision::PolicygetPrecision () const
 
const auto & getStateTensors () const
 Get the state tensors of this module.
 
bool isTraining () const
 Check if the module is in training mode.
 
virtual void load (ModelArchive &archive)=0
 Load the module state from a zip archive.
 
virtual size_t parameterCount () const =0
 Get the number of trainable parameters in the module.
 
virtual void save (ModelArchive &archive) const =0
 Save the module state to a zip archive.
 
virtual void setTraining (bool is_training)
 Set the training mode of this module.
 
virtual std::string toString () const =0
 Convert the module to a string representation.
 

Private Attributes

std::shared_ptr< FusedOpop_
 

Additional Inherited Members

- Public Types inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
using MR = std::conditional_t< TDeviceType==DeviceType::Cuda, CudaMemoryResource, CpuMemoryResource >
 
- Protected Member Functions inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
const std::string parametersToString () const
 Helper method to convert parameters to string representation.
 
const std::string stateToString () const
 Helper method to convert state tensors to string representation.
 
- Protected Attributes inherited from Mila::Dnn::Module< TDeviceType, TInput, TOutput >
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > parameter_map_ = {}
 Map of parameter names to parameter tensors.
 
std::unordered_map< std::string, std::shared_ptr< Tensor< TOutput, MR > > > state_map_ = {}
 Map of state names to state tensors.
 

Constructor & Destructor Documentation

◆ FusedModule()

template<typename TPrecision , typename TInput = TPrecision, Compute::DeviceType TDeviceType = Compute::DeviceType::Cuda>
Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >::FusedModule ( std::shared_ptr< FusedOp op)
inlineexplicit

Member Function Documentation

◆ build()

template<typename TPrecision , typename TInput = TPrecision, Compute::DeviceType TDeviceType = Compute::DeviceType::Cuda>
void Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >::build ( Device  ,
DType   
)
inlineoverride

◆ forward()

template<typename TPrecision , typename TInput = TPrecision, Compute::DeviceType TDeviceType = Compute::DeviceType::Cuda>
void Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >::forward ( const Tensor input,
Tensor output 
)
inlineoverride

Member Data Documentation

◆ op_

template<typename TPrecision , typename TInput = TPrecision, Compute::DeviceType TDeviceType = Compute::DeviceType::Cuda>
std::shared_ptr<FusedOp> Mila::Dnn::FusedModule< TPrecision, TInput, TDeviceType >::op_
private

The documentation for this class was generated from the following file: