Skip to content

Commit ce15448

Browse files
authored
Schedule request in worker thread (#748)
Signed-off-by: penghuo <[email protected]>
1 parent 23a4a88 commit ce15448

File tree

4 files changed

+130
-45
lines changed

4 files changed

+130
-45
lines changed

legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.rest.RestStatus.BAD_REQUEST;
1010
import static org.opensearch.rest.RestStatus.OK;
1111
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
12+
import static org.opensearch.sql.opensearch.executor.Scheduler.schedule;
1213

1314
import com.alibaba.druid.sql.parser.ParserException;
1415
import com.google.common.collect.ImmutableList;
@@ -147,19 +148,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
147148

148149
Format format = SqlRequestParam.getFormat(request.params());
149150

150-
// Route request to new query engine if it's supported already
151-
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
152-
sqlRequest.getSql(), request.path(), request.params());
153-
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
154-
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
155-
LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId());
156-
return result;
157-
}
158-
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
159-
QueryContext.getRequestId(), newSqlRequest);
160-
161-
final QueryAction queryAction = explainRequest(client, sqlRequest, format);
162-
return channel -> executeSqlRequest(request, queryAction, client, channel);
151+
return channel -> schedule(client, () -> {
152+
try {
153+
// Route request to new query engine if it's supported already
154+
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
155+
sqlRequest.getSql(), request.path(), request.params());
156+
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
157+
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
158+
LOG.info("[{}] Request is handled by new SQL query engine",
159+
QueryContext.getRequestId());
160+
result.accept(channel);
161+
} else {
162+
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
163+
QueryContext.getRequestId(), newSqlRequest);
164+
QueryAction queryAction = explainRequest(client, sqlRequest, format);
165+
executeSqlRequest(request, queryAction, client, channel);
166+
}
167+
} catch (Exception e) {
168+
logAndPublishMetrics(e);
169+
reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
170+
}
171+
});
163172
} catch (Exception e) {
164173
logAndPublishMetrics(e);
165174
return channel -> reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.executor;
7+
8+
import java.util.Map;
9+
import lombok.experimental.UtilityClass;
10+
import org.apache.logging.log4j.ThreadContext;
11+
import org.opensearch.client.node.NodeClient;
12+
import org.opensearch.common.unit.TimeValue;
13+
import org.opensearch.threadpool.ThreadPool;
14+
15+
/** The scheduler which schedule the task run in sql-worker thread pool. */
16+
@UtilityClass
17+
public class Scheduler {
18+
19+
public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker";
20+
21+
public static void schedule(NodeClient client, Runnable task) {
22+
ThreadPool threadPool = client.threadPool();
23+
threadPool.schedule(withCurrentContext(task), new TimeValue(0), SQL_WORKER_THREAD_POOL_NAME);
24+
}
25+
26+
private static Runnable withCurrentContext(final Runnable task) {
27+
final Map<String, String> currentContext = ThreadContext.getImmutableContext();
28+
return () -> {
29+
ThreadContext.putAll(currentContext);
30+
task.run();
31+
};
32+
}
33+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.executor;
7+
8+
import static org.junit.Assert.assertTrue;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.mock;
12+
import static org.mockito.Mockito.when;
13+
14+
import java.util.concurrent.atomic.AtomicBoolean;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.ExtendWith;
17+
import org.mockito.junit.jupiter.MockitoExtension;
18+
import org.opensearch.client.node.NodeClient;
19+
import org.opensearch.threadpool.ThreadPool;
20+
21+
@ExtendWith(MockitoExtension.class)
22+
class SchedulerTest {
23+
@Test
24+
public void schedule() {
25+
NodeClient nodeClient = mock(NodeClient.class);
26+
ThreadPool threadPool = mock(ThreadPool.class);
27+
when(nodeClient.threadPool()).thenReturn(threadPool);
28+
29+
doAnswer(
30+
invocation -> {
31+
Runnable task = invocation.getArgument(0);
32+
task.run();
33+
return null;
34+
})
35+
.when(threadPool)
36+
.schedule(any(), any(), any());
37+
AtomicBoolean isRun = new AtomicBoolean(false);
38+
Scheduler.schedule(nodeClient, () -> isRun.set(true));
39+
assertTrue(isRun.get());
40+
}
41+
}

plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;
1010
import static org.opensearch.rest.RestStatus.OK;
1111
import static org.opensearch.rest.RestStatus.SERVICE_UNAVAILABLE;
12+
import static org.opensearch.sql.opensearch.executor.Scheduler.schedule;
1213

1314
import com.google.common.collect.ImmutableList;
1415
import java.util.Arrays;
@@ -112,42 +113,43 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nod
112113
PPLQueryRequestFactory.getPPLRequest(request)
113114
);
114115

115-
return channel ->
116-
nodeClient.execute(
117-
PPLQueryAction.INSTANCE,
118-
transportPPLQueryRequest,
119-
new ActionListener<>() {
120-
@Override
121-
public void onResponse(TransportPPLQueryResponse response) {
122-
sendResponse(channel, OK, response.getResult());
123-
}
124-
125-
@Override
126-
public void onFailure(Exception e) {
127-
if (transportPPLQueryRequest.isExplainRequest()) {
128-
LOG.error("Error happened during explain", e);
129-
sendResponse(
130-
channel,
131-
INTERNAL_SERVER_ERROR,
132-
"Failed to explain the query due to error: " + e.getMessage());
133-
} else if (e instanceof IllegalAccessException) {
116+
return channel -> schedule(nodeClient, () ->
117+
nodeClient.execute(
118+
PPLQueryAction.INSTANCE,
119+
transportPPLQueryRequest,
120+
new ActionListener<>() {
121+
@Override
122+
public void onResponse(TransportPPLQueryResponse response) {
123+
sendResponse(channel, OK, response.getResult());
124+
}
125+
126+
@Override
127+
public void onFailure(Exception e) {
128+
if (transportPPLQueryRequest.isExplainRequest()) {
129+
LOG.error("Error happened during explain", e);
130+
sendResponse(
131+
channel,
132+
INTERNAL_SERVER_ERROR,
133+
"Failed to explain the query due to error: " + e.getMessage());
134+
} else if (e instanceof IllegalAccessException) {
135+
reportError(channel, e, BAD_REQUEST);
136+
} else {
137+
LOG.error("Error happened during query handling", e);
138+
if (isClientError(e)) {
139+
Metrics.getInstance()
140+
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS)
141+
.increment();
134142
reportError(channel, e, BAD_REQUEST);
135143
} else {
136-
LOG.error("Error happened during query handling", e);
137-
if (isClientError(e)) {
138-
Metrics.getInstance()
139-
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS)
140-
.increment();
141-
reportError(channel, e, BAD_REQUEST);
142-
} else {
143-
Metrics.getInstance()
144-
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
145-
.increment();
146-
reportError(channel, e, SERVICE_UNAVAILABLE);
147-
}
144+
Metrics.getInstance()
145+
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
146+
.increment();
147+
reportError(channel, e, SERVICE_UNAVAILABLE);
148148
}
149149
}
150-
});
150+
}
151+
}
152+
));
151153
}
152154

153155
private void sendResponse(RestChannel channel, RestStatus status, String content) {

0 commit comments

Comments
 (0)