Mila 0.13.48
Deep Neural Network Library
Loading...
Searching...
No Matches
/__w/Mila/Mila/Mila/Src/Dnn/Compute/ExecutionContext.ixx

Validate and cast IExecutionContext to device-specific execution context.

Validate and cast IExecutionContext to device-specific execution context.Generic helper for operation constructors. Validates that the provided context matches the expected device type and casts it to the concrete type.

Template Parameters
TDeviceTypeThe expected device type
Parameters
contextThe execution context to validate
op_nameOperation name for error messages
Returns
Validated and cast execution context
Exceptions
std::invalid_argumentif context is null or device type doesn't match

CudaGeluOp(IExecutionContext* context, const GeluConfig& config) : cuda_context_(validateExecutionContext<DeviceType::Cuda>(context, "CudaGeluOp")) , config_(config) {}

module;
//#include <cuda_runtime.h>
//#include <cublasLt.h>
#ifdef USE_CUDNN
#include <cudnn.h>
#endif
#include <cassert>
//#include <memory>
#include <string>
#include <stdexcept>
//#include <format>
export import :Cpu;
#ifdef MILA_HAS_CUDA
export import :Cuda;
#endif
#ifdef MILA_HAS_METAL
export import :Metal;
#endif
#ifdef MILA_HAS_ROCM
export import :Rocm;
#endif
{
export template<DeviceType TDeviceType>
[[nodiscard]] ExecutionContext<TDeviceType>* cast_context_( IExecutionContext* ctx ) noexcept
{
if ( !ctx )
return nullptr;
assert( ctx->getDeviceId().type == TDeviceType && "Device type mismatch in context cast" );
return static_cast<ExecutionContext<TDeviceType>*>(ctx);
}
export template<DeviceType TDeviceType>
[[nodiscard]] const ExecutionContext<TDeviceType>* cast_context_( const IExecutionContext* ctx ) noexcept
{
if ( !ctx )
return nullptr;
assert( ctx->getDeviceId().type == TDeviceType && "Device type mismatch in context cast" );
return static_cast<const ExecutionContext<TDeviceType>*>(ctx);
}
export template<DeviceType TDeviceType>
ExecutionContext<TDeviceType>* validateExecutionContext_(
IExecutionContext* context,
const std::string& op_name )
{
if ( !context ) {
throw std::invalid_argument( "{} requires a non-null execution context" );
// FIXME
// std::format( "{} requires a non-null execution context", op_name )
//);
}
if ( context->getDeviceId().type != TDeviceType ) {
throw std::invalid_argument( "{} requires {} execution context, got {}" );
// FIXME:
// std::format( "{} requires {} execution context, got {}",
// op_name,
// deviceTypeToString( TDeviceType ),
// deviceTypeToString( context->getDeviceId().type ) )
//);
}
return static_cast<ExecutionContext<TDeviceType>*>(context);
}
}
Definition Device.ixx:15
ExecutionContext< TDeviceType > * cast_context_(IExecutionContext *ctx) noexcept
Definition ExecutionContext.ixx:58
ExecutionContext< TDeviceType > * validateExecutionContext_(IExecutionContext *context, const std::string &op_name)
Definition ExecutionContext.ixx:102