Skip to content

[UR][SYCL] Introduce UR api to set kernel args + launch in one call. #18764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cebad02
[UR][SYCL] Introduce UR api to set kernel args + launch in one call.
aarongreig Jun 3, 2025
eff5f5e
Merge branch 'sycl' into aaron/enqueueKernelWithArgs
aarongreig Jun 10, 2025
82176ff
Fix tsan launchinfo
aarongreig Jun 10, 2025
9ccdfcd
Fix unit tests.
aarongreig Jun 10, 2025
2762c70
Fix native cpu + some cuda/hip fails.
aarongreig Jun 12, 2025
e25f390
Mechanically replace urEnqueueKernelLaunch in e2e tests.
aarongreig Jun 12, 2025
ffa9a11
Merge branch 'sycl' into aaron/enqueueKernelWithArgs
aarongreig Jun 12, 2025
3c87170
Fix a couple of tests and an oversight in the sanitizer layer.
aarongreig Jun 13, 2025
3c38b26
Fix fallthrough.
aarongreig Jun 13, 2025
d118cec
Remove missed SetArg calls.
aarongreig Jun 13, 2025
28d539e
Fix a test and move asan kernel arg handling to helpers.
aarongreig Jun 16, 2025
1fda654
Add missing locks to sanitizer launch with args.
aarongreig Jun 17, 2025
26976b6
Set kernel args in sanitizer layers rather than passing them through.
aarongreig Jun 18, 2025
70aa909
Merge branch 'sycl' into aaron/enqueueKernelWithArgs
aarongreig Jun 18, 2025
cda9d00
Fix printing by adding separate value union member.
aarongreig Jun 18, 2025
29cbd08
Fix global size validation and add some negative tests.
aarongreig Jun 18, 2025
9153d56
Expand testing + validation.
aarongreig Jun 20, 2025
74e19e1
Add kernel arg storage to queue_impl rather than re-allocate for ever…
aarongreig Jun 20, 2025
2879552
Spec cleanup, add rst file
aarongreig Jun 20, 2025
4744151
Fix unittest build.
aarongreig Jun 20, 2025
bafcebd
Minor l0 fix: handle sampler args with SetArgValueHelper.
aarongreig Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,13 @@ void queue_impl::verifyProps(const property_list &Props) const {
CheckPropertiesWithData);
}

std::vector<ur_exp_kernel_arg_properties_t> &
queue_impl::getKernelArgStorage(uint32_t size) {
MKernelArgStorage.clear();
MKernelArgStorage.reserve(size);
return MKernelArgStorage;
}

} // namespace detail
} // namespace _V1
} // namespace sycl
9 changes: 9 additions & 0 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,11 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
}
#endif

/// Clears MKernelArgsStorage, has it .reserve(size), and returns a reference
/// to it. Not inherently thread safe.
std::vector<ur_exp_kernel_arg_properties_t> &
getKernelArgStorage(uint32_t size);

protected:
template <typename HandlerType = handler>
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
Expand Down Expand Up @@ -999,6 +1004,10 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// List of queues created for FPGA device from a single SYCL queue.
ur_queue_handle_t MQueue;

// To avoid re-allocating this every time a kernel is enqueued we keep this
// vector around and .clear()/.reserve() for each kernel instead.
std::vector<ur_exp_kernel_arg_properties_t> MKernelArgStorage;

// Access should be guarded with MMutex
struct DependencyTrackingItems {
// This event is employed for enhanced dependency tracking with in-order
Expand Down
199 changes: 152 additions & 47 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2302,12 +2302,11 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
}
}

void SetArgBasedOnType(
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
void GetUrArgsBasedOnType(
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
size_t NextTrueIndex) {
size_t NextTrueIndex, std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
Expand All @@ -2327,52 +2326,61 @@ void SetArgBasedOnType(
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter->call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
Value});
break;
}
case kernel_param_kind_t::kind_std_layout: {
ur_exp_kernel_arg_type_t Type;
if (Arg.MPtr) {
Adapter->call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
} else {
Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
}
ur_exp_kernel_arg_value_t Value = {};
Value.value = {Arg.MPtr};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
Type, static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Arg.MSize), Value});

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter->call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
ur_exp_kernel_arg_value_t Value = {};
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(ur_sampler_handle_t), Value});
break;
}
case kernel_param_kind_t::kind_pointer: {
// We need to de-rerence this to get the actual USM allocation - that's the
ur_exp_kernel_arg_value_t Value = {};
// We need to de-rerence to get the actual USM allocation - that's the
// pointer UR is expecting.
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex), sizeof(Arg.MPtr),
Value});
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter->call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex),
sizeof(SpecConstsBuffer), Value});
break;
}
case kernel_param_kind_t::kind_invalid:
Expand Down Expand Up @@ -2407,22 +2415,31 @@ static ur_result_t SetKernelParamsAndLaunch(
: Empty);
}

auto UrArgs = Queue.getKernelArgStorage(Args.size());

if (KernelFuncPtr && !KernelHasSpecialCaptures) {
auto setFunc = [&Adapter, Kernel,
auto setFunc = [&UrArgs,
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
size_t NextTrueIndex) {
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
switch (ParamDesc.kind) {
case kernel_param_kind_t::kind_std_layout: {
int Size = ParamDesc.info;
Adapter->call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
Size, nullptr, ArgPtr);
ur_exp_kernel_arg_value_t Value = {};
Value.value = ArgPtr;
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_VALUE,
static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Size), Value});
break;
}
case kernel_param_kind_t::kind_pointer: {
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
ur_exp_kernel_arg_value_t Value = {};
Value.pointer = *static_cast<const void *const *>(ArgPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(Value.pointer), Value});
break;
}
default:
Expand All @@ -2432,10 +2449,11 @@ static ur_result_t SetKernelParamsAndLaunch(
applyFuncOnFilteredArgs(EliminatedArgMask, KernelNumArgs,
KernelParamDescGetter, setFunc);
} else {
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImplPtr(), Arg, NextTrueIndex);
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
&UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImplPtr(), Arg, NextTrueIndex,
UrArgs);
};
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
}
Expand All @@ -2448,8 +2466,12 @@ static ur_result_t SetKernelParamsAndLaunch(
// CUDA-style local memory setting. Note that we may have -1 as a position,
// this indicates the buffer is actually unused and was elided.
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
Adapter->call<UrApiKind::urKernelSetArgLocal>(
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
nullptr,
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
static_cast<uint32_t>(ImplicitLocalArg.value()),
WorkGroupMemorySize,
{nullptr}});
}

adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
Expand Down Expand Up @@ -2507,20 +2529,103 @@ static ur_result_t SetKernelParamsAndLaunch(
{{WorkGroupMemorySize}}});
}
ur_event_handle_t UREvent = nullptr;
ur_result_t Error = Adapter->call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
LocalSize, property_list.size(),
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
ur_result_t Error =
Adapter->call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
Queue.getHandleRef(), Kernel,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
property_list.size(),
property_list.empty() ? nullptr : property_list.data(),
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
OutEventImpl->setHandle(UREvent);
}

return Error;
}

void SetArgBasedOnType(
const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
size_t NextTrueIndex) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
case kernel_param_kind_t::kind_work_group_memory:
break;
case kernel_param_kind_t::kind_stream:
break;
case kernel_param_kind_t::kind_dynamic_accessor:
case kernel_param_kind_t::kind_accessor: {
Requirement *Req = (Requirement *)(Arg.MPtr);

// getMemAllocationFunc is nullptr when there are no requirements. However,
// we may pass default constructed accessors to a command, which don't add
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
// valid case, so we need to properly handle it.
ur_mem_handle_t MemArg =
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter->call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
break;
}
case kernel_param_kind_t::kind_std_layout: {
if (Arg.MPtr) {
Adapter->call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
} else {
Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
}

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter->call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
break;
}
case kernel_param_kind_t::kind_pointer: {
// We need to de-rerence this to get the actual USM allocation - that's the
// pointer UR is expecting.
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter->call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
break;
}
case kernel_param_kind_t::kind_invalid:
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Invalid kernel param kind " +
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
break;
}
}

static std::tuple<ur_kernel_handle_t, std::shared_ptr<device_image_impl>,
const KernelArgMask *>
getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
queue q;

submit_kernel(q); // starts a batch
// CHECK: ---> urEnqueueKernelLaunch
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// Initialize Level Zero driver is required if this test is linked
Expand All @@ -41,7 +41,7 @@ int main(int argc, char *argv[]) {
// CHECK-NOT: zeCommandQueueExecuteCommandLists

submit_kernel(q);
// CHECK: ---> urEnqueueKernelLaunch
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// interop should close the batch
Expand Down
Loading