Skip to content

Commit c135c5c

Browse files
committed
refactor code based on comments
Signed-off-by: Chenyang Ji <[email protected]>
1 parent b34314c commit c135c5c

File tree

5 files changed

+43
-50
lines changed

5 files changed

+43
-50
lines changed

plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.opensearch.action.search.SearchRequestOperationsListener;
1717
import org.opensearch.cluster.service.ClusterService;
1818
import org.opensearch.common.inject.Inject;
19-
import org.opensearch.common.util.concurrent.ThreadContext;
2019
import org.opensearch.core.xcontent.ToXContent;
2120
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
2221
import org.opensearch.plugin.insights.rules.model.Attribute;
@@ -153,18 +152,16 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
153152
// Get internal computed and user provided labels
154153
Map<String, Object> labels = new HashMap<>();
155154
// Retrieve user provided label if exists
156-
ThreadContext threadContext = threadPool.getThreadContext();
157-
String userProvidedLabel = threadContext.getRequestHeadersOnly().get(Task.X_OPAQUE_ID);
155+
String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool);
158156
if (userProvidedLabel != null) {
159157
labels.put(Task.X_OPAQUE_ID, userProvidedLabel);
160158
}
161159
// Retrieve computed labels if exists
162-
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
160+
Map<String, Object> computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool);
163161
if (computedLabels != null) {
164162
labels.putAll(computedLabels);
165163
}
166164
attributes.put(Attribute.LABELS, labels);
167-
168165
// construct SearchQueryRecord from attributes and measurements
169166
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
170167
queryInsightsService.addRecord(record);

plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import org.opensearch.common.util.concurrent.ThreadContext;
2020
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
2121
import org.opensearch.plugin.insights.core.service.TopQueriesService;
22+
import org.opensearch.plugin.insights.rules.model.Attribute;
2223
import org.opensearch.plugin.insights.rules.model.MetricType;
24+
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
2325
import org.opensearch.plugin.insights.settings.QueryInsightsSettings;
2426
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
2527
import org.opensearch.search.aggregations.support.ValueType;
@@ -35,10 +37,13 @@
3537
import java.util.Collections;
3638
import java.util.HashMap;
3739
import java.util.List;
40+
import java.util.Locale;
3841
import java.util.Map;
3942
import java.util.concurrent.CountDownLatch;
4043
import java.util.concurrent.Phaser;
4144

45+
import org.mockito.ArgumentCaptor;
46+
4247
import static org.mockito.ArgumentMatchers.any;
4348
import static org.mockito.Mockito.mock;
4449
import static org.mockito.Mockito.times;
@@ -70,11 +75,12 @@ public void setup() {
7075
when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService);
7176

7277
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
73-
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "test"), new HashMap<>()));
74-
threadContext.putTransient(RequestLabelingService.COMPUTED_LABELS, Map.of("a", "b"));
78+
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>()));
79+
threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue"));
7580
when(threadPool.getThreadContext()).thenReturn(threadContext);
7681
}
7782

83+
@SuppressWarnings("unchecked")
7884
public void testOnRequestEnd() throws InterruptedException {
7985
Long timestamp = System.currentTimeMillis() - 100L;
8086
SearchType searchType = SearchType.QUERY_THEN_FETCH;
@@ -101,10 +107,19 @@ public void testOnRequestEnd() throws InterruptedException {
101107
when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap);
102108
when(searchPhaseContext.getRequest()).thenReturn(searchRequest);
103109
when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards);
110+
ArgumentCaptor<SearchQueryRecord> captor = ArgumentCaptor.forClass(SearchQueryRecord.class);
104111

105112
queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext);
106113

107-
verify(queryInsightsService, times(1)).addRecord(any());
114+
verify(queryInsightsService, times(1)).addRecord(captor.capture());
115+
SearchQueryRecord generatedRecord = captor.getValue();
116+
assertEquals(timestamp.longValue(), generatedRecord.getTimestamp());
117+
assertEquals(numberOfShards, generatedRecord.getAttributes().get(Attribute.TOTAL_SHARDS));
118+
assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE));
119+
assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE));
120+
Map<String, String> labels = (Map<String, String>) generatedRecord.getAttributes().get(Attribute.LABELS);
121+
assertEquals("labelValue", labels.get("labelKey"));
122+
assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID));
108123
}
109124

110125
public void testConcurrentOnRequestEnd() throws InterruptedException {

server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.threadpool.ThreadPool;
1515

1616
import java.util.List;
17+
import java.util.Locale;
1718
import java.util.Map;
1819
import java.util.stream.Collectors;
1920

@@ -25,7 +26,7 @@ public class RequestLabelingService {
2526
/**
2627
* Field name for computed labels
2728
*/
28-
public static final String COMPUTED_LABELS = "computed_labels";
29+
public static final String RULE_BASED_LABELS = "rule_based_labels";
2930
private final ThreadPool threadPool;
3031
private final List<Rule> rules;
3132

@@ -35,42 +36,34 @@ public RequestLabelingService(final ThreadPool threadPool, final List<Rule> rule
3536
}
3637

3738
/**
38-
* Get all the existing rules
39-
*
40-
* @return list of existing rules
41-
*/
42-
public List<Rule> getRules() {
43-
return rules;
44-
}
45-
46-
/**
47-
* Add a labeling rule to the service
39+
* Evaluate all labeling rules and store the computed rules into thread context
4840
*
49-
* @param rule {@link Rule}
41+
* @param searchRequest {@link SearchRequest}
5042
*/
51-
public void addRule(final Rule rule) {
52-
this.rules.add(rule);
43+
public void applyAllRules(final SearchRequest searchRequest) {
44+
Map<String, Object> labels = rules.stream()
45+
.map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest))
46+
.flatMap(m -> m.entrySet().stream())
47+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement));
48+
String userProvidedTag = getUserProvidedTag(threadPool);
49+
if (labels.containsKey(Task.X_OPAQUE_ID) && userProvidedTag.equals(labels.get(Task.X_OPAQUE_ID))) {
50+
throw new IllegalArgumentException(
51+
String.format(Locale.ROOT, "Unexpected label %s found: %s", Task.X_OPAQUE_ID, userProvidedTag)
52+
);
53+
}
54+
threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, labels);
5355
}
5456

5557
/**
5658
* Get the user provided tag from the X-Opaque-Id header
5759
*
5860
* @return user provided tag
5961
*/
60-
public String getUserProvidedTag() {
62+
public static String getUserProvidedTag(ThreadPool threadPool) {
6163
return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null);
6264
}
6365

64-
/**
65-
* Evaluate all labeling rules and store the computed rules into thread context
66-
*
67-
* @param searchRequest {@link SearchRequest}
68-
*/
69-
public void applyAllRules(final SearchRequest searchRequest) {
70-
Map<String, Object> labels = rules.stream()
71-
.map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest))
72-
.flatMap(m -> m.entrySet().stream())
73-
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement));
74-
threadPool.getThreadContext().putTransient(COMPUTED_LABELS, labels);
66+
public static Map<String, Object> getRuleBasedLabels(ThreadPool threadPool) {
67+
return threadPool.getThreadContext().getTransient(RequestLabelingService.RULE_BASED_LABELS);
7568
}
7669
}

server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
package org.opensearch.search.labels;
1010

11-
import org.opensearch.action.search.SearchPhaseContext;
1211
import org.opensearch.action.search.SearchRequestContext;
1312
import org.opensearch.action.search.SearchRequestOperationsListener;
1413

@@ -29,7 +28,4 @@ public void onRequestStart(SearchRequestContext searchRequestContext) {
2928
// add tags to search request
3029
requestLabelingService.applyAllRules(searchRequestContext.getRequest());
3130
}
32-
33-
@Override
34-
public void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}
3531
}

server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,10 @@ public void setUpVariables() {
4242
when(threadPool.getThreadContext()).thenReturn(threadContext);
4343
}
4444

45-
public void testAddRule() {
46-
Rule mockRule = mock(Rule.class);
47-
requestLabelingService.addRule(mockRule);
48-
List<Rule> rules = requestLabelingService.getRules();
49-
assertEquals(1, rules.size());
50-
assertEquals(mockRule, rules.get(0));
51-
}
52-
5345
public void testGetUserProvidedTag() {
5446
String expectedTag = "test-tag";
5547
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>()));
56-
String actualTag = requestLabelingService.getUserProvidedTag();
48+
String actualTag = RequestLabelingService.getUserProvidedTag(threadPool);
5749
assertEquals(expectedTag, actualTag);
5850
}
5951

@@ -63,7 +55,7 @@ public void testBasicApplyAllRules() {
6355
when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap);
6456
rules.add(mockRule1);
6557
requestLabelingService.applyAllRules(mockSearchRequest);
66-
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
58+
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
6759
assertEquals(1, computedLabels.size());
6860
assertEquals("value1", computedLabels.get("label1"));
6961
}
@@ -77,7 +69,7 @@ public void testApplyAllRulesWithConflict() {
7769
rules.add(mockRule1);
7870
rules.add(mockRule2);
7971
requestLabelingService.applyAllRules(mockSearchRequest);
80-
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
72+
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
8173
assertEquals(1, computedLabels.size());
8274
assertEquals("value2", computedLabels.get("conflictingLabel"));
8375
}
@@ -91,7 +83,7 @@ public void testApplyAllRulesWithoutConflict() {
9183
rules.add(mockRule1);
9284
rules.add(mockRule2);
9385
requestLabelingService.applyAllRules(mockSearchRequest);
94-
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
86+
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
9587
assertEquals(2, computedLabels.size());
9688
assertEquals("value1", computedLabels.get("label1"));
9789
assertEquals("value2", computedLabels.get("label2"));

0 commit comments

Comments
 (0)