Skip to content

Commit 92d83d2

Browse files
committed
Move logic to StackCallerProtectionDomainChainExtractor
Signed-off-by: Craig Perkins <[email protected]>
1 parent 9f12fe8 commit 92d83d2

File tree

3 files changed

+7
-45
lines changed

3 files changed

+7
-45
lines changed

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/FileInterceptor.java

+1-21
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
import java.security.Policy;
2222
import java.security.ProtectionDomain;
2323
import java.util.Collection;
24-
import java.util.ArrayList;
25-
import java.util.List;
26-
import java.lang.StackWalker.StackFrame;
27-
import java.util.stream.Collectors;
2824

2925
import net.bytebuddy.asm.Advice;
3026

@@ -70,23 +66,7 @@ public static void intercept(@Advice.AllArguments Object[] args, @Advice.Origin
7066
}
7167

7268
final StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
73-
final Collection<ProtectionDomain> callers = walker.walk(s -> {
74-
List<ProtectionDomain> domains = new ArrayList<>();
75-
boolean foundPrivileged = false;
76-
77-
for (StackFrame frame : s.toList()) {
78-
if (frame.getClassName().equals("java.security.AccessController") &&
79-
frame.getMethodName().equals("doPrivileged")) {
80-
foundPrivileged = true;
81-
break;
82-
}
83-
Class<?> callerClass = frame.getDeclaringClass();
84-
domains.add(callerClass.getProtectionDomain());
85-
}
86-
87-
return foundPrivileged ? domains : s.map(f -> f.getDeclaringClass().getProtectionDomain())
88-
.collect(Collectors.toList());
89-
});
69+
final Collection<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);
9070

9171
final String name = method.getName();
9272
boolean isMutating = name.equals("move") || name.equals("write") || name.startsWith("create");

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java

+1-22
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import org.opensearch.javaagent.bootstrap.AgentPolicy;
1212

13-
import java.lang.StackWalker.Option;
1413
import java.lang.reflect.Method;
1514
import java.net.InetSocketAddress;
1615
import java.net.NetPermission;
@@ -19,10 +18,6 @@
1918
import java.security.Policy;
2019
import java.security.ProtectionDomain;
2120
import java.util.Collection;
22-
import java.util.ArrayList;
23-
import java.util.List;
24-
import java.lang.StackWalker.StackFrame;
25-
import java.util.stream.Collectors;
2621

2722
import net.bytebuddy.asm.Advice;
2823
import net.bytebuddy.asm.Advice.Origin;
@@ -51,23 +46,7 @@ public static void intercept(@Advice.AllArguments Object[] args, @Origin Method
5146
}
5247

5348
final StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
54-
final Collection<ProtectionDomain> callers = walker.walk(s -> {
55-
List<ProtectionDomain> domains = new ArrayList<>();
56-
boolean foundPrivileged = false;
57-
58-
for (StackFrame frame : s.toList()) {
59-
if (frame.getClassName().equals("java.security.AccessController") &&
60-
frame.getMethodName().equals("doPrivileged")) {
61-
foundPrivileged = true;
62-
break;
63-
}
64-
Class<?> callerClass = frame.getDeclaringClass();
65-
domains.add(callerClass.getProtectionDomain());
66-
}
67-
68-
return foundPrivileged ? domains : s.map(f -> f.getDeclaringClass().getProtectionDomain())
69-
.collect(Collectors.toList());
70-
});
49+
final Collection<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);
7150

7251
if (args[0] instanceof InetSocketAddress address) {
7352
if (!AgentPolicy.isTrustedHost(address.getHostString())) {

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ private StackCallerProtectionDomainChainExtractor() {}
3535
*/
3636
@Override
3737
public Collection<ProtectionDomain> apply(Stream<StackFrame> frames) {
38-
return frames.map(StackFrame::getDeclaringClass)
38+
return frames.takeWhile(
39+
frame -> !frame.getClassName().equals("java.security.AccessController") || !frame.getMethodName().equals("doPrivileged")
40+
)
41+
.map(StackFrame::getDeclaringClass)
3942
.map(Class::getProtectionDomain)
40-
.filter(pd -> pd.getCodeSource() != null) /* JDK */
43+
.filter(pd -> pd.getCodeSource() != null) // Filter out JDK classes
4144
.collect(Collectors.toSet());
4245
}
4346
}

0 commit comments

Comments
 (0)