Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn Namespace Reference

Namespaces

namespace  Compute
namespace  Detail
namespace  detail
namespace  Extensibility
namespace  Optimizers
namespace  Quant
namespace  Serialization
namespace  Visualization

Classes

struct  AxisPartition
 Information about axis partitioning of a tensor. More...
class  BufferedTokenStreamer
 Buffers BufSize tokens before forwarding a contiguous span to Sink. More...
class  BuildContext
 Build-time context for Component::build(). More...
class  Component
 Abstract base class for neural network components. More...
class  ComponentConfig
 Abstract base for component configuration objects. More...
class  ComponentFactory
 Factory for reconstructing components from serialized archives. More...
class  CompositeComponent
 A component that contains and manages child components. More...
class  ConstantLRScheduler
 Constant learning-rate scheduler. More...
class  CosineLRScheduler
 Cosine annealing scheduler. More...
class  CpuTensorDataTypeTraits
 CPU-specific traits for abstract tensor data types. More...
class  CrossEntropyConfig
 Configuration for fused SoftmaxCrossEntropy loss. More...
struct  dependent_false
class  Dropout
 Dropout regularization module for neural networks. More...
class  DropoutConfig
 Configuration class for Dropout module. More...
class  FusedComponent
 DEPRECATED. More...
class  Gelu
 Gaussian Error Linear Unit (GELU) activation component. More...
class  GeluConfig
 Configuration class for GELU module. More...
struct  GenerateParams
struct  GenerationStatistics
 Statistics captured during a single generateStreaming() call. More...
class  GptBlock
 Transformer encoder block as a composite component. More...
class  GptBlockConfig
 Configuration class for GPT transformer blocks. More...
class  GptConfig
 Network-level configuration for GPT-style transformer networks. More...
class  GptModel
 GPT inference model. More...
class  GptTransformer
 GPT-2 style transformer (decoder-only) for autoregressive token prediction. More...
class  GqaConfig
 Configuration class for the Grouped-Query Attention module. More...
class  GroupedQueryAttention
 Grouped-Query Attention module that accepts concatenated QKV input. More...
class  ITensor
 Abstract interface providing essential tensor information and data access. More...
class  LanguageModel
struct  LanguageModelConfig
 CRTP base configuration for all deployable Mila language models. More...
class  LanguageNetwork
class  LayerNorm
 Device-templated Layer Normalization component. More...
class  LayerNormConfig
class  LearningRateScheduler
 Abstract base for learning-rate schedulers. More...
class  Linear
 Device-templated fully connected (linear) component. More...
class  LinearConfig
 Configuration object for a Linear (fully connected) layer. More...
class  LinearLRScheduler
 Linear decay scheduler. More...
class  LlamaBlock
class  LlamaConfig
 Network-level configuration for LLaMA-style transformer networks. More...
class  LlamaModel
 LLaMA 3 compatible inference model. More...
struct  LlamaModelConfig
 Deployment configuration for Llama language models. More...
class  LlamaTransformer
 LLaMA-style transformer (decoder-only) for autoregressive token prediction. More...
class  Loss
 Abstract base class for neural network loss functions. More...
class  Lpe
 Encoder module for token and positional embeddings (device-templated). More...
class  LpeConfig
 Configuration class for the Learned Positional Encoder. More...
struct  MemoryStats
 Memory allocation breakdown for a single component. More...
class  MLP
 Multi-Layer Perceptron (MLP) composite component. More...
class  MLPConfig
 Configuration class for the Multi-Layer Perceptron (MLP) block. More...
class  Model
class  ModelConfig
 Abstract base configuration for all deployable Mila models. More...
struct  MultiAxisPartition
 Multi-axis partition for normalization over trailing dimensions. More...
class  MultiHeadAttention
 Multi-Head Attention module that accepts concatenated QKV input. More...
class  MultiHeadAttentionConfig
 Configuration class for Attention module. More...
class  Network
 Root composite network container. More...
class  NetworkFactory
 Factory registry for Network deserialization. More...
class  Residual
 Device-templated Residual connection component. More...
class  ResidualConfig
 Configuration class for Residual connection component. More...
class  RmsNorm
 Device-templated RMS Normalization component. More...
class  RmsNormConfig
class  Rope
 Device-templated RoPE component. More...
class  RopeConfig
 Type-safe metadata container for component serialization. More...
class  Softmax
 Softmax activation module (device-templated). More...
class  SoftmaxConfig
 Configuration class for Softmax module. More...
class  SoftmaxCrossEntropy
 Fused SoftmaxCrossEntropy loss module (device-templated). More...
class  Swiglu
 SwiGLU activation component. More...
class  SwigluConfig
class  Tensor
 Device-aware N-dimensional tensor. More...
class  TensorBuffer
 Device-agnostic buffer for storing tensor data with abstract type system. More...
struct  TensorDataTypeMap
 Primary template for mapping concrete C++ types to TensorDataType. More...
struct  TensorDataTypeMap< __nv_fp8_e4m3 >
struct  TensorDataTypeMap< __nv_fp8_e5m2 >
struct  TensorDataTypeMap< float >
 Concrete type mapping for float (FP32). More...
struct  TensorDataTypeMap< half >
struct  TensorDataTypeMap< nv_bfloat16 >
struct  TensorDataTypeMap< std::int16_t >
 Concrete type mapping for 16-bit signed integer. More...
struct  TensorDataTypeMap< std::int32_t >
 Concrete type mapping for 32-bit signed integer. More...
struct  TensorDataTypeMap< std::int8_t >
 Concrete type mapping for 8-bit signed integer. More...
struct  TensorDataTypeMap< std::uint16_t >
 Concrete type mapping for 16-bit unsigned integer. More...
struct  TensorDataTypeMap< std::uint32_t >
 Concrete type mapping for 32-bit unsigned integer. More...
struct  TensorDataTypeMap< std::uint8_t >
 Concrete type mapping for 8-bit unsigned integer. More...
struct  TensorDataTypeTraits
 Compile-time traits for TensorDataType enumeration values. More...
struct  TensorDataTypeTraits< TensorDataType::BF16 >
 Traits specialization for 16-bit brain floating point. More...
struct  TensorDataTypeTraits< TensorDataType::FP16 >
 Traits specialization for 16-bit half precision floating point. More...
struct  TensorDataTypeTraits< TensorDataType::FP32 >
 Traits specialization for 32-bit IEEE 754 floating point. More...
struct  TensorDataTypeTraits< TensorDataType::FP4_E2M1 >
 Traits specialization for 4-bit floating point with E2M1 format. More...
struct  TensorDataTypeTraits< TensorDataType::FP4_E3M0 >
 Traits specialization for 4-bit floating point with E3M0 format. More...
struct  TensorDataTypeTraits< TensorDataType::FP8_E4M3 >
 Traits specialization for 8-bit floating point with E4M3 format. More...
struct  TensorDataTypeTraits< TensorDataType::FP8_E5M2 >
 Traits specialization for 8-bit floating point with E5M2 format. More...
struct  TensorDataTypeTraits< TensorDataType::INT16 >
 Traits specialization for 16-bit signed integer. More...
struct  TensorDataTypeTraits< TensorDataType::INT32 >
 Traits specialization for 32-bit signed integer. More...
struct  TensorDataTypeTraits< TensorDataType::INT8 >
 Traits specialization for 8-bit signed integer. More...
struct  TensorDataTypeTraits< TensorDataType::UINT16 >
 Traits specialization for 16-bit unsigned integer. More...
struct  TensorDataTypeTraits< TensorDataType::UINT32 >
 Traits specialization for 32-bit unsigned integer. More...
struct  TensorDataTypeTraits< TensorDataType::UINT8 >
 Traits specialization for 8-bit unsigned integer. More...
struct  TensorHostTypeMap
 Maps abstract TensorDataType to host-compatible C++ type and TensorDataType. More...
struct  TensorHostTypeMap< TensorDataType::BF16 >
 Host type for 16-bit brain floating point. More...
struct  TensorHostTypeMap< TensorDataType::FP16 >
 Host type for 16-bit half precision floating point. More...
struct  TensorHostTypeMap< TensorDataType::FP32 >
 Host type for 32-bit IEEE 754 floating point. More...
struct  TensorHostTypeMap< TensorDataType::FP8_E4M3 >
 Host type for 8-bit floating point with E4M3 format. More...
struct  TensorHostTypeMap< TensorDataType::FP8_E5M2 >
 Host type for 8-bit floating point with E5M2 format. More...
struct  TensorHostTypeMap< TensorDataType::INT16 >
 Host type for 16-bit signed integer. More...
struct  TensorHostTypeMap< TensorDataType::INT32 >
 Host type for 32-bit signed integer. More...
struct  TensorHostTypeMap< TensorDataType::INT8 >
 Host type for 8-bit signed integer. More...
struct  TensorHostTypeMap< TensorDataType::UINT16 >
 Host type for 16-bit unsigned integer. More...
struct  TensorHostTypeMap< TensorDataType::UINT32 >
 Host type for 32-bit unsigned integer. More...
struct  TensorHostTypeMap< TensorDataType::UINT8 >
 Host type for 8-bit unsigned integer. More...
struct  TensorOps
 Device-dispatched TensorOps interface template. More...
struct  TensorOps< Compute::DeviceType::Cpu >
struct  TensorOps< Compute::DeviceType::Cuda >
struct  TensorShape
 Fixed-capacity inline shape descriptor for N-dimensional tensors. More...
class  TokenEmbedding
 Pure token embedding component (device-templated). More...
class  TokenEmbeddingConfig
 Configuration for the TokenEmbedding component. More...
class  UniqueIdGenerator
 Thread-safe generator for unique tensor identifiers. More...
class  VulkanTensorTraits
 Vulkan-specific traits for abstract tensor data types. More...

Concepts

concept  DeviceOnlyTensorDataType
 Concept identifying device-only abstract data types.
concept  HostCompatibleTensorDataType
 Concept identifying host-compatible abstract data types.
concept  isValidTensor
 Primary tensor configuration validation concept.
concept  PrecisionSupportedOnDevice
 Concept to validate precision is supported on a device at compile-time.
concept  TokenSink
 Satisfied by any callable accepting a span of decoded tokens.
concept  TokenStreamer
 Satisfied by any callable accepting a single decoded token.
concept  ValidFloatTensorDataType
 Concept constraining abstract data types to floating-point formats.
concept  ValidIntegerTensorDataType
 Concept constraining abstract data types to integer formats.

Typedefs

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CpuDropout = Dropout<DeviceType::Cpu, TInput, TOutput>
 Type alias for CPU-based dropout module with customizable tensor types.
template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CudaDropout = Dropout<DeviceType::Cuda, TInput, TOutput>
 Type alias for CUDA-based dropout module with customizable tensor types.
using Mila::Dnn::dim_t = int64_t
 Integer type used for tensor dimensions and indices.
using Mila::Dnn::dtype_t = TensorDataType
 Alias for TensorDataType enumeration.
template<TensorDataType TDataType>
using Mila::Dnn::host_type_t = typename TensorHostTypeMap<TDataType>::host_type
 Convenience alias for accessing host type mapping.
template<TensorDataType TDataType>
using Mila::Dnn::host_value_t = std::conditional_t<TensorDataTypeTraits<TDataType>::is_integer_type, int32_t, float>
 Host value type for given abstract tensor data type.
template<TensorDataType TDataType>
using Mila::Dnn::HostTensor = Tensor<TDataType, Compute::CpuMemoryResource>
 Host tensor alias.
using Mila::Dnn::index_t = TensorShape
 Index descriptor for multi-dimensional element access.
using Mila::Dnn::json = nlohmann::json
using Mila::Dnn::shape_t = TensorShape
 Row-major shape descriptor for tensor dimensional sizes.
using Mila::Dnn::stride_t = TensorShape
 Stride descriptor (in elements) for each tensor dimension, row-major layout.
using TokenId = Data::TokenId

Enumerations

enum class  Mila::Dnn::ActivationType {
  None , Relu , Gelu , Silu ,
  Swiglu , Tanh , Sigmoid , LeakyRelu ,
  Mish
}
 Enumeration of supported activation function types. More...
enum class  Mila::Dnn::ApproximationMethod { Exact , Tanh , Sigmoid }
 Approximation methods usable by activation functions. More...
enum class  Mila::Dnn::AttentionType { MultiHead , GroupedQuery , MultiQuery }
 Enumeration of supported attention mechanism types. More...
enum class  Mila::Dnn::ComponentType : int {
  Unknown = 0 , Linear , Gelu , Swiglu ,
  LayerNorm , RmsNorm , Softmax , Dropout ,
  MultiHeadAttention , GroupedQueryAttention , Residual , TokenEmbedding ,
  Lpe , Rope , SoftmaxCrossEntropy , Mlp ,
  Transformer , Network , Gpt2 , Llama ,
  Mistral , Bert , CustomComponentStart = 1000 , MockComponent = CustomComponentStart
}
 Canonical list of framework-known component types. More...
enum class  Mila::Dnn::ConnectionType { Addition }
 Connection types supported by residual and skip-connection components. More...
enum class  Mila::Dnn::EncodingType { Learned , RoPE , ALiBi }
 Positional encoding strategies. More...
enum class  Mila::Dnn::KvCacheCompression { None , FP8 }
 KV cache storage and compression strategy for GroupedQueryAttention. More...
enum class  Mila::Dnn::NormType { LayerNorm , RMSNorm }
 Normalization type selection. More...
enum class  Mila::Dnn::RuntimeMode : uint8_t { Inference , Training }
 Runtime mode governing Model API and Network build policy. More...
enum class  Mila::Dnn::TensorDataType {
  FP32 , FP16 , BF16 , FP8_E4M3 ,
  FP8_E5M2 , FP4_E2M1 , FP4_E3M0 , INT8 ,
  INT16 , INT32 , UINT8 , UINT16 ,
  UINT32
}
 Enumeration of supported abstract tensor data types. More...
enum class  Mila::Dnn::TrainingMode : uint8_t { Normal , Eval }
 Runtime behavioral state for Components built with RuntimeMode::Training. More...
enum class  Mila::Dnn::WeightQuantization { None , FP8 , FP4 }
 Weight storage and matmul strategy for Linear components. More...

Functions

std::string Mila::Dnn::activationTypeToString (ActivationType type)
 Converts an ActivationType enum value to its string representation.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::add (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b, Tensor< TDataType, TMemoryResource > &result, IExecutionContext *exec_context=nullptr)
 Element-wise addition with optional ExecutionContext (device-dispatched).
constexpr std::string_view Mila::Dnn::ApproximationMethodToString (ApproximationMethod m) noexcept
 Convert ApproximationMethod to a short string.
std::string Mila::Dnn::attentionTypeToString (AttentionType t)
 Convert AttentionType to string.
AxisPartition Mila::Dnn::computeAxisPartition (const shape_t &shape, dim_t axis, const char *op_name="Operation")
 Normalize and validate an axis, then compute partition sizes.
MultiAxisPartition Mila::Dnn::computeNormalizedShapePartition (const shape_t &shape, const shape_t &normalized_shape, const char *op_name="Operation")
 Compute partition for normalization over trailing dimensions.
int64_t Mila::Dnn::computeNumElements (const shape_t &shape)
 Compute total number of elements in a tensor shape.
int64_t Mila::Dnn::computePrefillChunkSize (int64_t batch, int64_t num_heads, int64_t head_dim, int64_t context_length, int64_t precision_bytes)
std::string Mila::Dnn::connectionTypeToString (ConnectionType type)
 Converts a ConnectionType enum value to its string representation.
template<TensorDataType TSrcDataType, typename TSrcMemoryResource, TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TSrcDataType, TSrcMemoryResource> && isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copy (const Tensor< TSrcDataType, TSrcMemoryResource > &src, Tensor< TDstDataType, TDstMemoryResource > &dst, IExecutionContext *exec_context=nullptr)
 Copies tensor data from source to destination tensor with optional ExecutionContext.
template<TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copyFromBlob (const Serialization::ITensorBlob &blob, Tensor< TDstDataType, TDstMemoryResource > &dst, IExecutionContext *exec_context=nullptr)
template<TensorDataType TSrcDataType, TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copyFromBlobWithConversion (const Serialization::ITensorBlob &blob, Tensor< TDstDataType, TDstMemoryResource > &dst, IExecutionContext *exec_context=nullptr)
 Copy a serialized blob into a destination tensor, converting element types.
template<TensorDataType TPrecision, typename MemoryResource>
void Mila::Dnn::debugDumpTensor (const ITensor &t, const std::string &label, size_t maxElements=8)
 Debug dump a concrete tensor to the log.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::divide (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b, Tensor< TDataType, TMemoryResource > &result, IExecutionContext *exec_context=nullptr)
 Element-wise division with optional ExecutionContext (device-dispatched).
std::string Mila::Dnn::encodingTypeToString (EncodingType p)
 Convert EncodingType to string.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill (Tensor< TDataType, TMemoryResource > &tensor, host_value_t< TDataType > host_value, IExecutionContext *exec_context=nullptr)
 Fill a tensor with a scalar host value (device-dispatched) with optional ExecutionContext.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill (Tensor< TDataType, TMemoryResource > &tensor, std::span< const host_value_t< TDataType > > host_values, IExecutionContext *exec_context=nullptr)
 Copy host values into a tensor with device dispatch and optional ExecutionContext.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill_normal (Tensor< TDataType, TMemoryResource > &tensor, float mean, float stddev, IExecutionContext *exec_context=nullptr)
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill_uniform (Tensor< TDataType, TMemoryResource > &tensor, host_value_t< TDataType > min_val, host_value_t< TDataType > max_val, IExecutionContext *exec_context=nullptr)
ComponentType Mila::Dnn::fromString (std::string_view s) noexcept
 Parse a case-insensitive component name into a ComponentType.
ComponentType Mila::Dnn::fromTypeId (std::string_view s) noexcept
 Map a short type identifier back to a ComponentType enum.
GptConfig Mila::Dnn::GPT2_Large ()
 GPT-2 Large (774M parameters).
GptConfig Mila::Dnn::GPT2_Medium ()
 GPT-2 Medium (345M parameters).
GptConfig Mila::Dnn::GPT2_Small ()
 Usage Examples:
GptConfig Mila::Dnn::GPT2_XL ()
 GPT-2 XL (1.5B parameters).
std::string Mila::Dnn::indexToString (const index_t &index)
LlamaConfig Mila::Dnn::Llama2_13B ()
 Llama 2 13B.
LlamaConfig Mila::Dnn::Llama2_70B ()
 Llama 2 70B.
LlamaConfig Mila::Dnn::Llama2_7B ()
 Llama 2 7B.
LlamaConfig Mila::Dnn::Llama3_1_405B ()
 Llama 3.1 405B.
LlamaConfig Mila::Dnn::Llama3_1_70B ()
 Llama 3.1 70B.
LlamaConfig Mila::Dnn::Llama3_1_8B ()
 Llama 3.1 8B.
LlamaConfig Mila::Dnn::Llama3_2_1B ()
 Usage Examples:
LlamaConfig Mila::Dnn::Llama3_2_3B ()
 Llama 3.2 3B.
LlamaConfig Mila::Dnn::Llama3_70B ()
 Llama 3 70B (Original release).
LlamaConfig Mila::Dnn::Llama3_8B ()
 Llama 3 8B (Original release).
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::multiply (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b, Tensor< TDataType, TMemoryResource > &result, IExecutionContext *exec_context=nullptr)
 Element-wise multiplication with optional ExecutionContext (device-dispatched).
std::string Mila::Dnn::normTypeToString (NormType n)
 Convert NormType to string.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator* (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b)
 Element-wise multiplication operator (always synchronous).
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator+ (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b)
 Element-wise addition operator (always synchronous).
MemoryStats Mila::Dnn::operator+ (MemoryStats lhs, const MemoryStats &rhs) noexcept
 Aggregate two MemoryStats instances.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator- (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b)
 Element-wise subtraction operator (always synchronous).
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator/ (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b)
 Element-wise division operator (always synchronous).
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
std::ostream & Mila::Dnn::operator<< (std::ostream &os, const Tensor< TDataType, TMemoryResource > &tensor)
 Stream insertion operator for tensor output.
TensorDataType Mila::Dnn::parseTensorDataType (const std::string &type_str)
std::string Mila::Dnn::shapeToString (const shape_t &shape)
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::split (const Tensor< TDataType, TMemoryResource > &input, Tensor< TDataType, TMemoryResource > &output_a, Tensor< TDataType, TMemoryResource > &output_b, IExecutionContext *exec_context=nullptr)
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::split (const Tensor< TDataType, TMemoryResource > &input, Tensor< TDataType, TMemoryResource > &output_a, Tensor< TDataType, TMemoryResource > &output_b, Tensor< TDataType, TMemoryResource > &output_c, IExecutionContext *exec_context=nullptr)
std::string Mila::Dnn::strideToString (const stride_t &stride)
ActivationType Mila::Dnn::stringToActivationType (const std::string &name)
 Converts a string to its corresponding ActivationType enum value.
AttentionType Mila::Dnn::stringToAttentionType (const std::string &v)
 Parse string to AttentionType.
ConnectionType Mila::Dnn::stringToConnectionType (const std::string &name)
 Converts a string to its corresponding ConnectionType enum value.
EncodingType Mila::Dnn::stringToEncodingType (const std::string &v)
 Parse string to PositionalEncodingType.
NormType Mila::Dnn::stringToNormType (const std::string &v)
 Parse string to NormType.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::subtract (const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b, Tensor< TDataType, TMemoryResource > &result, IExecutionContext *exec_context=nullptr)
 Element-wise subtraction with optional ExecutionContext (device-dispatched).
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
float Mila::Dnn::sum (const Tensor< TDataType, TMemoryResource > &tensor, IExecutionContext *exec_context=nullptr)
 Sum reduction with optional ExecutionContext (device-dispatched).
std::string Mila::Dnn::tensorDataTypeToString (TensorDataType type)
 Converts TensorDataType enumeration to human-readable string.
template<TensorDataType TDstDataType, TensorDataType TSrcDataType, typename TSrcMemoryResource>
requires isValidTensor<TSrcDataType, TSrcMemoryResource> && isValidTensor<TDstDataType, CpuMemoryResource>
Tensor< TDstDataType, CpuMemoryResourceMila::Dnn::toHost (const Tensor< TSrcDataType, TSrcMemoryResource > &src, IExecutionContext *exec_context=nullptr)
 Create a host (CPU) tensor from src and copy data into it.
std::string Mila::Dnn::toString (ComponentType t) noexcept
 Convert a ComponentType enum value to its canonical name.
std::string Mila::Dnn::toTypeId (ComponentType t) noexcept
 Get the short 2..4 character type identifier for a ComponentType.
void Mila::Dnn::validateTensorSize (const shape_t &shape, int64_t expected_size, const char *tensor_name="tensor", const char *op_name="Operation")
 Validate that a tensor has the expected number of elements.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::zero (Tensor< TDataType, TMemoryResource > &tensor, IExecutionContext *exec_context=nullptr)
 Zero a tensor using the fastest backend implementation.

Variables

template<TensorDataType TDataType>
constexpr bool Mila::Dnn::is_host_float_type = std::is_floating_point_v<host_type_t<TDataType>>
 Checks if a TensorDataType maps to a floating-point host type.
template<TensorDataType TDataType>
constexpr bool Mila::Dnn::is_host_integer_type = std::is_integral_v<host_type_t<TDataType>>
 Checks if a TensorDataType maps to an integer host type.
constexpr int64_t Mila::Dnn::kPrefillScratchByteCap = int64_t{ 1536 } * 1024 * 1024

Typedef Documentation

◆ CpuDropout

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CpuDropout = Dropout<DeviceType::Cpu, TInput, TOutput>
export

Type alias for CPU-based dropout module with customizable tensor types.

Template Parameters
TInputData type of the input tensor elements.
TOutputData type of the output tensor elements, defaults to TInput.

◆ CudaDropout

template<typename TInput = float, typename TOutput = TInput>
using Mila::Dnn::CudaDropout = Dropout<DeviceType::Cuda, TInput, TOutput>
export

Type alias for CUDA-based dropout module with customizable tensor types.

Template Parameters
TInputData type of the input tensor elements.
TOutputData type of the output tensor elements, defaults to TInput.

◆ dim_t

using Mila::Dnn::dim_t = int64_t
export

Integer type used for tensor dimensions and indices.

◆ dtype_t

Alias for TensorDataType enumeration.

Provides a concise alias for the TensorDataType enumeration to improve code readability in tensor-related contexts.

◆ host_type_t

template<TensorDataType TDataType>
using Mila::Dnn::host_type_t = typename TensorHostTypeMap<TDataType>::host_type
export

Convenience alias for accessing host type mapping.

Provides a more concise way to access the host type for a given abstract tensor data type, following modern C++ alias template patterns.

Template Parameters
TDataTypeAbstract tensor data type

Example usage:

using HostType = host_type_t<TensorDataType::FP16>; // -> float
using IntHostType = host_type_t<TensorDataType::INT8>; // -> std::int8_t
typename TensorHostTypeMap< TDataType >::host_type host_type_t
Convenience alias for accessing host type mapping.
Definition TensorHostTypeMap.ixx:210

◆ host_value_t

template<TensorDataType TDataType>
using Mila::Dnn::host_value_t = std::conditional_t<TensorDataTypeTraits<TDataType>::is_integer_type, int32_t, float>
export

Host value type for given abstract tensor data type.

Maps floating tensor types to float and integer tensor types to int32_t. Use this alias when declaring host-side buffers, spans or scalar arguments intended for conversion/transfer into tensors of TDataType.

Template Parameters
TDataTypeAbstract tensor data type from TensorDataType enum.

◆ HostTensor

template<TensorDataType TDataType>
using Mila::Dnn::HostTensor = Tensor<TDataType, Compute::CpuMemoryResource>
export

Host tensor alias.

◆ index_t

Index descriptor for multi-dimensional element access.

One index per tensor dimension. Valid indices satisfy: 0 <= index[i] < shape[i].

◆ json

using Mila::Dnn::json = nlohmann::json
export

◆ shape_t

Row-major shape descriptor for tensor dimensional sizes.

  • {} : scalar (rank 0)
  • {n} : 1D tensor of length n
  • {m, n} : 2D tensor, m rows and n columns
  • {B, T, C} : 3D activation (batch, sequence, channels)
  • {B, H, T, D} : 4D attention tensor

A zero in any position indicates an empty tensor.

Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx.

◆ stride_t

Stride descriptor (in elements) for each tensor dimension, row-major layout.

stride_t[i] is the element count to advance one step along dimension i. Length equals shape.size(); empty for scalars.

◆ TokenId

Enumeration Type Documentation

◆ ActivationType

enum class Mila::Dnn::ActivationType
exportstrong

Enumeration of supported activation function types.

This enum class defines the different activation functions that can be used throughout the Mila library, particularly in neural network layers.

Enumerator
None 

No activation (identity function).

Relu 

Rectified Linear Unit: max(0, x).

Gelu 

Gaussian Error Linear Unit: x * phi(x) where phi() is the standard Gaussian CDF.

Silu 

Sigmoid Linear Unit (Swish): x * sigmoid(x).

Swiglu 

SwiGLU: gated activation x1 * GELU(x2).

Tanh 

Hyperbolic Tangent: tanh(x).

Sigmoid 

Sigmoid function: 1 / (1 + exp(-x)).

LeakyRelu 

Leaky ReLU: max(alpha * x, x) where alpha is typically 0.01.

Mish 

Mish: x * tanh(softplus(x)).

◆ ApproximationMethod

enum class Mila::Dnn::ApproximationMethod
exportstrong

Approximation methods usable by activation functions.

Enumerator
Exact 

Exact implementation using erf.

Tanh 

Fast tanh-based approximation.

Sigmoid 

Sigmoid-based approximation.

Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx.

◆ AttentionType

enum class Mila::Dnn::AttentionType
exportstrong

Enumeration of supported attention mechanism types.

Enumerator
MultiHead 

Multi-Head Attention (MHA): independent Q, K, V per head.

GroupedQuery 

Grouped Query Attention (GQA): Q heads grouped over shared K/V heads.

MultiQuery 

Multi-Query Attention (MQA): all Q heads share a single K/V head.

◆ ComponentType

enum class Mila::Dnn::ComponentType : int
exportstrong

Canonical list of framework-known component types.

These values are used by the deserializer and factory code to identify component implementations. Values 1..999 are reserved for built-in components; values >= CustomComponentStart are available for user defined components or extensions.

Enumerator
Unknown 
Linear 
Gelu 
Swiglu 
LayerNorm 
RmsNorm 
Softmax 
Dropout 
MultiHeadAttention 
GroupedQueryAttention 
Residual 
TokenEmbedding 
Lpe 
Rope 
SoftmaxCrossEntropy 

WIP: Fused softmax + cross-entropy loss — targeted for Llama training.

Mlp 
Transformer 
Network 
Gpt2 

GPT-2 style transformer network.

Llama 

LLaMA style transformer network.

Mistral 

Mistral style transformer network.

Bert 

BERT style transformer network.

CustomComponentStart 
MockComponent 

Example custom component for testing.

Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Core/ComponentFactory.ixx.

◆ ConnectionType

enum class Mila::Dnn::ConnectionType
exportstrong

Connection types supported by residual and skip-connection components.

Defines how the input and transformed output are combined in residual and skip-connection architectures.

Currently only Addition is implemented. Other types (multiplication, concatenation) may be added in the future.

Enumerator
Addition 

Element-wise addition (y = x + F(x)).

◆ EncodingType

enum class Mila::Dnn::EncodingType
exportstrong

Positional encoding strategies.

Enumerator
Learned 

Learned absolute position embeddings (GPT-2 style).

RoPE 

Rotary Position Embeddings (LLaMA style).

ALiBi 

Attention with Linear Biases (MPT / BLOOM style).

◆ KvCacheCompression

enum class Mila::Dnn::KvCacheCompression
exportstrong

KV cache storage and compression strategy for GroupedQueryAttention.

Maps to the TKvPolicy template parameter on GroupedQueryAttention and CudaGqaOp via the fromPretrained() runtime→compile-time bridge. The mapping is:

None → NoKvCompression (BF16 cache, no compression overhead) FP8 → PerChannelKvFp8<> (FP8_E4M3 cache, per-head per-token float32 scales)

New compression algorithms (SlidingWindow, LowRank, TurboQuant) add a value here and a corresponding policy struct in KvCache.QuantPolicy — no other changes are required at this level.

Enumerator
None 

No compression — default; BF16 KV cache.

FP8 

FP8_E4M3 per-head per-token KV cache compression — Alpha.6 target.

◆ NormType

enum class Mila::Dnn::NormType
exportstrong

Normalization type selection.

Enumerator
LayerNorm 

Standard LayerNorm (mean + variance).

RMSNorm 

Root Mean Square Norm (variance-only).

◆ RuntimeMode

enum class Mila::Dnn::RuntimeMode : uint8_t
exportstrong

Runtime mode governing Model API and Network build policy.

Immutable after Model construction. Determines which public API methods are valid and how the Network allocates its buffers.

Mode Network build shape Valid Model API
Inference { 1, context_len } generate()
Training { batch, seq_len } eval(), sample()
Enumerator
Inference 
Training 

◆ TensorDataType

enum class Mila::Dnn::TensorDataType
exportstrong

Enumeration of supported abstract tensor data types.

Defines device-agnostic tensor data types that can be mapped to concrete implementations on different compute devices. This abstraction prevents host compilation issues with device-specific types while enabling compile-time dispatch and optimization.

Supported categories:

  • Standard floating-point: FP32
  • Reduced precision floating-point: FP16, BF16, FP8_E4M3, FP8_E5M2
  • Integer types: Various widths from 8-bit to 32-bit, signed and unsigned
Note
Device-only types (FP16, BF16, FP8) require device-accessible memory
Packed sub-byte types (FP4, INT4, UINT4) are planned for future implementation
Enumerator
FP32 

32-bit IEEE 754 floating point, host-compatible

FP16 

16-bit half precision floating point, device-only

BF16 

16-bit brain floating point, device-only

FP8_E4M3 

8-bit floating point with 4-bit exponent and 3-bit mantissa, device-only

FP8_E5M2 

8-bit floating point with 5-bit exponent and 2-bit mantissa, device-only

FP4_E2M1 

4-bit floating point with 2-bit exponent and 1-bit mantissa, packed, device-only

FP4_E3M0 

4-bit floating point with 3-bit exponent and 0-bit mantissa, packed, device-only

INT8 

8-bit signed integer

INT16 

16-bit signed integer, host-compatible

INT32 

32-bit signed integer, host-compatible

UINT8 

8-bit unsigned integer

UINT16 

16-bit unsigned integer, host-compatible

UINT32 

32-bit unsigned integer, host-compatible

◆ TrainingMode

enum class Mila::Dnn::TrainingMode : uint8_t
exportstrong

Runtime behavioral state for Components built with RuntimeMode::Training.

TrainingMode governs the runtime behavioral state of a Component that was built with RuntimeMode::Training. It is orthogonal to RuntimeModeRuntimeMode is a build-time allocation policy while TrainingMode is a runtime behavioral toggle.

States

TrainingMode Gradients Dropout Batch Norm
Normal active on uses batch stats
Eval inactive off uses running stats

Validity

TrainingMode is only meaningful on Components built with RuntimeMode::Training. Calling setTrainingMode() on a Component built with RuntimeMode::Inference throws std::runtime_error.

Ownership

Model drives transitions via Network::setTrainingMode() — the toggle is never exposed directly to the user.

Enumerator
Normal 

Gradients active, dropout on, batch norm uses batch stats.

Eval 

No gradients, dropout off, batch norm uses running stats.

Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx.

◆ WeightQuantization

enum class Mila::Dnn::WeightQuantization
exportstrong

Weight storage and matmul strategy for Linear components.

Maps to the TWeightQuant template parameter on Linear and CudaLinearOp via the fromPretrained() runtime→compile-time bridge. The mapping is:

None → NoWeightQuant (BF16 weights, standard cuBLASLt plan) FP8 → PerChannelFp8<> (FP8_E4M3 weights, per-channel float32 scales) FP4 → PerGroupFp4<> (future)

This enum is Mila API vocabulary. Callers set it via fluent methods on the concrete model config — they do not interact with the policy structs directly.

Enumerator
None 

BF16 weights — default; no quantization overhead.

FP8 

FP8_E4M3 per-channel weight quantization — Alpha.5 target.

FP4 

Per-group FP4 weight quantization — future target.

Function Documentation

◆ activationTypeToString()

std::string Mila::Dnn::activationTypeToString ( ActivationType type)
inlineexport

Converts an ActivationType enum value to its string representation.

Parameters
typeThe ActivationType to convert
Returns
std::string The string representation of the activation type
Here is the caller graph for this function:

◆ add()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::add ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b,
Tensor< TDataType, TMemoryResource > & result,
IExecutionContext * exec_context = nullptr )
export

Element-wise addition with optional ExecutionContext (device-dispatched).

Computes result[i] = a[i] + b[i] for all elements. Automatically dispatches to the appropriate device implementation based on memory resource type.

Template Parameters
TDataTypeAbstract tensor data type
TMemoryResourceMemory resource type determining device
Parameters
aFirst input tensor
bSecond input tensor
resultOutput tensor (must be pre-allocated with matching shape)
exec_contextOptional execution context for stream control (borrowed, not owned)
Note
For CUDA tensors, use CudaExecutionContext; for CPU, parameter is ignored
exec_context must outlive this function call
When exec_context provided, caller controls synchronization
When null, uses default stream/execution and synchronizes before returning

Example:

// With explicit context (async)
auto ctx = std::make_unique<CudaExecutionContext>(0);
add(tensor_a, tensor_b, result, ctx.get());
ctx->synchronize();
// Without context (sync)
add(tensor_a, tensor_b, result); // Returns after completion
void add(const Tensor< TDataType, TMemoryResource > &a, const Tensor< TDataType, TMemoryResource > &b, Tensor< TDataType, TMemoryResource > &result, IExecutionContext *exec_context=nullptr)
Element-wise addition with optional ExecutionContext (device-dispatched).
Definition TensorOps.Math.ixx:84
Here is the caller graph for this function:

◆ ApproximationMethodToString()

std::string_view Mila::Dnn::ApproximationMethodToString ( ApproximationMethod m)
constexprexportnoexcept

Convert ApproximationMethod to a short string.

Returns a constexpr std::string_view suitable for logging/serialization.

Here is the caller graph for this function:

◆ attentionTypeToString()

std::string Mila::Dnn::attentionTypeToString ( AttentionType t)
inlineexport

Convert AttentionType to string.

◆ computeAxisPartition()

AxisPartition Mila::Dnn::computeAxisPartition ( const shape_t & shape,
dim_t axis,
const char * op_name = "Operation" )
export

Normalize and validate an axis, then compute partition sizes.

Parameters
shapeTensor shape.
axisAxis to normalize (supports negative indexing).
op_nameOperation name for error messages.
Returns
AxisPartition containing normalized axis and sizes.
Exceptions
std::runtime_errorIf axis is out of range.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeNormalizedShapePartition()

MultiAxisPartition Mila::Dnn::computeNormalizedShapePartition ( const shape_t & shape,
const shape_t & normalized_shape,
const char * op_name = "Operation" )
export

Compute partition for normalization over trailing dimensions.

Verifies that the trailing dimensions of shape match normalized_shape exactly, then computes outer and normalized sizes and shapes.

Parameters
shapeInput tensor shape.
normalized_shapeExpected trailing dimensions to normalize over.
op_nameOperation name for error messages.
Returns
MultiAxisPartition containing partition information.
Exceptions
std::runtime_errorIf normalized_shape doesn't match trailing dims.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ computeNumElements()

int64_t Mila::Dnn::computeNumElements ( const shape_t & shape)
export

Compute total number of elements in a tensor shape.

Parameters
shapeTensor shape.
Returns
int64_t Total number of elements.
Here is the caller graph for this function:

◆ computePrefillChunkSize()

int64_t Mila::Dnn::computePrefillChunkSize ( int64_t batch,
int64_t num_heads,
int64_t head_dim,
int64_t context_length,
int64_t precision_bytes )
inlineexport
Here is the caller graph for this function:

◆ connectionTypeToString()

std::string Mila::Dnn::connectionTypeToString ( ConnectionType type)
inlineexport

Converts a ConnectionType enum value to its string representation.

Parameters
typeThe ConnectionType to convert
Returns
std::string The string representation of the connection type

Example:

// name == "Addition"
@ Addition
Element-wise addition (y = x + F(x)).
Definition ConnectionType.ixx:25
std::string connectionTypeToString(ConnectionType type)
Converts a ConnectionType enum value to its string representation.
Definition ConnectionType.ixx:40

◆ copy()

template<TensorDataType TSrcDataType, typename TSrcMemoryResource, TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TSrcDataType, TSrcMemoryResource> && isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copy ( const Tensor< TSrcDataType, TSrcMemoryResource > & src,
Tensor< TDstDataType, TDstMemoryResource > & dst,
IExecutionContext * exec_context = nullptr )
export

Copies tensor data from source to destination tensor with optional ExecutionContext.

Transfers data from source tensor to pre-allocated destination tensor. Both tensors must have compatible shapes (same dimensions). Supports type conversion and cross-device transfers with explicit stream control.

Device compatibility rules:

  • Host-accessible to host-accessible: Always allowed
  • Host-accessible to device-only: Uses destination device
  • Device-only to host-accessible: Uses source device
  • Device-only to device-only: Must be same device type (e.g., both CUDA)

ExecutionContext handling:

  • Optional ExecutionContext parameter for stream control (borrowed, not owned)
  • When provided, operations use the context's stream (caller controls sync)
  • When null, operations use default stream and synchronize before returning
  • Raw pointer semantics ensure zero overhead
Template Parameters
TSrcDataTypeSource tensor data type
TSrcMemoryResourceSource memory resource type
TDstDataTypeDestination tensor data type
TDstMemoryResourceDestination memory resource type
Parameters
srcSource tensor to copy from
dstDestination tensor to copy to (must be pre-allocated)
exec_contextOptional execution context for stream control (borrowed, not owned)
Exceptions
std::runtime_errorIf device-only tensors are on incompatible device types
Note
exec_context must outlive this function call
When exec_context provided, caller controls synchronization
When exec_context is null, uses default stream and synchronizes before returning
For CPU-only operations, exec_context parameter is ignored

Example:

// With explicit context (async)
auto ctx = std::make_unique<CudaExecutionContext>(0);
copy(src_tensor, dst_tensor, ctx.get());
ctx->synchronize();
// Without context (sync)
copy(src_tensor, dst_tensor); // Returns after completion
void copy(const Tensor< TSrcDataType, TSrcMemoryResource > &src, Tensor< TDstDataType, TDstMemoryResource > &dst, IExecutionContext *exec_context=nullptr)
Copies tensor data from source to destination tensor with optional ExecutionContext.
Definition TensorOps.Transfer.ixx:88
Here is the call graph for this function:
Here is the caller graph for this function:

◆ copyFromBlob()

template<TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copyFromBlob ( const Serialization::ITensorBlob & blob,
Tensor< TDstDataType, TDstMemoryResource > & dst,
IExecutionContext * exec_context = nullptr )
export
Here is the call graph for this function:
Here is the caller graph for this function:

◆ copyFromBlobWithConversion()

template<TensorDataType TSrcDataType, TensorDataType TDstDataType, typename TDstMemoryResource>
requires isValidTensor<TDstDataType, TDstMemoryResource>
void Mila::Dnn::copyFromBlobWithConversion ( const Serialization::ITensorBlob & blob,
Tensor< TDstDataType, TDstMemoryResource > & dst,
IExecutionContext * exec_context = nullptr )
export

Copy a serialized blob into a destination tensor, converting element types.

Intended for quantize-on-load paths where the checkpoint dtype (TSrcDataType) differs from the weight storage dtype (TDstDataType). Shape is validated against the destination tensor. Dispatches to the device-specific backend.

Template Parameters
TSrcDataTypeBlob element dtype (e.g. BF16).
TDstDataTypeDestination tensor dtype (e.g. FP8_E4M3).
TDstMemoryResourceDestination memory resource.
Parameters
blobSource tensor blob.
dstPre-allocated destination tensor.
exec_contextOptional execution context for stream control (borrowed).
Exceptions
std::invalid_argumentif blob shape != dst shape.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ debugDumpTensor()

template<TensorDataType TPrecision, typename MemoryResource>
void Mila::Dnn::debugDumpTensor ( const ITensor & t,
const std::string & label,
size_t maxElements = 8 )
export

Debug dump a concrete tensor to the log.

Template parameters:

The function attempts a dynamic_cast to the concrete Tensor<TPrecision, MemoryResource>. If the concrete tensor is host-accessible the contents are printed directly (first maxElements values). Otherwise a host copy (Tensor<TPrecision, CpuMemoryResource>) is created and the first maxElements values are printed. This avoids flooding logs while giving a quick numeric snapshot.

Notes:

  • Intended for short-lived debug logging in tests and components.
  • Callers should pick the template parameters that match the component context (e.g., TensorType alias).
Here is the call graph for this function:

◆ divide()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::divide ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b,
Tensor< TDataType, TMemoryResource > & result,
IExecutionContext * exec_context = nullptr )
export

Element-wise division with optional ExecutionContext (device-dispatched).

Computes result[i] = a[i] / b[i] for all elements.

Template Parameters
TDataTypeAbstract tensor data type
TMemoryResourceMemory resource type determining device
Parameters
aFirst input tensor (dividend)
bSecond input tensor (divisor)
resultOutput tensor (must be pre-allocated with matching shape)
exec_contextOptional execution context for stream control (borrowed, not owned)
Here is the caller graph for this function:

◆ encodingTypeToString()

std::string Mila::Dnn::encodingTypeToString ( EncodingType p)
inlineexport

Convert EncodingType to string.

◆ fill() [1/2]

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill ( Tensor< TDataType, TMemoryResource > & tensor,
host_value_t< TDataType > host_value,
IExecutionContext * exec_context = nullptr )
export

Fill a tensor with a scalar host value (device-dispatched) with optional ExecutionContext.

Forwards scalar fills to the device-specific TensorOps<Tag>::fill. Borrows execution context for stream control with zero overhead. The function signature enforces the expected host scalar representation for each abstract tensor data type via host_value_t<TDataType>.

Template Parameters
TDataTypeAbstract tensor data type.
TMemoryResourceMemory resource type backing the tensor.
Parameters
tensorDestination tensor to be filled. Must satisfy isValidTensor.
host_valueScalar value in host representation to broadcast to the tensor.
exec_contextOptional execution context for stream control (borrowed, not owned)
Note
exec_context must outlive this function call
When exec_context provided, caller controls synchronization
When exec_context is null, uses default stream and synchronizes before returning
For CUDA tensors, use CudaExecutionContext; for CPU, parameter is ignored

Example:

// With explicit context (async)
auto ctx = std::make_unique<CudaExecutionContext>(0);
fill(float_tensor, 3.14f, ctx.get());
fill(int_tensor, 42, ctx.get());
ctx->synchronize();
// Without context (sync)
fill(float_tensor, 3.14f); // Returns after completion
fill(int_tensor, 42); // Returns after completion
void fill(Tensor< TDataType, TMemoryResource > &tensor, std::span< const host_value_t< TDataType > > host_values, IExecutionContext *exec_context=nullptr)
Copy host values into a tensor with device dispatch and optional ExecutionContext.
Definition TensorOps.Fill.ixx:95

◆ fill() [2/2]

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill ( Tensor< TDataType, TMemoryResource > & tensor,
std::span< const host_value_t< TDataType > > host_values,
IExecutionContext * exec_context = nullptr )
export

Copy host values into a tensor with device dispatch and optional ExecutionContext.

Forwards the host->tensor copy operation (span form) to the device-specific implementation TensorOps<Tag>::fill. Borrows execution context for stream control with zero overhead. Falls back to default stream when no context provided.

The host element type is selected by host_value_t<TDataType> so callers must provide values in the expected host representation (float for floating-point tensor types, int32_t for integer tensor types). The device implementation performs any necessary conversion/quantization.

Template Parameters
TDataTypeAbstract tensor data type.
TMemoryResourceMemory resource type backing the tensor.
Parameters
tensorDestination tensor to be filled. Must satisfy isValidTensor.
host_valuesSpan of host values in host representation (see host_value_t).
exec_contextOptional execution context for stream control (borrowed, not owned)
Note
exec_context must outlive this function call
When exec_context provided, caller controls synchronization
When exec_context is null, uses default stream and synchronizes before returning
For CUDA tensors, use CudaExecutionContext; for CPU, parameter is ignored

Example:

// With explicit context (async)
auto ctx = std::make_unique<CudaExecutionContext>(0);
std::vector<float> values = {1.0f, 2.0f, 3.0f};
fill(tensor, std::span{values}, ctx.get());
ctx->synchronize();
// Without context (sync)
fill(tensor, std::span{values}); // Returns after completion

◆ fill_normal()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill_normal ( Tensor< TDataType, TMemoryResource > & tensor,
float mean,
float stddev,
IExecutionContext * exec_context = nullptr )
export

◆ fill_uniform()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::fill_uniform ( Tensor< TDataType, TMemoryResource > & tensor,
host_value_t< TDataType > min_val,
host_value_t< TDataType > max_val,
IExecutionContext * exec_context = nullptr )
export

◆ fromString()

ComponentType Mila::Dnn::fromString ( std::string_view s)
inlineexportnoexcept

Parse a case-insensitive component name into a ComponentType.

Accepts canonical names (case-insensitive) produced by toString and returns the corresponding enum value. Returns ComponentType::Unknown if the input does not match any known type.

Parameters
sInput string (case-insensitive)
Returns
ComponentType Matching enum value or Unknown
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Core/ComponentFactory.ixx.
Here is the caller graph for this function:

◆ fromTypeId()

ComponentType Mila::Dnn::fromTypeId ( std::string_view s)
inlineexportnoexcept

Map a short type identifier back to a ComponentType enum.

Accepts the compact lowercase identifiers produced by toTypeId and returns the corresponding enum value. Returns ComponentType::Unknown for unrecognized identifiers.

Parameters
sShort lowercase type id (for example "fc", "mlp", "tf")
Returns
ComponentType Matching enum value or Unknown

◆ GPT2_Large()

GptConfig Mila::Dnn::GPT2_Large ( )
export

GPT-2 Large (774M parameters).

Architecture:

  • 1280 embedding dim, 36 layers, 20 heads
  • 5120 hidden dim (4x embedding)
Here is the call graph for this function:

◆ GPT2_Medium()

GptConfig Mila::Dnn::GPT2_Medium ( )
export

GPT-2 Medium (345M parameters).

Architecture:

  • 1024 embedding dim, 24 layers, 16 heads
  • 4096 hidden dim (4x embedding)
Here is the call graph for this function:

◆ GPT2_Small()

GptConfig Mila::Dnn::GPT2_Small ( )
export

Usage Examples:

// Use a preset directly auto config = Mila::Dnn::Networks::GPT2_Small(); auto network = GptNetwork(config, 50257, 1024, "gpt2_small");

// Customize a preset auto config = Mila::Dnn::Networks::GPT2_Small() .withDropout(0.2f); // Custom dropout

// Mix and match for research auto custom = Mila::Dnn::Networks::GPT2_Medium() .withBias(false) // Remove bias .withResidualScale(0.5f); // Add residual scaling

GPT-2 Small (117M parameters)

Architecture:

  • 768 embedding dim, 12 layers, 12 heads
  • Standard multi-head attention
  • LayerNorm + GELU activation
  • 3072 hidden dim (4x embedding)
  • Bias enabled
  • Learned positional embeddings
Here is the call graph for this function:

◆ GPT2_XL()

GptConfig Mila::Dnn::GPT2_XL ( )
export

GPT-2 XL (1.5B parameters).

Architecture:

  • 1600 embedding dim, 48 layers, 25 heads
  • 6400 hidden dim (4x embedding)
Here is the call graph for this function:

◆ indexToString()

std::string Mila::Dnn::indexToString ( const index_t & index)
export
Here is the call graph for this function:

◆ Llama2_13B()

LlamaConfig Mila::Dnn::Llama2_13B ( )
export

Llama 2 13B.

Architecture:

  • 5120 embedding dim, 40 layers, 40 heads
  • MHA (40 KV heads)
  • 13824 hidden dim
Here is the call graph for this function:

◆ Llama2_70B()

LlamaConfig Mila::Dnn::Llama2_70B ( )
export

Llama 2 70B.

Architecture:

  • 8192 embedding dim, 80 layers, 64 heads
  • Grouped Query Attention (8 KV heads)
  • 28672 hidden dim
Here is the call graph for this function:

◆ Llama2_7B()

LlamaConfig Mila::Dnn::Llama2_7B ( )
export

Llama 2 7B.

Architecture:

  • 4096 embedding dim, 32 layers, 32 heads
  • Grouped Query Attention (32 KV heads - effectively MHA)
  • 11008 hidden dim
  • RoPE theta=10000 (original)
  • 4k context window
Here is the call graph for this function:

◆ Llama3_1_405B()

LlamaConfig Mila::Dnn::Llama3_1_405B ( )
export

Llama 3.1 405B.

Architecture:

  • 16384 embedding dim, 126 layers, 128 heads
  • Grouped Query Attention (8 KV heads)
  • 53248 hidden dim
  • 128k context window
Here is the call graph for this function:

◆ Llama3_1_70B()

LlamaConfig Mila::Dnn::Llama3_1_70B ( )
export

Llama 3.1 70B.

Architecture:

  • 8192 embedding dim, 80 layers, 64 heads
  • Grouped Query Attention (8 KV heads)
  • 28672 hidden dim
  • 128k context window
Here is the call graph for this function:

◆ Llama3_1_8B()

LlamaConfig Mila::Dnn::Llama3_1_8B ( )
export

Llama 3.1 8B.

Architecture:

  • 4096 embedding dim, 32 layers
  • 32 heads
  • Grouped Query Attention (8 KV heads)
  • 14336 hidden dim (~3.5x for SwiGLU)
  • No bias
  • RoPE positional encoding (theta=500000)
  • 128k context window
Here is the call graph for this function:

◆ Llama3_2_1B()

LlamaConfig Mila::Dnn::Llama3_2_1B ( )
export

Usage Examples:

// Use a preset directly auto config = Mila::Dnn::Networks::Llama3_2_1B(); auto network = LlamaNetwork(config, 128256, 131072, "llama3_2_1b");

// Customize a preset auto config = Mila::Dnn::Networks::Llama3_8B() .withRoPETheta(1000000.0f); // Extend context with higher theta

// Mix and match for research auto custom = Mila::Dnn::Networks::Llama3_8B() .withNumKVHeads(32) // Convert GQA to MHA .withResidualScale(0.5f); // Add residual scaling

Llama 3.2 1B

Architecture:

  • 2048 embedding dim, 16 layers
  • 32 heads
  • Grouped Query Attention (8 KV heads)
  • RMSNorm + SwiGLU activation
  • 8192 hidden dim (~4x for SwiGLU)
  • No bias
  • RoPE positional encoding (theta=500000)
  • 128k context window
Here is the call graph for this function:

◆ Llama3_2_3B()

LlamaConfig Mila::Dnn::Llama3_2_3B ( )
export

Llama 3.2 3B.

Architecture:

  • 3072 embedding dim, 28 layers
  • 24 heads
  • Grouped Query Attention (8 KV heads)
  • 8192 hidden dim
  • 128k context window
Here is the call graph for this function:

◆ Llama3_70B()

LlamaConfig Mila::Dnn::Llama3_70B ( )
export

Llama 3 70B (Original release).

Architecture:

  • Same as 3.1 70B but with 8k context window (no scaling)
Here is the call graph for this function:

◆ Llama3_8B()

LlamaConfig Mila::Dnn::Llama3_8B ( )
export

Llama 3 8B (Original release).

Architecture:

  • Same as 3.1 8B but with 8k context window (no scaling)
Here is the call graph for this function:

◆ multiply()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::multiply ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b,
Tensor< TDataType, TMemoryResource > & result,
IExecutionContext * exec_context = nullptr )
export

Element-wise multiplication with optional ExecutionContext (device-dispatched).

Computes result[i] = a[i] * b[i] for all elements.

Template Parameters
TDataTypeAbstract tensor data type
TMemoryResourceMemory resource type determining device
Parameters
aFirst input tensor
bSecond input tensor
resultOutput tensor (must be pre-allocated with matching shape)
exec_contextOptional execution context for stream control (borrowed, not owned)
Here is the caller graph for this function:

◆ normTypeToString()

std::string Mila::Dnn::normTypeToString ( NormType n)
inlineexport

Convert NormType to string.

◆ operator*()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator* ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b )
export

Element-wise multiplication operator (always synchronous).

Note
This operator always uses default execution and synchronizes.
For async operations with stream control, use multiply(a, b, result, ctx).
Here is the call graph for this function:

◆ operator+() [1/2]

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator+ ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b )
export

Element-wise addition operator (always synchronous).

Note
This operator always uses default execution and synchronizes.
For async operations with stream control, use add(a, b, result, ctx).
Here is the call graph for this function:

◆ operator+() [2/2]

MemoryStats Mila::Dnn::operator+ ( MemoryStats lhs,
const MemoryStats & rhs )
nodiscardexportnoexcept

Aggregate two MemoryStats instances.

◆ operator-()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator- ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b )
export

Element-wise subtraction operator (always synchronous).

Note
This operator always uses default execution and synchronizes.
For async operations with stream control, use subtract(a, b, result, ctx).
Here is the call graph for this function:

◆ operator/()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
Tensor< TDataType, TMemoryResource > Mila::Dnn::operator/ ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b )
export

Element-wise division operator (always synchronous).

Note
This operator always uses default execution and synchronizes.
For async operations with stream control, use divide(a, b, result, ctx).
Here is the call graph for this function:

◆ operator<<()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
std::ostream & Mila::Dnn::operator<< ( std::ostream & os,
const Tensor< TDataType, TMemoryResource > & tensor )
export

Stream insertion operator for tensor output.

Here is the call graph for this function:

◆ parseTensorDataType()

TensorDataType Mila::Dnn::parseTensorDataType ( const std::string & type_str)
export
Here is the caller graph for this function:

◆ shapeToString()

std::string Mila::Dnn::shapeToString ( const shape_t & shape)
export
Here is the call graph for this function:
Here is the caller graph for this function:

◆ split() [1/2]

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::split ( const Tensor< TDataType, TMemoryResource > & input,
Tensor< TDataType, TMemoryResource > & output_a,
Tensor< TDataType, TMemoryResource > & output_b,
IExecutionContext * exec_context = nullptr )
export
Here is the caller graph for this function:

◆ split() [2/2]

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::split ( const Tensor< TDataType, TMemoryResource > & input,
Tensor< TDataType, TMemoryResource > & output_a,
Tensor< TDataType, TMemoryResource > & output_b,
Tensor< TDataType, TMemoryResource > & output_c,
IExecutionContext * exec_context = nullptr )
export

◆ strideToString()

std::string Mila::Dnn::strideToString ( const stride_t & stride)
export
Here is the call graph for this function:

◆ stringToActivationType()

ActivationType Mila::Dnn::stringToActivationType ( const std::string & name)
inlineexport

Converts a string to its corresponding ActivationType enum value.

Parameters
nameThe string representation of an activation function
Returns
ActivationType The corresponding enum value
Exceptions
std::invalid_argumentif the string doesn't match any known activation function

◆ stringToAttentionType()

AttentionType Mila::Dnn::stringToAttentionType ( const std::string & v)
inlineexport

Parse string to AttentionType.

Exceptions
std::invalid_argumenton unknown value.

◆ stringToConnectionType()

ConnectionType Mila::Dnn::stringToConnectionType ( const std::string & name)
inlineexport

Converts a string to its corresponding ConnectionType enum value.

Parameters
nameThe string representation of a connection type
Returns
ConnectionType The corresponding enum value
Exceptions
std::invalid_argumentif the string doesn't match any known connection type

Example:

auto type = stringToConnectionType("Addition");
// type == ConnectionType::Addition
ConnectionType stringToConnectionType(const std::string &name)
Converts a string to its corresponding ConnectionType enum value.
Definition ConnectionType.ixx:65

◆ stringToEncodingType()

EncodingType Mila::Dnn::stringToEncodingType ( const std::string & v)
inlineexport

Parse string to PositionalEncodingType.

Exceptions
std::invalid_argumenton unknown value.

◆ stringToNormType()

NormType Mila::Dnn::stringToNormType ( const std::string & v)
inlineexport

Parse string to NormType.

Exceptions
std::invalid_argumenton unknown value.

◆ subtract()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::subtract ( const Tensor< TDataType, TMemoryResource > & a,
const Tensor< TDataType, TMemoryResource > & b,
Tensor< TDataType, TMemoryResource > & result,
IExecutionContext * exec_context = nullptr )
export

Element-wise subtraction with optional ExecutionContext (device-dispatched).

Computes result[i] = a[i] - b[i] for all elements.

Template Parameters
TDataTypeAbstract tensor data type
TMemoryResourceMemory resource type determining device
Parameters
aFirst input tensor (minuend)
bSecond input tensor (subtrahend)
resultOutput tensor (must be pre-allocated with matching shape)
exec_contextOptional execution context for stream control (borrowed, not owned)
Here is the caller graph for this function:

◆ sum()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
float Mila::Dnn::sum ( const Tensor< TDataType, TMemoryResource > & tensor,
IExecutionContext * exec_context = nullptr )
export

Sum reduction with optional ExecutionContext (device-dispatched).

Computes the sum of all elements in the tensor. Always synchronizes before returning the result (even when exec_context is provided).

Template Parameters
TDataTypeAbstract tensor data type
TMemoryResourceMemory resource type determining device
Parameters
tensorInput tensor
exec_contextOptional execution context for stream control (borrowed, not owned)
Returns
Sum of all elements as float
Note
Always returns after synchronization to ensure result validity
Here is the caller graph for this function:

◆ tensorDataTypeToString()

std::string Mila::Dnn::tensorDataTypeToString ( TensorDataType type)
inlineexport

Converts TensorDataType enumeration to human-readable string.

Here is the caller graph for this function:

◆ toHost()

template<TensorDataType TDstDataType, TensorDataType TSrcDataType, typename TSrcMemoryResource>
requires isValidTensor<TSrcDataType, TSrcMemoryResource> && isValidTensor<TDstDataType, CpuMemoryResource>
Tensor< TDstDataType, CpuMemoryResource > Mila::Dnn::toHost ( const Tensor< TSrcDataType, TSrcMemoryResource > & src,
IExecutionContext * exec_context = nullptr )
export

Create a host (CPU) tensor from src and copy data into it.

By default the destination data type matches the source data type. The destination tensor preserves the source shape. An optional execution context may be supplied for device-side stream control when the source is device-resident.

Template Parameters
TSrcDataTypeSource tensor data type
TSrcMemoryResourceSource memory resource type
TDstDataTypeDestination tensor data type (defaults to source type)
Parameters
srcSource tensor to copy from
exec_contextOptional execution context for stream control (borrowed)
Returns
Tensor on CPU with copied data
Here is the call graph for this function:

◆ toString()

std::string Mila::Dnn::toString ( ComponentType t)
inlineexportnoexcept

Convert a ComponentType enum value to its canonical name.

Returns a human-readable name suitable for logs and metadata fields (for example "Linear", "Transformer"). Always returns "Unknown" for unrecognized enum values.

Parameters
tComponentType enum value
Returns
std::string Canonical name for the component type
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Core/ComponentFactory.ixx.
Here is the caller graph for this function:

◆ toTypeId()

std::string Mila::Dnn::toTypeId ( ComponentType t)
inlineexportnoexcept

Get the short 2..4 character type identifier for a ComponentType.

The short type id is intended for compact labels in serialized metadata and concise diagnostics (examples: "fc" for Linear, "mlp" for MLP, "tf" for Transformer). Returns "Unknown" for unrecognized types.

Parameters
tComponentType enum value
Returns
std::string Short lowercase identifier (2..4 chars) or "Unknown"

◆ validateTensorSize()

void Mila::Dnn::validateTensorSize ( const shape_t & shape,
int64_t expected_size,
const char * tensor_name = "tensor",
const char * op_name = "Operation" )
export

Validate that a tensor has the expected number of elements.

Parameters
shapeTensor shape to validate.
expected_sizeExpected number of elements.
tensor_nameName of tensor for error message.
op_nameOperation name for error message.
Exceptions
std::runtime_errorIf size doesn't match.
Here is the call graph for this function:

◆ zero()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource>
void Mila::Dnn::zero ( Tensor< TDataType, TMemoryResource > & tensor,
IExecutionContext * exec_context = nullptr )
export

Zero a tensor using the fastest backend implementation.

Forwards to the device-specific TensorOps<device>::zero implementation.

Template Parameters
TDataTypeAbstract tensor data type.
TMemoryResourceMemory resource type backing the tensor.
Parameters
tensorDestination tensor to be zeroed.
exec_contextOptional execution context for stream control (borrowed).
Note
If exec_context is provided the backend should schedule the zero on the context's stream and avoid synchronizing. If exec_context is null the backend may synchronize before returning to provide synchronous semantics.
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx.
Here is the caller graph for this function:

Variable Documentation

◆ is_host_float_type

template<TensorDataType TDataType>
bool Mila::Dnn::is_host_float_type = std::is_floating_point_v<host_type_t<TDataType>>
constexprexport

Checks if a TensorDataType maps to a floating-point host type.

Compile-time utility to determine if the host representation of an abstract tensor data type is a floating-point type.

Template Parameters
TDataTypeAbstract tensor data type to check

◆ is_host_integer_type

template<TensorDataType TDataType>
bool Mila::Dnn::is_host_integer_type = std::is_integral_v<host_type_t<TDataType>>
constexprexport

Checks if a TensorDataType maps to an integer host type.

Compile-time utility to determine if the host representation of an abstract tensor data type is an integer type.

Template Parameters
TDataTypeAbstract tensor data type to check

◆ kPrefillScratchByteCap

int64_t Mila::Dnn::kPrefillScratchByteCap = int64_t{ 1536 } * 1024 * 1024
inlineconstexprexport