Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
Mila::Dnn::Compute::Cuda::RandomOps Struct Referenceexport
Inheritance diagram for Mila::Dnn::Compute::Cuda::RandomOps:

Static Public Member Functions

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource> && TensorDataTypeTraits<TDataType>::is_float_type
static void fill_normal (Tensor< TDataType, TMemoryResource > &tensor, float mean, float stddev, IExecutionContext *exec_context=nullptr)
 Fill a float tensor with values drawn from N(mean, stddev^2) using cuRAND.
template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource> && TensorDataTypeTraits<TDataType>::is_float_type
static void fill_uniform (Tensor< TDataType, TMemoryResource > &tensor, float min_val, float max_val, IExecutionContext *exec_context=nullptr)
 Fill a float tensor with uniform values in [min_val, max_val) using cuRAND.

Static Private Member Functions

static curandGenerator_t make_temp_generator_ (cudaStream_t stream)

Member Function Documentation

◆ fill_normal()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource> && TensorDataTypeTraits<TDataType>::is_float_type
void Mila::Dnn::Compute::Cuda::RandomOps::fill_normal ( Tensor< TDataType, TMemoryResource > & tensor,
float mean,
float stddev,
IExecutionContext * exec_context = nullptr )
inlinestatic

Fill a float tensor with values drawn from N(mean, stddev^2) using cuRAND.

When exec_context is provided, the cached generator from CudaExecutionContext is reused. Without a context, a temporary generator is created and destroyed per call, seeded from Core::RandomGenerator.

curandGenerateNormal requires an even element count (Box-Muller pairs). If the tensor has an odd element count, a temporary device buffer of size n+1 is allocated, the full even count is generated into it, and n elements are copied to the tensor. This is zero-overhead for even-sized tensors.

Template Parameters
TDataTypeFloating-point tensor data type.
TMemoryResourceCUDA memory resource type.
Parameters
tensorDestination tensor (pre-allocated).
meanMean of the normal distribution.
stddevStandard deviation of the normal distribution.
exec_contextOptional execution context for stream and generator reuse (borrowed, not owned).
Note
Only FP32 native tensors are currently supported. FP16/BF16 require a temporary float buffer with a subsequent conversion pass (not yet implemented).
Exceptions
std::runtime_errorOn cuRAND or CUDA failure.
Here is the call graph for this function:

◆ fill_uniform()

template<TensorDataType TDataType, typename TMemoryResource>
requires isValidTensor<TDataType, TMemoryResource> && TensorDataTypeTraits<TDataType>::is_float_type
void Mila::Dnn::Compute::Cuda::RandomOps::fill_uniform ( Tensor< TDataType, TMemoryResource > & tensor,
float min_val,
float max_val,
IExecutionContext * exec_context = nullptr )
inlinestatic

Fill a float tensor with uniform values in [min_val, max_val) using cuRAND.

cuRAND generates values in [0, 1), which are scaled and shifted to [min_val, max_val) by a device kernel. When exec_context is provided, the cached generator is reused. Without a context, a temporary generator is created and destroyed per call.

Template Parameters
TDataTypeFloating-point tensor data type.
TMemoryResourceCUDA memory resource type.
Parameters
tensorDestination tensor (pre-allocated).
min_valLower bound of the uniform range (inclusive).
max_valUpper bound of the uniform range (exclusive).
exec_contextOptional execution context for stream and generator reuse (borrowed, not owned).
Note
Only FP32 native tensors are currently supported.
Exceptions
std::runtime_errorOn cuRAND or CUDA failure.
Here is the call graph for this function:

◆ make_temp_generator_()

curandGenerator_t Mila::Dnn::Compute::Cuda::RandomOps::make_temp_generator_ ( cudaStream_t stream)
inlinestaticprivate
Here is the call graph for this function:
Here is the caller graph for this function:

The documentation for this struct was generated from the following file: