Validate and cast IExecutionContext to device-specific execution context.
Validate and cast IExecutionContext to device-specific execution context.Generic helper for operation constructors. Validates that the provided context matches the expected device type and casts it to the concrete type.
- Template Parameters
-
| TDeviceType | The expected device type |
- Parameters
-
| context | The execution context to validate |
| op_name | Operation name for error messages |
- Returns
- Validated and cast execution context
- Exceptions
-
| std::invalid_argument | if context is null or device type doesn't match |
CudaGeluOp(IExecutionContext* context, const GeluConfig& config) : cuda_context_(validateExecutionContext<DeviceType::Cuda>(context, "CudaGeluOp")) , config_(config) {}
module;
#ifdef USE_CUDNN
#include <cudnn.h>
#endif
#include <cassert>
#include <string>
#include <stdexcept>
export import :Cpu;
#ifdef MILA_HAS_CUDA
export import :Cuda;
#endif
#ifdef MILA_HAS_METAL
export import :Metal;
#endif
#ifdef MILA_HAS_ROCM
export import :Rocm;
#endif
{
export template<DeviceType TDeviceType>
[[nodiscard]] ExecutionContext<TDeviceType>*
cast_context_( IExecutionContext* ctx )
noexcept
{
if ( !ctx )
return nullptr;
assert( ctx->getDeviceId().type == TDeviceType && "Device type mismatch in context cast" );
return static_cast<ExecutionContext<TDeviceType>*>(ctx);
}
export template<DeviceType TDeviceType>
[[nodiscard]]
const ExecutionContext<TDeviceType>*
cast_context_(
const IExecutionContext* ctx )
noexcept
{
if ( !ctx )
return nullptr;
assert( ctx->getDeviceId().type == TDeviceType && "Device type mismatch in context cast" );
return static_cast<const ExecutionContext<TDeviceType>*>(ctx);
}
export template<DeviceType TDeviceType>
IExecutionContext* context,
const std::string& op_name )
{
if ( !context ) {
throw std::invalid_argument( "{} requires a non-null execution context" );
}
if ( context->getDeviceId().type != TDeviceType ) {
throw std::invalid_argument( "{} requires {} execution context, got {}" );
}
return static_cast<ExecutionContext<TDeviceType>*>(context);
}
}
ExecutionContext< TDeviceType > * cast_context_(IExecutionContext *ctx) noexcept
Definition ExecutionContext.ixx:58
ExecutionContext< TDeviceType > * validateExecutionContext_(IExecutionContext *context, const std::string &op_name)
Definition ExecutionContext.ixx:102