Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute Namespace Reference

Namespaces

namespace  Cpu
namespace  Cuda
namespace  OperationNames

Classes

struct  always_false
class  BinaryOperation
class  CpuAdamWOptimizer
 CPU-specific AdamW optimizer implementation. More...
class  CpuAttentionOp
 CPU implementation of Multi-Head Attention operation. More...
class  CpuAttentionOpRegistrar
class  CpuCrossEntropyOp
 CPU implementation of the cross entropy loss operation for neural networks. More...
class  CpuCrossEntropyOpRegistrar
 Class responsible for registering the CpuCrossEntropyOp operation. More...
class  CpuDevice
 Class representing a CPU compute device. More...
class  CpuDeviceRegistrar
 CPU device plugin for device-agnostic registration. More...
class  CpuEncoderOp
 CPU implementation of the Encoder operation. More...
class  CpuEncoderOpRegistrar
 Registrar for CpuEncoderOp operation. More...
class  CpuGeluOp
 CPU implementation of GELU activation operation using abstract TensorDataType. More...
class  CpuGeluOpRegistrar
 Class responsible for registering CPU GELU operations. More...
class  CpuLayerNormOp
 CPU implementation of Layer Normalization using abstract TensorDataType API. More...
class  CpuLayerNormOpRegistrar
class  CpuLinearOp
 CPU implementation of Linear operation using abstract TensorDataType API. More...
class  CpuLinearOpRegistrar
class  CpuMemoryResource
 CPU memory resource for host-accessible memory allocation. More...
class  CpuResidualOp
 CPU Residual operation (FP32) implementing BinaryOperation interface. More...
class  CpuResidualOpRegistrar
 Registrar for CPU Residual operation (FP32). More...
class  CpuSoftmaxCrossEntropyOp
 Fused CPU implementation of Softmax + CrossEntropy using abstract TensorDataType API. More...
class  CpuSoftmaxCrossEntropyOpRegistrar
 Registrar for fused Softmax+CrossEntropy operation. More...
class  CpuSoftmaxOp
 CPU implementation of Softmax using abstract TensorDataType API. More...
class  CpuSoftmaxOpRegistrar
class  CublasLtError
class  CudaAdamWOptimizer
 CUDA-specific AdamW optimizer implementation. More...
class  CudaBadAlloc
struct  CudaDataTypeMap
 Helper struct to map C++ types to CUDA data types for cuBLASLt. More...
struct  CudaDataTypeMap< __nv_bfloat16 >
struct  CudaDataTypeMap< float >
struct  CudaDataTypeMap< half >
class  CudaDevice
 Class representing a CUDA compute device instance. More...
class  CudaDeviceMemoryResource
 CUDA device memory resource for GPU-accessible memory allocation. More...
class  CudaDeviceProps
 Wrapper for CUDA device properties with cached values. More...
class  CudaDeviceRegistrar
 CUDA device registrar for device-agnostic registration. More...
class  CudaError
 Exception class for CUDA runtime errors. More...
class  CudaManagedMemoryResource
 CUDA managed memory resource for unified host/device accessible memory. More...
class  CudaPinnedMemoryResource
 CUDA pinned memory resource for fast host/device transfer memory. More...
class  CudaTimer
 GPU-accurate interval timer using a CUDA event pair. More...
class  Device
 Abstract interface for compute device implementations. More...
struct  DeviceAccessible
class  DeviceConstructionKey
 Construction key for device factories. More...
struct  DeviceId
 Lightweight identifier for a compute device. More...
class  DeviceRegistrar
 Device-agnostic registrar for automatic device discovery and registration. More...
class  DeviceRegistry
 Registry of discovered compute devices with lazy instantiation. More...
struct  DeviceTypeTraits
struct  DeviceTypeTraits< DeviceType::Cpu >
 DeviceTypeTraits specialization for the CPU device. More...
struct  DeviceTypeTraits< DeviceType::Cuda >
 DeviceTypeTraits specialization for the CUDA device. More...
class  ExecutionContext
 Templated execution context for device-specific operations. More...
class  ExecutionContext< DeviceType::Cpu >
 CPU execution context specialization. More...
class  ExecutionContext< DeviceType::Cuda >
 CUDA execution context specialization. More...
class  ExecutionContext< DeviceType::Metal >
 Metal execution context specialization. More...
class  ExecutionContext< DeviceType::Vulkan >
 Vulkan execution context specialization. More...
struct  GqaState
 Non-owning pointers to shared transient GQA scratch buffers. More...
struct  HostAccessible
class  IExecutionContext
 Type-erased execution context interface. More...
struct  IKvCacheLifecycle
 Capability interface for KV-cache state management. More...
struct  IKvInference
 Compute interface for attention operations that maintain a KV cache. More...
struct  IPackedKvInference
 KV-cache inference interface for packed-QKV MHA backends. More...
struct  IPositionalDecode
 Capability interface for position-dependent unary operations. More...
struct  IPositionalPairedOp
 Capability interface for position-dependent paired operations. More...
struct  LinearOpTypeMap< DeviceType::Cpu, TensorDataType::FP32 >
class  MemoryResource
 Clean memory resource abstraction for device-specific memory allocation. More...
struct  MemoryResourceTraits
 Memory resource traits for compile-time dispatch optimization. More...
struct  MemoryResourceTraits< CpuMemoryResource >
 CPU-specific memory resource traits providing detailed CPU backend characteristics. More...
struct  MemoryResourceTraits< CudaDeviceMemoryResource >
 CUDA device memory resource traits providing detailed GPU backend characteristics. More...
struct  MemoryResourceTraits< CudaManagedMemoryResource >
 CUDA managed memory resource traits providing unified memory characteristics. More...
struct  MemoryResourceTraits< CudaPinnedMemoryResource >
 CUDA pinned memory resource traits providing fast transfer characteristics. More...
struct  MemoryStats
 Global memory statistics for all TrackedMemoryResource instances. More...
class  MetalDevice
 Class representing a Metal compute device instance. More...
class  MetalDevicePlugin
 Metal device plugin for device-agnostic registration. More...
class  MetalMemoryResource
 Stub implementation for non-Apple platforms. More...
class  Operation
class  OperationRegistry
 Central registry for typed, device-aware compute operations. More...
class  OperationsRegistrar
 Class to manage compute operations initialization. More...
struct  OperationTraits
 Primary traits template for unified compile-time operation dispatch. More...
struct  OperationTraits< OperationType::CrossEntropyOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::CrossEntropyOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::GeluOp, DeviceType::Cpu, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::GeluOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::GeluOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::GroupedQueryAttentionOp, DeviceType::Cuda, TensorDataType::BF16, NoKvCompression >
 Unquantized BF16 path. No KV cache compression. Standard inference precision. More...
struct  OperationTraits< OperationType::GroupedQueryAttentionOp, DeviceType::Cuda, TensorDataType::FP32, NoKvCompression >
 Unquantized FP32 path. No KV cache compression. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cpu, TensorDataType::FP32, NoWeightQuant >
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, NoWeightQuant >
 Unquantized BF16 path. Standard inference precision. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, PerChannelFp8<> >
 FP8 per-channel quantized BF16 path. Requires SM >= 8.0 (Ampere+). More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, PerGroupFp4< 128 > >
 FP4 E2M1 per-group quantized BF16 path. W4A16 fused GEMM with E2M1 decode, group_size=128. Requires SM >= 8.0. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, PerGroupFp4< 64 > >
 FP4 E2M1 per-group quantized BF16 path. W4A16 fused GEMM with E2M1 decode, group_size=64. Requires SM >= 8.0. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, PerGroupInt4< 128 > >
 INT4 per-group quantized BF16 path. W4A16 fused GEMM, group_size=128. Requires SM >= 8.0. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::BF16, PerGroupInt4< 64 > >
 INT4 per-group quantized BF16 path. W4A16 fused GEMM, group_size=64. Requires SM >= 8.0. More...
struct  OperationTraits< OperationType::LinearOp, DeviceType::Cuda, TensorDataType::FP32, NoWeightQuant >
 Unquantized FP32 path. Retained for validation and reference. More...
struct  OperationTraits< OperationType::LpeOp, DeviceType::Cpu, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::LpeOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::LpeOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::MultiHeadAttentionOp, DeviceType::Cpu, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::MultiHeadAttentionOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::MultiHeadAttentionOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::ResidualOp, DeviceType::Cpu, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::ResidualOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::ResidualOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::RmsNormOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::RmsNormOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::RopeOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::RopeOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::SoftmaxOp, DeviceType::Cpu, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::SoftmaxOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::SoftmaxOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::SwigluOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::SwigluOp, DeviceType::Cuda, TensorDataType::FP32, void >
struct  OperationTraits< OperationType::TokenEmbeddingOp, DeviceType::Cuda, TensorDataType::BF16, void >
struct  OperationTraits< OperationType::TokenEmbeddingOp, DeviceType::Cuda, TensorDataType::FP32, void >
class  Optimizer
 Abstract base class for parameter optimizers. More...
class  PairedOperation
 Abstract base for paired operations: two inputs -> two outputs. More...
class  TrackedMemoryResource
 A memory resource wrapper that tracks allocation and deallocation statistics. More...
class  UnaryOperation
class  VulkanDevice
 Class representing a Vulkan compute device instance. More...
class  VulkanMemoryResource
 Stub implementation for platforms without Vulkan support. More...

Concepts

concept  GqaOpConcept
 Contract for GroupedQueryAttentionOp: positional forward and backward.
concept  IsAcceleratorMemoryResource
 Concept for identifying compute accelerator memory resources.
concept  IsCacheCoherent
 Concept for cache-coherent memory resources.
concept  IsDeviceMemoryResource
 Concept for identifying device-based memory resources.
concept  IsHighBandwidth
 Concept for memory resources with high bandwidth characteristics.
concept  IsHostMemoryResource
 Concept for identifying host-accessible memory resources.
concept  LinearOpConcept
 Contract for LinearOp: typed forward matmul and backward weight/input gradients.
concept  OptimizedForCoalescing
 Concept for memory resources optimized for coalesced access patterns.
concept  RequiresContextBinding
 Concept for memory resources requiring CUDA context binding.
concept  SamplingOpConcept
 Contract for SamplingOp: in-place token sampling from a logits tensor.
concept  SupportsConcurrentKernels
 Concept for memory resources supporting concurrent kernel execution.
concept  SupportsPeerAccess
 Concept for memory resources optimized for device-to-device transfers.
concept  SupportsSIMD
 Concept for CPU memory resources with SIMD support.
concept  SupportsTextureMemory
 Concept for memory resources supporting texture memory access.
concept  SupportsThreading
 Concept for memory resources with threading support.
concept  SupportsUnifiedMemory
 Concept for CUDA memory resources with unified memory support.
concept  UnaryOpConcept
 Contract for policy-free unary ops (Softmax, RmsNorm, LayerNorm, Residual, ...).

Typedefs

using AdamWConfig = Mila::Dnn::Optimizers::AdamWConfig
using Mila::Dnn::Compute::CpuExecutionContext = ExecutionContext<DeviceType::Cpu>
using Mila::Dnn::Compute::CudaExecutionContext = ExecutionContext<DeviceType::Cuda>
using Mila::Dnn::Compute::HostMemoryResource = CpuMemoryResource
 Alias for CpuMemoryResource that represents host-accessible memory.
using OptimizerConfig = Mila::Dnn::Optimizers::AdamWConfig

Enumerations

enum class  Mila::Dnn::Compute::DeviceType { Cpu , Cuda , Metal , Rocm }
 Enumeration of supported compute device types. More...
enum class  Mila::Dnn::Compute::OperationType {
  CrossEntropyOp , TokenEmbeddingOp , LpeOp , RopeOp ,
  FusedOp , LinearOp , GeluOp , SwigluOp ,
  LayerNormOp , RmsNormOp , MultiHeadAttentionOp , GroupedQueryAttentionOp ,
  ResidualOp , SoftmaxOp , DropoutOp , SamplingOp ,
  SoftmaxCrossEntropyOp
}
 Enumeration of all supported neural network operation types. More...

Functions

template<typename Tp, typename Tg>
void adamw_update (Tp *params_memory, float *master_params_memory, Tg *grads_memory, float *m_memory, float *v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, float grad_scale, unsigned int seed, cudaStream_t stream)
template<DeviceType TDeviceType>
const ExecutionContext< TDeviceType > * Mila::Dnn::Compute::cast_context_ (const IExecutionContext *ctx) noexcept
template<DeviceType TDeviceType>
ExecutionContext< TDeviceType > * Mila::Dnn::Compute::cast_context_ (IExecutionContext *ctx) noexcept
std::unique_ptr< IExecutionContextMila::Dnn::Compute::createExecutionContext (DeviceId device_id)
 Create execution context for specified device.
void Mila::Dnn::Compute::cublasLtCheckStatus (cublasStatus_t status, const std::source_location &location=std::source_location::current())
 Checks the status of a cuBLASLt operation and throws if an error occurred.
void cuda_mha_forward_fp16 (half *Y, half *qkvr, half *att, const half *X, int B, int T, int C, int NH, cudaStream_t stream)
void cuda_mha_forward_fp32 (float *Y, float *qkvr, float *att, const float *X, int B, int T, int C, int NH, cudaStream_t stream)
template<typename TPrecision>
void cuda_softmax_crossentropy_backward (TPrecision *dX, const TPrecision *dY_loss, const TPrecision *Y, const int *targets, int batch_size, int seq_len, int vocab_size, cudaStream_t stream)
template<typename TPrecision>
void cuda_softmax_crossentropy_forward (TPrecision *Y_loss, TPrecision *Y, const TPrecision *X, const int *targets, int batch_size, int seq_len, int vocab_size, cudaStream_t stream)
void Mila::Dnn::Compute::cudaCheckLastError (const std::source_location &location=std::source_location::current())
 Checks the last CUDA error and throws if an error occurred.
void Mila::Dnn::Compute::cudaCheckStatus (cudaError_t status, const std::source_location &location=std::source_location::current())
 Checks the status of a CUDA operation and throws if an error occurred.
std::string Mila::Dnn::Compute::deviceTypeToString (DeviceType device_type)
 Converts a DeviceType to its string representation.
DeviceId Mila::Dnn::Compute::getBestDevice (DeviceType type, bool preferMemory=false)
 Gets the best DeviceId of a specific type based on performance characteristics.
std::size_t Mila::Dnn::Compute::getDeviceCount (DeviceType type) noexcept
 Count instantiated devices of the given DeviceType.
template<DeviceType TDevice, TensorDataType TDataType>
std::vector< std::string > Mila::Dnn::Compute::getRegisteredOperations ()
 Templated helper returning registered operation names for a compile-time device and tensor data type.
template<DeviceType TDevice, TensorDataType TDataType>
bool Mila::Dnn::Compute::isOperationRegistered (const std::string &operation_name)
 Templated helper that checks whether a named operation is registered for a compile-time device and tensor data type.
std::vector< std::string > Mila::Dnn::Compute::listDevicesByName ()
 Lists all available compute devices by name.
std::vector< std::string > Mila::Dnn::Compute::listDevicesByType (DeviceType type)
 Lists compute devices of a specific type.
template<typename OpType, DeviceType DT, typename ConfigT>
std::shared_ptr< OpType > makeOpInstance (IExecutionContext *ctx, const ConfigT &cfg)
 Attempt to construct an OpType instance from a raw IExecutionContext*.
std::string_view Mila::Dnn::Compute::operationTypeToString (OperationType op)
template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TA, Dnn::TensorDataType TB, Dnn::TensorDataType TP = TA>
void Mila::Dnn::Compute::registerBinaryOpType (const std::string &opName)
 Register a binary operation type with OperationRegistry using a common factory pattern.
template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TA, Dnn::TensorDataType TB = TA, Dnn::TensorDataType TP = TA>
void Mila::Dnn::Compute::registerPairedOpType (const std::string &opName)
 Register a paired operation type with OperationRegistry using a common factory pattern.
template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TInput, Dnn::TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::registerUnaryOpType (std::string_view op_name)
 Register a unary operation type with OperationRegistry using a common factory pattern.
DeviceType Mila::Dnn::Compute::toDeviceType (std::string_view device_type)
 Converts a string to the corresponding DeviceType.
template<DeviceType TDeviceType>
ExecutionContext< TDeviceType > * Mila::Dnn::Compute::validateExecutionContext_ (IExecutionContext *context, const std::string &op_name)

Variables

constexpr float GELU_SCALING_FACTOR = 0.7978845608f
 Scaling factor for GELU tanh approximation: sqrt(2/pi).

Typedef Documentation

◆ AdamWConfig

◆ CpuExecutionContext

◆ CudaExecutionContext

◆ HostMemoryResource

Alias for CpuMemoryResource that represents host-accessible memory.

This alias provides a semantic name that describes the memory's accessibility characteristics rather than its implementation details. Use HostMemoryResource when you need memory that can be directly accessed from host (CPU) code.

See also
CpuMemoryResource

◆ OptimizerConfig

Enumeration Type Documentation

◆ DeviceType

enum class Mila::Dnn::Compute::DeviceType
exportstrong

Enumeration of supported compute device types.

Defines the types of compute devices that can be used for tensor operations and neural network operations.

Enumerator
Cpu 

CPU device type.

Cuda 

CUDA GPU device type.

Metal 
Rocm 

◆ OperationType

enum class Mila::Dnn::Compute::OperationType
exportstrong

Enumeration of all supported neural network operation types.

This enumeration defines the different types of operations that can be executed by the compute framework. Each operation type corresponds to a specific neural network function or layer.

Enumerator
CrossEntropyOp 

Cross entropy loss operation (host-based; used by GPT reference implementation).

TokenEmbeddingOp 

Token embedding operation.

LpeOp 

Learned Positional Embedding operation for transformer architecture.

RopeOp 

Rotary Position Embedding operation for transformer architecture.

FusedOp 

Fused operation combining multiple operations for performance optimization.

LinearOp 

Linear (fully connected/dense) layer operation.

GeluOp 

Gaussian Error Linear Unit activation function.

SwigluOp 

SwiGLU activation function.

LayerNormOp 

Layer normalization operation.

RmsNormOp 

RMS normalization operation.

MultiHeadAttentionOp 

Multi-head attention operation (MHA) for transformers.

GroupedQueryAttentionOp 

Grouped Query Attention (GQA).

ResidualOp 

Residual connection operation.

SoftmaxOp 

Softmax activation function.

DropoutOp 

Dropout regularization operation.

SamplingOp 

Device-side token sampling from logits.

SoftmaxCrossEntropyOp 

WIP: Fused softmax + cross-entropy loss — targeted for Llama training.

Function Documentation

◆ adamw_update()

template<typename Tp, typename Tg>
void Mila::Dnn::Compute::adamw_update ( Tp * params_memory,
float * master_params_memory,
Tg * grads_memory,
float * m_memory,
float * v_memory,
size_t num_parameters,
ptrdiff_t w_stride,
ptrdiff_t g_stride,
ptrdiff_t s_stride,
int num_slices,
float learning_rate,
float beta1,
float beta2,
int t,
float eps,
float weight_decay,
float grad_scale,
unsigned int seed,
cudaStream_t stream )
Here is the caller graph for this function:

◆ cast_context_() [1/2]

template<DeviceType TDeviceType>
const ExecutionContext< TDeviceType > * Mila::Dnn::Compute::cast_context_ ( const IExecutionContext * ctx)
nodiscardexportnoexcept

◆ cast_context_() [2/2]

template<DeviceType TDeviceType>
ExecutionContext< TDeviceType > * Mila::Dnn::Compute::cast_context_ ( IExecutionContext * ctx)
nodiscardexportnoexcept
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Compute/ExecutionContext.ixx.
Here is the caller graph for this function:

◆ createExecutionContext()

std::unique_ptr< IExecutionContext > Mila::Dnn::Compute::createExecutionContext ( DeviceId device_id)
export

Create execution context for specified device.

Factory function returning type-erased IExecutionContext. Hides device-specific implementation details from users.

Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ cublasLtCheckStatus()

void Mila::Dnn::Compute::cublasLtCheckStatus ( cublasStatus_t status,
const std::source_location & location = std::source_location::current() )
inlineexport

Checks the status of a cuBLASLt operation and throws if an error occurred.

Parameters
statusThe cuBLASLt error status code to check.
locationSource location information (automatically populated by default).
Exceptions
CublasLtErrorif the status is not CUBLAS_STATUS_SUCCESS.
Here is the caller graph for this function:

◆ cuda_mha_forward_fp16()

void Mila::Dnn::Compute::cuda_mha_forward_fp16 ( half * Y,
half * qkvr,
half * att,
const half * X,
int B,
int T,
int C,
int NH,
cudaStream_t stream )

◆ cuda_mha_forward_fp32()

void Mila::Dnn::Compute::cuda_mha_forward_fp32 ( float * Y,
float * qkvr,
float * att,
const float * X,
int B,
int T,
int C,
int NH,
cudaStream_t stream )

◆ cuda_softmax_crossentropy_backward()

template<typename TPrecision>
void Mila::Dnn::Compute::cuda_softmax_crossentropy_backward ( TPrecision * dX,
const TPrecision * dY_loss,
const TPrecision * Y,
const int * targets,
int batch_size,
int seq_len,
int vocab_size,
cudaStream_t stream )

◆ cuda_softmax_crossentropy_forward()

template<typename TPrecision>
void Mila::Dnn::Compute::cuda_softmax_crossentropy_forward ( TPrecision * Y_loss,
TPrecision * Y,
const TPrecision * X,
const int * targets,
int batch_size,
int seq_len,
int vocab_size,
cudaStream_t stream )

◆ cudaCheckLastError()

void Mila::Dnn::Compute::cudaCheckLastError ( const std::source_location & location = std::source_location::current())
export

Checks the last CUDA error and throws if an error occurred.

Parameters
locationSource location information (automatically populated by default).
Exceptions
CudaErrorif the last error is not cudaSuccess.
Here is the caller graph for this function:

◆ cudaCheckStatus()

void Mila::Dnn::Compute::cudaCheckStatus ( cudaError_t status,
const std::source_location & location = std::source_location::current() )
export

Checks the status of a CUDA operation and throws if an error occurred.

Parameters
statusThe CUDA error status code to check.
locationSource location information (automatically populated by default).
Exceptions
CudaErrorif the status is not cudaSuccess.
Here is the caller graph for this function:

◆ deviceTypeToString()

std::string Mila::Dnn::Compute::deviceTypeToString ( DeviceType device_type)
export

Converts a DeviceType to its string representation.

Parameters
device_typeThe device type to convert.
Returns
std::string The string representation of the device type (Device::Cpu() or "CUDA").
Exceptions
std::invalid_argumentIf the device type is invalid.
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Components/Activations/Gelu/Gelu.ixx, and /__w/Mila/Mila/Mila/Src/Dnn/Components/Normalization/Softmax.ixx.
Here is the caller graph for this function:

◆ getBestDevice()

DeviceId Mila::Dnn::Compute::getBestDevice ( DeviceType type,
bool preferMemory = false )
export

Gets the best DeviceId of a specific type based on performance characteristics.

For CUDA the function consults Cuda::getBestDeviceId(preferMemory) and returns the matching DeviceId when available. If no device of the requested type is found the function returns a DeviceId with the requested type and index == -1 to indicate "none".

Parameters
typeThe device type to filter by (e.g., DeviceType::Cuda)
preferMemoryWhen true, prioritizes memory bandwidth over compute capability
Returns
DeviceId Best available device id for the requested type, or index == -1 if none
Here is the call graph for this function:

◆ getDeviceCount()

std::size_t Mila::Dnn::Compute::getDeviceCount ( DeviceType type)
exportnoexcept

Count instantiated devices of the given DeviceType.

Lightweight, thread-safe convenience helper intended for tests and diagnostics. Non-throwing; returns 0 on error.

Here is the call graph for this function:

◆ getRegisteredOperations()

template<DeviceType TDevice, TensorDataType TDataType>
std::vector< std::string > Mila::Dnn::Compute::getRegisteredOperations ( )
export

Templated helper returning registered operation names for a compile-time device and tensor data type.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ isOperationRegistered()

template<DeviceType TDevice, TensorDataType TDataType>
bool Mila::Dnn::Compute::isOperationRegistered ( const std::string & operation_name)
export

Templated helper that checks whether a named operation is registered for a compile-time device and tensor data type.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ listDevicesByName()

std::vector< std::string > Mila::Dnn::Compute::listDevicesByName ( )
export

Lists all available compute devices by name.

Returns
std::vector<std::string> A list of device identifiers (e.g., "CUDA:0", "CPU:0").
Here is the call graph for this function:

◆ listDevicesByType()

std::vector< std::string > Mila::Dnn::Compute::listDevicesByType ( DeviceType type)
export

Lists compute devices of a specific type.

Parameters
typeThe device type to filter by
Returns
std::vector<std::string> List of matching device identifiers
Here is the call graph for this function:

◆ makeOpInstance()

template<typename OpType, DeviceType DT, typename ConfigT>
std::shared_ptr< OpType > Mila::Dnn::Compute::makeOpInstance ( IExecutionContext * ctx,
const ConfigT & cfg )

Attempt to construct an OpType instance from a raw IExecutionContext*.

Supports three constructor shapes (in priority order):

  • OpType(IExecutionContext*, const ConfigT&)
  • OpType(ExecutionContext<DT>*, const ConfigT&)
  • OpType(std::shared_ptr<ExecutionContext<DT>>, const ConfigT&)

Throws std::invalid_argument when a required dynamic cast fails.

Here is the caller graph for this function:

◆ operationTypeToString()

std::string_view Mila::Dnn::Compute::operationTypeToString ( OperationType op)
export

◆ registerBinaryOpType()

template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TA, Dnn::TensorDataType TB, Dnn::TensorDataType TP = TA>
void Mila::Dnn::Compute::registerBinaryOpType ( const std::string & opName)
export

Register a binary operation type with OperationRegistry using a common factory pattern.

Template parameter ordering:

  • TDataType : DeviceType
  • OpType : Concrete operation class (must define using ConfigType = ...)
  • TA : Input A precision
  • TB : Input B precision
  • TP : Compute/output precision (defaults to TA)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ registerPairedOpType()

template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TA, Dnn::TensorDataType TB = TA, Dnn::TensorDataType TP = TA>
void Mila::Dnn::Compute::registerPairedOpType ( const std::string & opName)
export

Register a paired operation type with OperationRegistry using a common factory pattern.

Template parameter ordering:

  • TDataType : DeviceType
  • OpType : Concrete operation class (must define using ConfigType = ...)
  • TA : Input A precision
  • TB : Input B precision (defaults to TA)
  • TP : Compute/output precision (defaults to TA)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ registerUnaryOpType()

template<DeviceType TDataType, typename OpType, Dnn::TensorDataType TInput, Dnn::TensorDataType TPrecision = TInput>
void Mila::Dnn::Compute::registerUnaryOpType ( std::string_view op_name)
export

Register a unary operation type with OperationRegistry using a common factory pattern.

Template parameter ordering and names:

  • TDataType : DeviceType
  • OpType : Concrete operation class (must define using ConfigType = ...)
  • TInput : Abstract input tensor precision
  • TPrecision: Compute/output precision (defaults to TInput)

The factory lambda casts ComponentConfig -> OpType::ConfigType and forwards the IExecutionContext* + concrete config to the operation constructor via makeOpInstance.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ toDeviceType()

DeviceType Mila::Dnn::Compute::toDeviceType ( std::string_view device_type)
export

Converts a string to the corresponding DeviceType.

Performs case-insensitive matching to convert device type strings to the corresponding enum value.

Parameters
device_typeThe string representation of the device type.
Returns
DeviceType The corresponding device type enum value.
Exceptions
std::invalid_argumentIf the string does not represent a valid device type. Valid options are: Device::Cpu(), "CUDA", "AUTO".
Here is the caller graph for this function:

◆ validateExecutionContext_()

template<DeviceType TDeviceType>
ExecutionContext< TDeviceType > * Mila::Dnn::Compute::validateExecutionContext_ ( IExecutionContext * context,
const std::string & op_name )
export
Examples
/__w/Mila/Mila/Mila/Src/Dnn/Compute/ExecutionContext.ixx.
Here is the call graph for this function:
Here is the caller graph for this function:

Variable Documentation

◆ GELU_SCALING_FACTOR

float Mila::Dnn::Compute::GELU_SCALING_FACTOR = 0.7978845608f
constexpr

Scaling factor for GELU tanh approximation: sqrt(2/pi).

Used in the tanh approximation formula: GELU(x) ~= 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3)))