|
13 | 13 | #include "context.hpp"
|
14 | 14 | #include "kernel.hpp"
|
15 | 15 | #include "memory.hpp"
|
| 16 | +#include "queue_api.hpp" |
16 | 17 |
|
17 | 18 | #include "../device.hpp"
|
18 | 19 | #include "../helpers/kernel_helpers.hpp"
|
@@ -624,4 +625,28 @@ ur_result_t urKernelGetInfo(ur_kernel_handle_t hKernel,
|
624 | 625 | } catch (...) {
|
625 | 626 | return exceptionToResult(std::current_exception());
|
626 | 627 | }
|
| 628 | + |
| 629 | +ur_result_t urKernelGetSuggestedLocalWorkSize( |
| 630 | + ur_kernel_handle_t hKernel, ur_queue_handle_t hQueue, uint32_t workDim, |
| 631 | + [[maybe_unused]] const size_t *pGlobalWorkOffset, |
| 632 | + const size_t *pGlobalWorkSize, size_t *pSuggestedLocalWorkSize) { |
| 633 | + UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); |
| 634 | + UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); |
| 635 | + UR_ASSERT(pSuggestedLocalWorkSize != nullptr, |
| 636 | + UR_RESULT_ERROR_INVALID_NULL_POINTER); |
| 637 | + |
| 638 | + uint32_t localWorkSize[3]; |
| 639 | + size_t globalWorkSize3D[3]{1, 1, 1}; |
| 640 | + std::copy(pGlobalWorkSize, pGlobalWorkSize + workDim, globalWorkSize3D); |
| 641 | + |
| 642 | + ur_device_handle_t hDevice; |
| 643 | + UR_CALL(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice), |
| 644 | + reinterpret_cast<void *>(&hDevice), nullptr)); |
| 645 | + |
| 646 | + UR_CALL(getSuggestedLocalWorkSize(hDevice, hKernel->getZeHandle(hDevice), |
| 647 | + globalWorkSize3D, localWorkSize)); |
| 648 | + |
| 649 | + std::copy(localWorkSize, localWorkSize + workDim, pSuggestedLocalWorkSize); |
| 650 | + return UR_RESULT_SUCCESS; |
| 651 | +} |
627 | 652 | } // namespace ur::level_zero
|
0 commit comments