Mila
Deep Neural Network Library
Loading...
Searching...
No Matches
TensorTraits.ixx File Reference

Provides type traits for tensor data types with compile-time type information. More...

#include <vector>
#include <type_traits>
#include <cstdint>
#include <string_view>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>

Classes

struct  Mila::Dnn::TensorTrait< __nv_fp8_e4m3 >
 Specialization of TensorTrait for 8-bit floating point type (e4m3). More...
 
struct  Mila::Dnn::TensorTrait< __nv_fp8_e5m2 >
 Specialization of TensorTrait for alternative 8-bit floating point type (e5m2). More...
 
struct  Mila::Dnn::TensorTrait< float >
 Specialization of TensorTrait for float type. More...
 
struct  Mila::Dnn::TensorTrait< half >
 Specialization of TensorTrait for half-precision float type. More...
 
struct  Mila::Dnn::TensorTrait< int >
 Specialization of TensorTrait for 32-bit signed integer type. More...
 
struct  Mila::Dnn::TensorTrait< int16_t >
 Specialization of TensorTrait for 16-bit signed integer type. More...
 
struct  Mila::Dnn::TensorTrait< nv_bfloat16 >
 Specialization of TensorTrait for NVIDIA bfloat16 type. More...
 
struct  Mila::Dnn::TensorTrait< uint16_t >
 Specialization of TensorTrait for 16-bit unsigned integer type. More...
 
struct  Mila::Dnn::TensorTrait< uint32_t >
 Specialization of TensorTrait for 32-bit unsigned integer type. More...
 

Namespaces

namespace  Mila
 
namespace  Mila::Dnn
 

Concepts

concept  Mila::Dnn::ValidTensorType
 Concept that constrains types to those with valid tensor trait specializations.
 
concept  Mila::Dnn::ValidFloatTensorType
 Concept that constrains types to valid floating-point tensor types.
 
concept  Mila::Dnn::ValidFloatTensorTypes
 Concept that verifies both types are valid floating-point tensor types.
 
concept  Mila::Dnn::ValidTensorTypes
 Concept that verifies both input and compute types have valid tensor trait mappings.
 

Functions

template<typename T >
constexpr std::string_view Mila::Dnn::tensor_type_name ()
 Get the string representation of a tensor element type.
 
template<typename T >
constexpr size_t Mila::Dnn::tensor_type_size ()
 Get the size in bytes of a tensor element type.
 

Detailed Description

Provides type traits for tensor data types with compile-time type information.