Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::CudaDevice Class Referenceexport

Class representing a CUDA compute device instance. More...

Inheritance diagram for Mila::Dnn::Compute::CudaDevice:
Collaboration diagram for Mila::Dnn::Compute::CudaDevice:

Public Member Functions

 CudaDevice (DeviceConstructionKey key, DeviceId device_id)
 Construct CUDA device with validation.
std::pair< int, int > getComputeCapability () const
 Gets the compute capability version.
int getComputeCapabilityVersion () const
 Gets the compute capability as a single number.
DeviceId getDeviceId () const override
 Gets the device identifier.
std::string getDeviceName () const override
 Gets the device name.
constexpr DeviceType getDeviceType () const override
 Gets the device type.
int getMaxThreadsPerBlock () const
 Gets the maximum number of threads per block.
int getMultiprocessorCount () const
 Gets the number of multiprocessors.
const CudaDevicePropsgetProperties () const
 Gets the properties of this CUDA device.
size_t getSharedMemoryPerBlock () const
 Gets the shared memory per block in bytes.
size_t getTotalGlobalMemory () const
 Gets the total global memory size in bytes.
int getWarpSize () const
 Gets the warp size.
bool hasTensorCores () const
 Checks if the device has Tensor Cores.
bool isBf16Supported () const
 Checks if the device supports BF16 (bfloat16 precision).
bool isFp16Supported () const
 Checks if the device supports FP16 (half precision).
bool isFp8Supported () const
 Checks if the device supports FP8 (8-bit float precision).
bool isInt8Supported () const
 Checks if the device supports INT8 tensor cores.
Public Member Functions inherited from Mila::Dnn::Compute::Device
virtual ~Device ()=default

Static Private Member Functions

static DeviceId validateDeviceId (DeviceId device_id)
 Validates CUDA device ID.

Private Attributes

DeviceId device_id_
CudaDeviceProps props_

Additional Inherited Members

Static Public Member Functions inherited from Mila::Dnn::Compute::Device
static constexpr DeviceId Cpu () noexcept
 Create CPU device identifier.
static constexpr DeviceId Cuda (int index) noexcept
 Create CUDA device identifier.
template<DeviceType TDeviceType>
static constexpr DeviceId getDeviceId (int index) noexcept
static constexpr DeviceId Metal (int index) noexcept
 Create Metal device identifier.
static constexpr DeviceId Rocm (int index) noexcept
 Create ROCm device identifier.

Detailed Description

Class representing a CUDA compute device instance.

Provides an interface to interact with a specific NVIDIA CUDA-capable GPU. Handles device properties and capabilities for a single device instance.

Device instances are created exclusively by DeviceFactory (via DeviceRegistry). Users should obtain devices through DeviceRegistry::getDevice().

Precision Support:

  • FP32: All CUDA devices (SM 1.0+)
  • FP16: Pascal and newer (SM 6.0+)
  • BF16: Ampere and newer (SM 8.0+)
  • FP8: Hopper and newer (SM 9.0+)
  • INT8: Turing and newer (SM 7.5+)

Constructor & Destructor Documentation

◆ CudaDevice()

Mila::Dnn::Compute::CudaDevice::CudaDevice ( DeviceConstructionKey key,
DeviceId device_id )
inlineexplicit

Construct CUDA device with validation.

Validates that the device ID is registered with DeviceRegistry and queries/caches device properties from CUDA runtime.

Parameters
keyConstruction key ensuring only DeviceRegistry can create instances
device_idDevice identifier to initialize
Exceptions
std::invalid_argumentIf device_id validation fails
std::runtime_errorIf device is not registered or CUDA operations fail
Here is the call graph for this function:

Member Function Documentation

◆ getComputeCapability()

std::pair< int, int > Mila::Dnn::Compute::CudaDevice::getComputeCapability ( ) const
inline

Gets the compute capability version.

Returns
std::pair<int, int> Major and minor version (e.g., {8, 6} for SM 8.6).

◆ getComputeCapabilityVersion()

int Mila::Dnn::Compute::CudaDevice::getComputeCapabilityVersion ( ) const
inline

Gets the compute capability as a single number.

Returns
int Compute capability (e.g., 86 for SM 8.6).
Here is the caller graph for this function:

◆ getDeviceId()

DeviceId Mila::Dnn::Compute::CudaDevice::getDeviceId ( ) const
inlineoverridevirtual

Gets the device identifier.

Returns
DeviceId The identifier for this CUDA device (type + index).

Implements Mila::Dnn::Compute::Device.

◆ getDeviceName()

std::string Mila::Dnn::Compute::CudaDevice::getDeviceName ( ) const
inlineoverridevirtual

Gets the device name.

Returns
std::string The device name (e.g., "CUDA:0", "CUDA:1").

Implements Mila::Dnn::Compute::Device.

◆ getDeviceType()

DeviceType Mila::Dnn::Compute::CudaDevice::getDeviceType ( ) const
inlineconstexproverridevirtual

Gets the device type.

Returns
DeviceType The device type (Cuda).

Implements Mila::Dnn::Compute::Device.

◆ getMaxThreadsPerBlock()

int Mila::Dnn::Compute::CudaDevice::getMaxThreadsPerBlock ( ) const
inline

Gets the maximum number of threads per block.

Returns
int Maximum threads per block.

◆ getMultiprocessorCount()

int Mila::Dnn::Compute::CudaDevice::getMultiprocessorCount ( ) const
inline

Gets the number of multiprocessors.

Returns
int Number of streaming multiprocessors.

◆ getProperties()

const CudaDeviceProps & Mila::Dnn::Compute::CudaDevice::getProperties ( ) const
inline

Gets the properties of this CUDA device.

Returns
const CudaDeviceProps& Reference to the device properties.

◆ getSharedMemoryPerBlock()

size_t Mila::Dnn::Compute::CudaDevice::getSharedMemoryPerBlock ( ) const
inline

Gets the shared memory per block in bytes.

Returns
size_t Shared memory per block in bytes.

◆ getTotalGlobalMemory()

size_t Mila::Dnn::Compute::CudaDevice::getTotalGlobalMemory ( ) const
inline

Gets the total global memory size in bytes.

Returns
size_t Total global memory in bytes.

◆ getWarpSize()

int Mila::Dnn::Compute::CudaDevice::getWarpSize ( ) const
inline

Gets the warp size.

Returns
int Warp size (typically 32).

◆ hasTensorCores()

bool Mila::Dnn::Compute::CudaDevice::hasTensorCores ( ) const
inline

Checks if the device has Tensor Cores.

Tensor Cores are available on Volta and newer (SM 7.0+).

Returns
bool True if Tensor Cores are available.

◆ isBf16Supported()

bool Mila::Dnn::Compute::CudaDevice::isBf16Supported ( ) const
inline

Checks if the device supports BF16 (bfloat16 precision).

BF16 is supported on Ampere and newer architectures (SM 8.0+).

Returns
bool True if BF16 is supported.

◆ isFp16Supported()

bool Mila::Dnn::Compute::CudaDevice::isFp16Supported ( ) const
inline

Checks if the device supports FP16 (half precision).

FP16 is supported on Pascal and newer architectures (SM 6.0+).

Returns
bool True if FP16 is supported.

◆ isFp8Supported()

bool Mila::Dnn::Compute::CudaDevice::isFp8Supported ( ) const
inline

Checks if the device supports FP8 (8-bit float precision).

FP8 is supported on Hopper and newer architectures (SM 9.0+).

Returns
bool True if FP8 is supported.

◆ isInt8Supported()

bool Mila::Dnn::Compute::CudaDevice::isInt8Supported ( ) const
inline

Checks if the device supports INT8 tensor cores.

INT8 tensor cores are supported on Turing and newer (SM 7.5+).

Returns
bool True if INT8 tensor cores are supported.
Here is the call graph for this function:

◆ validateDeviceId()

DeviceId Mila::Dnn::Compute::CudaDevice::validateDeviceId ( DeviceId device_id)
inlinestaticprivate

Validates CUDA device ID.

Ensures device_id has correct type (Cuda), non-negative index, and is within the range of available CUDA devices.

Parameters
device_idDevice identifier to validate.
Returns
DeviceId The validated device identifier.
Exceptions
std::invalid_argumentIf device_id type is not Cuda or index is negative.
std::runtime_errorIf CUDA device count query fails or index is out of range.
Here is the call graph for this function:
Here is the caller graph for this function:

Member Data Documentation

◆ device_id_

DeviceId Mila::Dnn::Compute::CudaDevice::device_id_
private

◆ props_

CudaDeviceProps Mila::Dnn::Compute::CudaDevice::props_
private

The documentation for this class was generated from the following file: