From 438868a2b6f9fae05100b766cc3c6d4df37762d5 Mon Sep 17 00:00:00 2001 From: Craig Perkins Date: Thu, 10 Apr 2025 15:58:20 -0400 Subject: [PATCH] Limit stack walking to frames before AccessController.doPrivileged Signed-off-by: Craig Perkins Signed-off-by: Andrew Ross --- .../javaagent/SocketChannelInterceptor.java | 3 +- ...kCallerProtectionDomainChainExtractor.java | 7 +- ...kCallerProtectionDomainExtractorTests.java | 118 ++++++++++++++++++ 3 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java diff --git a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java index 3ac48f9e72f74..93daeccb6503f 100644 --- a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java +++ b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java @@ -10,7 +10,6 @@ import org.opensearch.javaagent.bootstrap.AgentPolicy; -import java.lang.StackWalker.Option; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.NetPermission; @@ -46,7 +45,7 @@ public static void intercept(@Advice.AllArguments Object[] args, @Origin Method return; /* noop */ } - final StackWalker walker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE); + final StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); final Collection callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE); if (args[0] instanceof InetSocketAddress address) { diff --git a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java index 69b91d0d8b74c..f4a1382254b0f 100644 --- a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java +++ b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java @@ -35,9 +35,12 @@ private StackCallerProtectionDomainChainExtractor() {} */ @Override public Collection apply(Stream frames) { - return frames.map(StackFrame::getDeclaringClass) + return frames.takeWhile( + frame -> !(frame.getClassName().equals("java.security.AccessController") && frame.getMethodName().equals("doPrivileged")) + ) + .map(StackFrame::getDeclaringClass) .map(Class::getProtectionDomain) - .filter(pd -> pd.getCodeSource() != null) /* JDK */ + .filter(pd -> pd.getCodeSource() != null) // Filter out JDK classes .collect(Collectors.toSet()); } } diff --git a/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java b/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java new file mode 100644 index 0000000000000..4f26a97d0ff12 --- /dev/null +++ b/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java @@ -0,0 +1,118 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.javaagent; + +import org.junit.Test; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.security.ProtectionDomain; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertEquals; + +public class StackCallerProtectionDomainExtractorTests { + + private static List indirectlyCaptureStackFrames() { + return captureStackFrames(); + } + + private static List captureStackFrames() { + // OPTION.RETAIN_CLASS_REFERENCE lets you do f.getDeclaringClass() if you need it + StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); + return walker.walk(frames -> frames.collect(Collectors.toList())); + } + + @Test + public void testSimpleProtectionDomainExtraction() throws Exception { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(7, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "gradle-worker", + "gradle-worker-main", + "gradle-messaging", + "gradle-testing-base-infrastructure", + "test", // from the build/classes/java/test directory + "junit", + "gradle-testing-jvm-infrastructure" + ) + ); + } + + @Test + public void testIndirectlyCaptureStackFramesInListOfFrames() throws Exception { + List stackFrames = indirectlyCaptureStackFrames(); + List methodNames = stackFrames.stream().map(StackWalker.StackFrame::getMethodName).toList(); + assertThat(methodNames, hasItem("indirectlyCaptureStackFrames")); + } + + @Test + @SuppressWarnings("removal") + public void testStackTruncationWithAccessController() throws Exception { + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(1, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "test" // from the build/classes/java/test directory + ) + ); + return null; + } + }); + } +}