diff --git a/build.gradle b/build.gradle index 4b8b9af63..a7f1d2934 100644 --- a/build.gradle +++ b/build.gradle @@ -83,6 +83,7 @@ buildscript { //****************************************************************************/ plugins { + id 'eclipse' id 'java-library' id 'java-test-fixtures' id 'idea' @@ -96,7 +97,7 @@ apply plugin: 'opensearch.opensearchplugin' apply plugin: 'opensearch.rest-test' apply plugin: 'opensearch.pluginzip' apply plugin: 'opensearch.repositories' - +apply plugin: 'opensearch.java-agent' def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile opensearch_tmp_dir.mkdirs() diff --git a/remote-index-build-client/build.gradle b/remote-index-build-client/build.gradle index 6b96f9273..2dea8ce1a 100644 --- a/remote-index-build-client/build.gradle +++ b/remote-index-build-client/build.gradle @@ -10,6 +10,7 @@ plugins { id "io.freefair.lombok" id 'com.diffplug.spotless' version '6.25.0' id 'opensearch.build' + id 'opensearch.java-agent' } repositories { diff --git a/src/main/java/org/opensearch/knn/jni/PlatformUtils.java b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java index a67a88487..0ec544a1f 100644 --- a/src/main/java/org/opensearch/knn/jni/PlatformUtils.java +++ b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java @@ -15,6 +15,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; + import oshi.util.platform.mac.SysctlUtil; import java.nio.file.Files; @@ -27,9 +28,18 @@ import java.util.stream.Stream; public class PlatformUtils { - private static final Logger logger = LogManager.getLogger(PlatformUtils.class); + private static volatile Boolean isAVX2Supported; + private static volatile Boolean isAVX512Supported; + private static volatile Boolean isAVX512SPRSupported; + + static void reset() { + isAVX2Supported = null; + isAVX512Supported = null; + isAVX512SPRSupported = null; + } + /** * Verify if the underlying system supports AVX2 SIMD Optimization or not * 1. If the architecture is not x86 return false. @@ -41,22 +51,26 @@ public class PlatformUtils { */ public static boolean isAVX2SupportedBySystem() { if (!Platform.isIntel() || Platform.isWindows()) { - return false; + isAVX2Supported = false; } - if (Platform.isMac()) { + if (isAVX2Supported != null) { + return isAVX2Supported; + } + if (Platform.isMac()) { // sysctl or system control retrieves system info and allows processes with appropriate privileges // to set system info. This system info contains the machine dependent cpu features that are supported by it. // On MacOS, if the underlying processor supports AVX2 instruction set, it will be listed under the "leaf7" // subset of instructions ("sysctl -a | grep machdep.cpu.leaf7_features"). // https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/sysctl.3.html try { - return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + isAVX2Supported = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { String flags = SysctlUtil.sysctl("machdep.cpu.leaf7_features", "empty"); return (flags.toLowerCase(Locale.ROOT)).contains("avx2"); }); } catch (Exception e) { + isAVX2Supported = false; logger.error("[KNN] Error fetching cpu flags info. [{}]", e.getMessage(), e); } @@ -70,25 +84,32 @@ public static boolean isAVX2SupportedBySystem() { // https://ark.intel.com/content/www/us/en/ark/products/199285/intel-pentium-gold-g6600-processor-4m-cache-4-20-ghz.html String fileName = "/proc/cpuinfo"; try { - return AccessController.doPrivileged( + isAVX2Supported = AccessController.doPrivileged( (PrivilegedExceptionAction) () -> (Boolean) Files.lines(Paths.get(fileName)) .filter(s -> s.startsWith("flags")) .anyMatch(s -> StringUtils.containsIgnoreCase(s, "avx2")) ); } catch (Exception e) { + isAVX2Supported = false; logger.error("[KNN] Error reading file [{}]. [{}]", fileName, e.getMessage(), e); } } - return false; + return isAVX2Supported; } public static boolean isAVX512SupportedBySystem() { - return areAVX512FlagsAvailable(new String[] { "avx512f", "avx512cd", "avx512vl", "avx512dq", "avx512bw" }); + if (isAVX512Supported == null) { + isAVX512Supported = areAVX512FlagsAvailable(new String[] { "avx512f", "avx512cd", "avx512vl", "avx512dq", "avx512bw" }); + } + return isAVX512Supported; } public static boolean isAVX512SPRSupportedBySystem() { - return areAVX512FlagsAvailable(new String[] { "avx512_fp16", "avx512_bf16", "avx512_vpopcntdq" }); + if (isAVX512SPRSupported == null) { + isAVX512SPRSupported = areAVX512FlagsAvailable(new String[] { "avx512_fp16", "avx512_bf16", "avx512_vpopcntdq" }); + } + return isAVX512SPRSupported; } private static boolean areAVX512FlagsAvailable(String[] avx512) { diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index e673c5d72..c1d3eace7 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -47,6 +47,7 @@ import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.jni.PlatformUtils; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; @@ -117,6 +118,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ForkJoinPool; import java.util.function.Supplier; import static java.util.Collections.singletonList; @@ -174,6 +176,14 @@ public class KNNPlugin extends Plugin private ClusterService clusterService; private Supplier repositoriesServiceSupplier; + static { + ForkJoinPool.commonPool().execute(() -> { + PlatformUtils.isAVX2SupportedBySystem(); + PlatformUtils.isAVX512SupportedBySystem(); + PlatformUtils.isAVX512SPRSupportedBySystem(); + }); + } + @Override public Map getMappers() { return Collections.singletonMap( diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 2ec9ce6b5..928a7ea7b 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -41,6 +41,9 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.test.hamcrest.OpenSearchAssertions; +import com.carrotsearch.randomizedtesting.ThreadFilter; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; + import java.io.IOException; import java.util.Base64; import java.util.Collection; @@ -63,7 +66,18 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +@ThreadLeakFilters(defaultFilters = true, filters = { KNNSingleNodeTestCase.ForkJoinFilter.class }) public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { + /** + * The the ForkJoinPool.commonPool() never terminates until program shutdown. + */ + public static final class ForkJoinFilter implements ThreadFilter { + @Override + public boolean reject(Thread t) { + return t.getName().startsWith("ForkJoinPool.commonPool-worker"); + } + } + @Override public void setUp() throws Exception { super.setUp(); diff --git a/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java index c524d211d..bc66b934a 100644 --- a/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java +++ b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java @@ -12,8 +12,11 @@ package org.opensearch.knn.jni; import com.sun.jna.Platform; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import org.mockito.MockedStatic; -import org.opensearch.knn.KNNTestCase; import oshi.util.platform.mac.SysctlUtil; import java.nio.file.Files; @@ -25,10 +28,16 @@ import static org.opensearch.knn.jni.PlatformUtils.isAVX512SupportedBySystem; import static org.opensearch.knn.jni.PlatformUtils.isAVX512SPRSupportedBySystem; -public class PlatformUtilTests extends KNNTestCase { +public class PlatformUtilTests extends Assert { public static final String MAC_CPU_FEATURES = "machdep.cpu.leaf7_features"; public static final String LINUX_PROC_CPU_INFO = "/proc/cpuinfo"; + @Before + public void setUp() { + PlatformUtils.reset(); + } + + @Test public void testIsAVX2SupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -36,6 +45,7 @@ public void testIsAVX2SupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -44,6 +54,7 @@ public void testIsAVX2SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFa } } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -59,6 +70,7 @@ public void testIsAVX2SupportedBySystem_platformIsMac_returnsTrue() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -72,6 +84,7 @@ public void testIsAVX2SupportedBySystem_platformIsMac_returnsFalse() { } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_throwsExceptionReturnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -98,6 +111,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_returnsTrue() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsLinux_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -112,6 +126,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_returnsFalse() { } + @Test public void testIsAVX2SupportedBySystem_platformIsLinux_throwsExceptionReturnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -127,7 +142,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_throwsExceptionReturnsFa } // AVX512 tests - + @Test public void testIsAVX512SupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -135,6 +150,7 @@ public void testIsAVX512SupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isMac).thenReturn(false); @@ -142,6 +158,7 @@ public void testIsAVX512SupportedBySystem_platformIsMac_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsIntelMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -150,6 +167,7 @@ public void testIsAVX512SupportedBySystem_platformIsIntelMac_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -158,6 +176,7 @@ public void testIsAVX512SupportedBySystem_platformIsIntelWithOSAsWindows_returns } } + @Test public void testIsAVX512SupportedBySystem_platformIsLinuxAllAVX512FlagsPresent_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -171,6 +190,7 @@ public void testIsAVX512SupportedBySystem_platformIsLinuxAllAVX512FlagsPresent_r } } + @Test public void testIsAVX512SupportedBySystem_platformIsLinuxSomeAVX512FlagsPresent_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -185,7 +205,7 @@ public void testIsAVX512SupportedBySystem_platformIsLinuxSomeAVX512FlagsPresent_ } // Tests AVX512 instructions available since Intel(R) Sapphire Rapids. - + @Test public void testIsAVX512SPRSupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -193,6 +213,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isMac).thenReturn(false); @@ -200,6 +221,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsMac_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsIntelMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -208,6 +230,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsIntelMac_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -216,6 +239,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsIntelWithOSAsWindows_retu } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsLinuxAllAVX512SPRFlagsPresent_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -229,6 +253,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsLinuxAllAVX512SPRFlagsPre } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsLinuxSomeAVX512SPRFlagsPresent_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true);