Mila
Deep Neural Network Library
|
Provides type traits for tensor data types with compile-time type information. More...
#include <vector>
#include <type_traits>
#include <cstdint>
#include <string_view>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
Classes | |
struct | Mila::Dnn::TensorTrait< __nv_fp8_e4m3 > |
Specialization of TensorTrait for 8-bit floating point type (e4m3). More... | |
struct | Mila::Dnn::TensorTrait< __nv_fp8_e5m2 > |
Specialization of TensorTrait for alternative 8-bit floating point type (e5m2). More... | |
struct | Mila::Dnn::TensorTrait< float > |
Specialization of TensorTrait for float type. More... | |
struct | Mila::Dnn::TensorTrait< half > |
Specialization of TensorTrait for half-precision float type. More... | |
struct | Mila::Dnn::TensorTrait< int > |
Specialization of TensorTrait for 32-bit signed integer type. More... | |
struct | Mila::Dnn::TensorTrait< int16_t > |
Specialization of TensorTrait for 16-bit signed integer type. More... | |
struct | Mila::Dnn::TensorTrait< nv_bfloat16 > |
Specialization of TensorTrait for NVIDIA bfloat16 type. More... | |
struct | Mila::Dnn::TensorTrait< uint16_t > |
Specialization of TensorTrait for 16-bit unsigned integer type. More... | |
struct | Mila::Dnn::TensorTrait< uint32_t > |
Specialization of TensorTrait for 32-bit unsigned integer type. More... | |
Namespaces | |
namespace | Mila |
namespace | Mila::Dnn |
Concepts | |
concept | Mila::Dnn::ValidTensorType |
Concept that constrains types to those with valid tensor trait specializations. | |
concept | Mila::Dnn::ValidFloatTensorType |
Concept that constrains types to valid floating-point tensor types. | |
concept | Mila::Dnn::ValidFloatTensorTypes |
Concept that verifies both types are valid floating-point tensor types. | |
concept | Mila::Dnn::ValidTensorTypes |
Concept that verifies both input and compute types have valid tensor trait mappings. | |
Functions | |
template<typename T > | |
constexpr std::string_view | Mila::Dnn::tensor_type_name () |
Get the string representation of a tensor element type. | |
template<typename T > | |
constexpr size_t | Mila::Dnn::tensor_type_size () |
Get the size in bytes of a tensor element type. | |
Provides type traits for tensor data types with compile-time type information.