Skip to content

Commit e6d4355

Browse files
authored
Merge pull request #2515 from igchor/l0_v2_coop
[L0 v2] implement urKernelSuggestMaxCooperativeGroupCountExp
2 parents 421b755 + bfae55d commit e6d4355

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

source/adapters/level_zero/v2/api.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -474,14 +474,6 @@ ur_result_t urCommandBufferCommandGetInfoExp(
474474
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
475475
}
476476

477-
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
478-
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
479-
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
480-
uint32_t *pGroupCountRet) {
481-
logger::error("{} function not implemented!", __FUNCTION__);
482-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
483-
}
484-
485477
ur_result_t urUSMImportExp(ur_context_handle_t hContext, void *pMem,
486478
size_t size) {
487479
logger::error("{} function not implemented!", __FUNCTION__);

source/adapters/level_zero/v2/kernel.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -649,4 +649,24 @@ ur_result_t urKernelGetSuggestedLocalWorkSize(
649649
std::copy(localWorkSize, localWorkSize + workDim, pSuggestedLocalWorkSize);
650650
return UR_RESULT_SUCCESS;
651651
}
652+
653+
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
654+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
655+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
656+
uint32_t *pGroupCountRet) {
657+
(void)dynamicSharedMemorySize;
658+
659+
uint32_t wg[3];
660+
wg[0] = ur_cast<uint32_t>(pLocalWorkSize[0]);
661+
wg[1] = workDim >= 2 ? ur_cast<uint32_t>(pLocalWorkSize[1]) : 1;
662+
wg[2] = workDim == 3 ? ur_cast<uint32_t>(pLocalWorkSize[2]) : 1;
663+
ZE2UR_CALL(zeKernelSetGroupSize,
664+
(hKernel->getZeHandle(hDevice), wg[0], wg[1], wg[2]));
665+
666+
uint32_t totalGroupCount = 0;
667+
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
668+
(hKernel->getZeHandle(hDevice), &totalGroupCount));
669+
*pGroupCountRet = totalGroupCount;
670+
return UR_RESULT_SUCCESS;
671+
}
652672
} // namespace ur::level_zero

0 commit comments

Comments
 (0)