Skip to content

Commit cebad02

Browse files
committed
[UR][SYCL] Introduce UR api to set kernel args + launch in one call.
1 parent 3a0acb8 commit cebad02

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2576
-82
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 154 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,12 +2303,11 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303
}
23042304
}
23052305

2306-
void SetArgBasedOnType(
2307-
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
2306+
void GetUrArgsBasedOnType(
23082307
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
23092308
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
23102309
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2311-
size_t NextTrueIndex) {
2310+
size_t NextTrueIndex, std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23122311
switch (Arg.MType) {
23132312
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23142313
break;
@@ -2328,52 +2327,63 @@ void SetArgBasedOnType(
23282327
getMemAllocationFunc
23292328
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
23302329
: nullptr;
2331-
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2332-
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2333-
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2334-
Adapter->call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2335-
&MemObjData, MemArg);
2330+
ur_exp_kernel_arg_value_t Value = {};
2331+
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
2332+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2333+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2334+
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
2335+
Value});
23362336
break;
23372337
}
23382338
case kernel_param_kind_t::kind_std_layout: {
2339+
ur_exp_kernel_arg_type_t Type;
23392340
if (Arg.MPtr) {
2340-
Adapter->call<UrApiKind::urKernelSetArgValue>(
2341-
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2341+
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23422342
} else {
2343-
Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2344-
Arg.MSize, nullptr);
2343+
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23452344
}
2345+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2346+
nullptr,
2347+
Type,
2348+
static_cast<uint32_t>(NextTrueIndex),
2349+
static_cast<size_t>(Arg.MSize),
2350+
{Arg.MPtr}});
23462351

23472352
break;
23482353
}
23492354
case kernel_param_kind_t::kind_sampler: {
23502355
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2351-
ur_sampler_handle_t Sampler =
2352-
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2353-
->getOrCreateSampler(ContextImpl);
2354-
Adapter->call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2355-
nullptr, Sampler);
2356+
ur_exp_kernel_arg_value_t Value = {};
2357+
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2358+
->getOrCreateSampler(ContextImpl);
2359+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2360+
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2361+
static_cast<uint32_t>(NextTrueIndex),
2362+
sizeof(ur_sampler_handle_t), Value});
23562363
break;
23572364
}
23582365
case kernel_param_kind_t::kind_pointer: {
2366+
void *Ptr = *static_cast<void *const *>(Arg.MPtr);
23592367
// We need to de-rerence this to get the actual USM allocation - that's the
23602368
// pointer UR is expecting.
2361-
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2362-
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2363-
nullptr, Ptr);
2369+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2370+
nullptr,
2371+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2372+
static_cast<uint32_t>(NextTrueIndex),
2373+
sizeof(Ptr),
2374+
{Ptr}});
23642375
break;
23652376
}
23662377
case kernel_param_kind_t::kind_specialization_constants_buffer: {
23672378
assert(DeviceImageImpl != nullptr);
23682379
ur_mem_handle_t SpecConstsBuffer =
23692380
DeviceImageImpl->get_spec_const_buffer_ref();
2370-
2371-
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2372-
MemObjProps.pNext = nullptr;
2373-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2374-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2375-
Adapter->call<UrApiKind::urKernelSetArgMemObj>(
2376-
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2381+
ur_exp_kernel_arg_value_t Value = {};
2382+
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2383+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2384+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2385+
static_cast<uint32_t>(NextTrueIndex),
2386+
sizeof(SpecConstsBuffer), Value});
23772387
break;
23782388
}
23792389
case kernel_param_kind_t::kind_invalid:
@@ -2408,22 +2418,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24082418
: Empty);
24092419
}
24102420

2421+
std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
2422+
UrArgs.reserve(Args.size());
2423+
24112424
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2412-
auto setFunc = [&Adapter, Kernel,
2425+
auto setFunc = [&UrArgs,
24132426
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24142427
size_t NextTrueIndex) {
24152428
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
24162429
switch (ParamDesc.kind) {
24172430
case kernel_param_kind_t::kind_std_layout: {
24182431
int Size = ParamDesc.info;
2419-
Adapter->call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2420-
Size, nullptr, ArgPtr);
2432+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2433+
nullptr,
2434+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2435+
static_cast<uint32_t>(NextTrueIndex),
2436+
static_cast<size_t>(Size),
2437+
{ArgPtr}});
24212438
break;
24222439
}
24232440
case kernel_param_kind_t::kind_pointer: {
24242441
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2425-
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2426-
nullptr, Ptr);
2442+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2443+
nullptr,
2444+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2445+
static_cast<uint32_t>(NextTrueIndex),
2446+
sizeof(Ptr),
2447+
{Ptr}});
24272448
break;
24282449
}
24292450
default:
@@ -2433,10 +2454,11 @@ static ur_result_t SetKernelParamsAndLaunch(
24332454
applyFuncOnFilteredArgs(EliminatedArgMask, KernelNumArgs,
24342455
KernelParamDescGetter, setFunc);
24352456
} else {
2436-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2437-
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2438-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2439-
Queue.getContextImplPtr(), Arg, NextTrueIndex);
2457+
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
2458+
&UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2459+
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
2460+
Queue.getContextImplPtr(), Arg, NextTrueIndex,
2461+
UrArgs);
24402462
};
24412463
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24422464
}
@@ -2449,8 +2471,12 @@ static ur_result_t SetKernelParamsAndLaunch(
24492471
// CUDA-style local memory setting. Note that we may have -1 as a position,
24502472
// this indicates the buffer is actually unused and was elided.
24512473
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
2452-
Adapter->call<UrApiKind::urKernelSetArgLocal>(
2453-
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
2474+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2475+
nullptr,
2476+
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2477+
static_cast<uint32_t>(ImplicitLocalArg.value()),
2478+
WorkGroupMemorySize,
2479+
{nullptr}});
24542480
}
24552481

24562482
adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
@@ -2508,20 +2534,103 @@ static ur_result_t SetKernelParamsAndLaunch(
25082534
{{WorkGroupMemorySize}}});
25092535
}
25102536
ur_event_handle_t UREvent = nullptr;
2511-
ur_result_t Error = Adapter->call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
2512-
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2513-
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
2514-
LocalSize, property_list.size(),
2515-
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
2516-
RawEvents.empty() ? nullptr : &RawEvents[0],
2517-
OutEventImpl ? &UREvent : nullptr);
2537+
ur_result_t Error =
2538+
Adapter->call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2539+
Queue.getHandleRef(), Kernel,
2540+
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
2541+
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
2542+
property_list.size(),
2543+
property_list.empty() ? nullptr : property_list.data(),
2544+
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
2545+
OutEventImpl ? &UREvent : nullptr);
25182546
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25192547
OutEventImpl->setHandle(UREvent);
25202548
}
25212549

25222550
return Error;
25232551
}
25242552

2553+
void SetArgBasedOnType(
2554+
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
2555+
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
2556+
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2557+
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2558+
size_t NextTrueIndex) {
2559+
switch (Arg.MType) {
2560+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2561+
break;
2562+
case kernel_param_kind_t::kind_work_group_memory:
2563+
break;
2564+
case kernel_param_kind_t::kind_stream:
2565+
break;
2566+
case kernel_param_kind_t::kind_dynamic_accessor:
2567+
case kernel_param_kind_t::kind_accessor: {
2568+
Requirement *Req = (Requirement *)(Arg.MPtr);
2569+
2570+
// getMemAllocationFunc is nullptr when there are no requirements. However,
2571+
// we may pass default constructed accessors to a command, which don't add
2572+
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2573+
// valid case, so we need to properly handle it.
2574+
ur_mem_handle_t MemArg =
2575+
getMemAllocationFunc
2576+
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
2577+
: nullptr;
2578+
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2579+
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2580+
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2581+
Adapter->call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2582+
&MemObjData, MemArg);
2583+
break;
2584+
}
2585+
case kernel_param_kind_t::kind_std_layout: {
2586+
if (Arg.MPtr) {
2587+
Adapter->call<UrApiKind::urKernelSetArgValue>(
2588+
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2589+
} else {
2590+
Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2591+
Arg.MSize, nullptr);
2592+
}
2593+
2594+
break;
2595+
}
2596+
case kernel_param_kind_t::kind_sampler: {
2597+
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2598+
ur_sampler_handle_t Sampler =
2599+
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2600+
->getOrCreateSampler(ContextImpl);
2601+
Adapter->call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2602+
nullptr, Sampler);
2603+
break;
2604+
}
2605+
case kernel_param_kind_t::kind_pointer: {
2606+
// We need to de-rerence this to get the actual USM allocation - that's the
2607+
// pointer UR is expecting.
2608+
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2609+
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2610+
nullptr, Ptr);
2611+
break;
2612+
}
2613+
case kernel_param_kind_t::kind_specialization_constants_buffer: {
2614+
assert(DeviceImageImpl != nullptr);
2615+
ur_mem_handle_t SpecConstsBuffer =
2616+
DeviceImageImpl->get_spec_const_buffer_ref();
2617+
2618+
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2619+
MemObjProps.pNext = nullptr;
2620+
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2621+
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2622+
Adapter->call<UrApiKind::urKernelSetArgMemObj>(
2623+
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2624+
break;
2625+
}
2626+
case kernel_param_kind_t::kind_invalid:
2627+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
2628+
"Invalid kernel param kind " +
2629+
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
2630+
break;
2631+
}
2632+
}
2633+
25252634
static std::tuple<ur_kernel_handle_t, std::shared_ptr<device_image_impl>,
25262635
const KernelArgMask *>
25272636
getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,

sycl/tools/xpti_helpers/usm_analyzer.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ class USMAnalyzer {
254254
handleKernelSetArgPointer(
255255
static_cast<ur_kernel_set_arg_pointer_params_t *>(Data->args_data));
256256
return;
257+
case UR_FUNCTION_ENQUEUE_KERNEL_LAUNCH_WITH_ARGS_EXP:
258+
handleEnqueueKernelLaunchWithArgsExp(
259+
static_cast<ur_enqueue_kernel_launch_with_args_exp_params_t *>(
260+
Data->args_data));
261+
return;
257262
default:
258263
return;
259264
}
@@ -421,4 +426,18 @@ class USMAnalyzer {
421426
"kernel parameter with index = " + std::to_string(*Params->pargIndex),
422427
Ptr, 0 /*no data how it will be used in kernel*/, "kernel");
423428
}
429+
430+
static void handleEnqueueKernelLaunchWithArgsExp(
431+
const ur_enqueue_kernel_launch_with_args_exp_params_t *Params) {
432+
// Search for pointer args and validate the pointers
433+
for (uint32_t i = 0; i < *Params->pnumArgs; i++) {
434+
if ((*Params->ppArgs)[i].type == UR_EXP_KERNEL_ARG_TYPE_POINTER) {
435+
void *Ptr = (const_cast<void *>((*Params->ppArgs)[i].arg.pointer));
436+
CheckPointerValidness("kernel parameter with index = " +
437+
std::to_string((*Params->ppArgs)[i].index),
438+
Ptr, 0 /*no data how it will be used in kernel*/,
439+
"kernel");
440+
}
441+
}
442+
}
424443
};

0 commit comments

Comments
 (0)