Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource > Class Template Referenceabstractexport

Abstract base class for data loaders used in training and evaluation. More...

Public Member Functions

 DataLoader (size_t batch_size)
 Constructs a new DataLoader with the specified batch size.
 
virtual ~DataLoader ()=default
 Virtual destructor to ensure proper cleanup in derived classes.
 
size_t batchSize () const
 Gets the size of each batch.
 
size_t currentBatch () const
 Gets the current batch index.
 
virtual const Tensor< TInput, TMemoryResource > & inputs () const =0
 Gets the input tensor containing the current batch of input data (const version).
 
virtual Tensor< TInput, TMemoryResource > & inputs ()=0
 Gets the input tensor containing the current batch of input data.
 
virtual void nextBatch ()=0
 Loads the next batch of data from the dataset.
 
virtual size_t numBatches () const =0
 Gets the total number of batches in the dataset.
 
virtual void reset ()
 Resets the loader to the beginning of the dataset.
 
virtual const Tensor< TTarget, TMemoryResource > & targets () const =0
 Gets the target tensor containing the current batch of target data (const version).
 
virtual Tensor< TTarget, TMemoryResource > & targets ()=0
 Gets the target tensor containing the current batch of target data.
 

Protected Attributes

size_t batch_size_
 Number of samples in each batch.
 
size_t current_batch_
 Index of the current batch (0-based)
 

Detailed Description

template<typename TInput, typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
requires ValidTensorTypes<TInput, TTarget> && (std::is_same_v<TMemoryResource, CudaPinnedMemoryResource> || std::is_same_v<TMemoryResource, CpuMemoryResource>)
class Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >

Abstract base class for data loaders used in training and evaluation.

The DataLoader class provides a generic interface for loading batches of data from various sources (files, databases, etc.) into tensors that can be used for model training and evaluation. It supports both CPU and CUDA pinned memory resources for efficient data transfer to GPU devices.

Template Parameters
TInputThe data type for input and target tensors (must be a valid floating point type).
TMemoryResourceThe memory resource type to use (either CudaPinnedMemoryResource or CpuMemoryResource).

Constructor & Destructor Documentation

◆ DataLoader()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::DataLoader ( size_t  batch_size)
inline

Constructs a new DataLoader with the specified batch size.

Parameters
batch_sizeThe number of samples to include in each batch.

◆ ~DataLoader()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::~DataLoader ( )
virtualdefault

Virtual destructor to ensure proper cleanup in derived classes.

Member Function Documentation

◆ batchSize()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
size_t Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::batchSize ( ) const
inline

Gets the size of each batch.

Returns the number of samples in each batch as specified during the DataLoader construction.

Returns
size_t The batch size.

◆ currentBatch()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
size_t Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::currentBatch ( ) const
inline

Gets the current batch index.

Returns the index of the batch that was most recently loaded.

Returns
size_t The current batch index (0-based).

◆ inputs() [1/2]

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual const Tensor< TInput, TMemoryResource > & Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::inputs ( ) const
pure virtual

Gets the input tensor containing the current batch of input data (const version).

This method must be implemented by derived classes to provide read-only access to the tensor containing input data for the current batch.

Returns
const Tensor<TInput, TMemoryResource>& Const reference to the input tensor.

◆ inputs() [2/2]

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual Tensor< TInput, TMemoryResource > & Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::inputs ( )
pure virtual

Gets the input tensor containing the current batch of input data.

This method must be implemented by derived classes to provide access to the tensor containing input data for the current batch.

Returns
Tensor<TInput, TMemoryResource>& Reference to the input tensor.

◆ nextBatch()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual void Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::nextBatch ( )
pure virtual

Loads the next batch of data from the dataset.

This method must be implemented by derived classes to load the next batch of data into the input and target tensors. The implementation should update the current_batch_ counter after successfully loading a new batch.

◆ numBatches()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual size_t Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::numBatches ( ) const
pure virtual

Gets the total number of batches in the dataset.

This method must be implemented by derived classes to report the total number of batches available in the dataset.

Returns
size_t The number of batches.

◆ reset()

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual void Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::reset ( )
inlinevirtual

Resets the loader to the beginning of the dataset.

Calling this method resets the internal state of the data loader to start from the first batch again. Derived classes may override this method to implement additional reset functionality.

◆ targets() [1/2]

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual const Tensor< TTarget, TMemoryResource > & Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::targets ( ) const
pure virtual

Gets the target tensor containing the current batch of target data (const version).

This method must be implemented by derived classes to provide read-only access to the tensor containing target/label data for the current batch.

Returns
const Tensor<TTarget, TMemoryResource>& Const reference to the target tensor.

◆ targets() [2/2]

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
virtual Tensor< TTarget, TMemoryResource > & Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::targets ( )
pure virtual

Gets the target tensor containing the current batch of target data.

This method must be implemented by derived classes to provide access to the tensor containing target/label data for the current batch.

Returns
Tensor<TTarget, TMemoryResource>& Reference to the target tensor.

Member Data Documentation

◆ batch_size_

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
size_t Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::batch_size_
protected

Number of samples in each batch.

◆ current_batch_

template<typename TInput , typename TTarget = TInput, typename TMemoryResource = CpuMemoryResource>
size_t Mila::Dnn::Data::DataLoader< TInput, TTarget, TMemoryResource >::current_batch_
protected

Index of the current batch (0-based)


The documentation for this class was generated from the following file: