@@ -649,4 +649,24 @@ ur_result_t urKernelGetSuggestedLocalWorkSize(
649
649
std::copy (localWorkSize, localWorkSize + workDim, pSuggestedLocalWorkSize);
650
650
return UR_RESULT_SUCCESS;
651
651
}
652
+
653
+ ur_result_t urKernelSuggestMaxCooperativeGroupCountExp (
654
+ ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
655
+ const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
656
+ uint32_t *pGroupCountRet) {
657
+ (void )dynamicSharedMemorySize;
658
+
659
+ uint32_t wg[3 ];
660
+ wg[0 ] = ur_cast<uint32_t >(pLocalWorkSize[0 ]);
661
+ wg[1 ] = workDim >= 2 ? ur_cast<uint32_t >(pLocalWorkSize[1 ]) : 1 ;
662
+ wg[2 ] = workDim == 3 ? ur_cast<uint32_t >(pLocalWorkSize[2 ]) : 1 ;
663
+ ZE2UR_CALL (zeKernelSetGroupSize,
664
+ (hKernel->getZeHandle (hDevice), wg[0 ], wg[1 ], wg[2 ]));
665
+
666
+ uint32_t totalGroupCount = 0 ;
667
+ ZE2UR_CALL (zeKernelSuggestMaxCooperativeGroupCount,
668
+ (hKernel->getZeHandle (hDevice), &totalGroupCount));
669
+ *pGroupCountRet = totalGroupCount;
670
+ return UR_RESULT_SUCCESS;
671
+ }
652
672
} // namespace ur::level_zero
0 commit comments