@@ -2303,12 +2303,11 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
2303
2303
}
2304
2304
}
2305
2305
2306
- void SetArgBasedOnType (
2307
- const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
2306
+ void GetUrArgsBasedOnType (
2308
2307
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
2309
2308
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2310
2309
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2311
- size_t NextTrueIndex) {
2310
+ size_t NextTrueIndex, std::vector<ur_exp_kernel_arg_properties_t> &UrArgs ) {
2312
2311
switch (Arg.MType ) {
2313
2312
case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2314
2313
break ;
@@ -2328,52 +2327,63 @@ void SetArgBasedOnType(
2328
2327
getMemAllocationFunc
2329
2328
? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2330
2329
: nullptr ;
2331
- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2332
- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2333
- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2334
- Adapter->call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2335
- &MemObjData, MemArg);
2330
+ ur_exp_kernel_arg_value_t Value = {};
2331
+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2332
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2333
+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2334
+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2335
+ Value});
2336
2336
break ;
2337
2337
}
2338
2338
case kernel_param_kind_t ::kind_std_layout: {
2339
+ ur_exp_kernel_arg_type_t Type;
2339
2340
if (Arg.MPtr ) {
2340
- Adapter->call <UrApiKind::urKernelSetArgValue>(
2341
- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2341
+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
2342
2342
} else {
2343
- Adapter->call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2344
- Arg.MSize , nullptr );
2343
+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
2345
2344
}
2345
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2346
+ nullptr ,
2347
+ Type,
2348
+ static_cast <uint32_t >(NextTrueIndex),
2349
+ static_cast <size_t >(Arg.MSize ),
2350
+ {Arg.MPtr }});
2346
2351
2347
2352
break ;
2348
2353
}
2349
2354
case kernel_param_kind_t ::kind_sampler: {
2350
2355
sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2351
- ur_sampler_handle_t Sampler =
2352
- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2353
- ->getOrCreateSampler (ContextImpl);
2354
- Adapter->call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2355
- nullptr , Sampler);
2356
+ ur_exp_kernel_arg_value_t Value = {};
2357
+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2358
+ ->getOrCreateSampler (ContextImpl);
2359
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2360
+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2361
+ static_cast <uint32_t >(NextTrueIndex),
2362
+ sizeof (ur_sampler_handle_t ), Value});
2356
2363
break ;
2357
2364
}
2358
2365
case kernel_param_kind_t ::kind_pointer: {
2366
+ void *Ptr = *static_cast <void *const *>(Arg.MPtr );
2359
2367
// We need to de-rerence this to get the actual USM allocation - that's the
2360
2368
// pointer UR is expecting.
2361
- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2362
- Adapter->call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2363
- nullptr , Ptr);
2369
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2370
+ nullptr ,
2371
+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2372
+ static_cast <uint32_t >(NextTrueIndex),
2373
+ sizeof (Ptr),
2374
+ {Ptr}});
2364
2375
break ;
2365
2376
}
2366
2377
case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2367
2378
assert (DeviceImageImpl != nullptr );
2368
2379
ur_mem_handle_t SpecConstsBuffer =
2369
2380
DeviceImageImpl->get_spec_const_buffer_ref ();
2370
-
2371
- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2372
- MemObjProps.pNext = nullptr ;
2373
- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2374
- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2375
- Adapter->call <UrApiKind::urKernelSetArgMemObj>(
2376
- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2381
+ ur_exp_kernel_arg_value_t Value = {};
2382
+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2383
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2384
+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2385
+ static_cast <uint32_t >(NextTrueIndex),
2386
+ sizeof (SpecConstsBuffer), Value});
2377
2387
break ;
2378
2388
}
2379
2389
case kernel_param_kind_t ::kind_invalid:
@@ -2408,22 +2418,33 @@ static ur_result_t SetKernelParamsAndLaunch(
2408
2418
: Empty);
2409
2419
}
2410
2420
2421
+ std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2422
+ UrArgs.reserve (Args.size ());
2423
+
2411
2424
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2412
- auto setFunc = [&Adapter, Kernel ,
2425
+ auto setFunc = [&UrArgs ,
2413
2426
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2414
2427
size_t NextTrueIndex) {
2415
2428
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2416
2429
switch (ParamDesc.kind ) {
2417
2430
case kernel_param_kind_t ::kind_std_layout: {
2418
2431
int Size = ParamDesc.info ;
2419
- Adapter->call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2420
- Size, nullptr , ArgPtr);
2432
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2433
+ nullptr ,
2434
+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2435
+ static_cast <uint32_t >(NextTrueIndex),
2436
+ static_cast <size_t >(Size),
2437
+ {ArgPtr}});
2421
2438
break ;
2422
2439
}
2423
2440
case kernel_param_kind_t ::kind_pointer: {
2424
2441
const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2425
- Adapter->call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2426
- nullptr , Ptr);
2442
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2443
+ nullptr ,
2444
+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2445
+ static_cast <uint32_t >(NextTrueIndex),
2446
+ sizeof (Ptr),
2447
+ {Ptr}});
2427
2448
break ;
2428
2449
}
2429
2450
default :
@@ -2433,10 +2454,11 @@ static ur_result_t SetKernelParamsAndLaunch(
2433
2454
applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
2434
2455
KernelParamDescGetter, setFunc);
2435
2456
} else {
2436
- auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2437
- &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2438
- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2439
- Queue.getContextImplPtr (), Arg, NextTrueIndex);
2457
+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
2458
+ &UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2459
+ GetUrArgsBasedOnType (DeviceImageImpl, getMemAllocationFunc,
2460
+ Queue.getContextImplPtr (), Arg, NextTrueIndex,
2461
+ UrArgs);
2440
2462
};
2441
2463
applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2442
2464
}
@@ -2449,8 +2471,12 @@ static ur_result_t SetKernelParamsAndLaunch(
2449
2471
// CUDA-style local memory setting. Note that we may have -1 as a position,
2450
2472
// this indicates the buffer is actually unused and was elided.
2451
2473
if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2452
- Adapter->call <UrApiKind::urKernelSetArgLocal>(
2453
- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2474
+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2475
+ nullptr ,
2476
+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2477
+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2478
+ WorkGroupMemorySize,
2479
+ {nullptr }});
2454
2480
}
2455
2481
2456
2482
adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2508,20 +2534,103 @@ static ur_result_t SetKernelParamsAndLaunch(
2508
2534
{{WorkGroupMemorySize}}});
2509
2535
}
2510
2536
ur_event_handle_t UREvent = nullptr ;
2511
- ur_result_t Error = Adapter->call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2512
- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2513
- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2514
- LocalSize, property_list.size (),
2515
- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2516
- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2517
- OutEventImpl ? &UREvent : nullptr );
2537
+ ur_result_t Error =
2538
+ Adapter->call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2539
+ Queue.getHandleRef (), Kernel,
2540
+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2541
+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2542
+ property_list.size (),
2543
+ property_list.empty () ? nullptr : property_list.data (),
2544
+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2545
+ OutEventImpl ? &UREvent : nullptr );
2518
2546
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
2519
2547
OutEventImpl->setHandle (UREvent);
2520
2548
}
2521
2549
2522
2550
return Error;
2523
2551
}
2524
2552
2553
+ void SetArgBasedOnType (
2554
+ const AdapterPtr &Adapter, ur_kernel_handle_t Kernel,
2555
+ const std::shared_ptr<device_image_impl> &DeviceImageImpl,
2556
+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2557
+ const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
2558
+ size_t NextTrueIndex) {
2559
+ switch (Arg.MType ) {
2560
+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2561
+ break ;
2562
+ case kernel_param_kind_t ::kind_work_group_memory:
2563
+ break ;
2564
+ case kernel_param_kind_t ::kind_stream:
2565
+ break ;
2566
+ case kernel_param_kind_t ::kind_dynamic_accessor:
2567
+ case kernel_param_kind_t ::kind_accessor: {
2568
+ Requirement *Req = (Requirement *)(Arg.MPtr );
2569
+
2570
+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2571
+ // we may pass default constructed accessors to a command, which don't add
2572
+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2573
+ // valid case, so we need to properly handle it.
2574
+ ur_mem_handle_t MemArg =
2575
+ getMemAllocationFunc
2576
+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2577
+ : nullptr ;
2578
+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2579
+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2580
+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2581
+ Adapter->call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2582
+ &MemObjData, MemArg);
2583
+ break ;
2584
+ }
2585
+ case kernel_param_kind_t ::kind_std_layout: {
2586
+ if (Arg.MPtr ) {
2587
+ Adapter->call <UrApiKind::urKernelSetArgValue>(
2588
+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2589
+ } else {
2590
+ Adapter->call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2591
+ Arg.MSize , nullptr );
2592
+ }
2593
+
2594
+ break ;
2595
+ }
2596
+ case kernel_param_kind_t ::kind_sampler: {
2597
+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2598
+ ur_sampler_handle_t Sampler =
2599
+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2600
+ ->getOrCreateSampler (ContextImpl);
2601
+ Adapter->call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2602
+ nullptr , Sampler);
2603
+ break ;
2604
+ }
2605
+ case kernel_param_kind_t ::kind_pointer: {
2606
+ // We need to de-rerence this to get the actual USM allocation - that's the
2607
+ // pointer UR is expecting.
2608
+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2609
+ Adapter->call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2610
+ nullptr , Ptr);
2611
+ break ;
2612
+ }
2613
+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2614
+ assert (DeviceImageImpl != nullptr );
2615
+ ur_mem_handle_t SpecConstsBuffer =
2616
+ DeviceImageImpl->get_spec_const_buffer_ref ();
2617
+
2618
+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2619
+ MemObjProps.pNext = nullptr ;
2620
+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2621
+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2622
+ Adapter->call <UrApiKind::urKernelSetArgMemObj>(
2623
+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2624
+ break ;
2625
+ }
2626
+ case kernel_param_kind_t ::kind_invalid:
2627
+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2628
+ " Invalid kernel param kind " +
2629
+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2630
+ break ;
2631
+ }
2632
+ }
2633
+
2525
2634
static std::tuple<ur_kernel_handle_t , std::shared_ptr<device_image_impl>,
2526
2635
const KernelArgMask *>
2527
2636
getCGKernelInfo (const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
0 commit comments