Skip to content

Commit 63f6b33

Browse files
joshuali925goyamegh
andcommitted
Add direct-query module for opensearch based implementation
Co-authored-by: Megha Goyal <[email protected]> Signed-off-by: Joshua Li <[email protected]>
1 parent ac6059f commit 63f6b33

22 files changed

+2439
-3
lines changed

direct-query/build.gradle

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
plugins {
7+
id 'java-library'
8+
id "io.freefair.lombok"
9+
id 'jacoco'
10+
}
11+
12+
repositories {
13+
mavenCentral()
14+
}
15+
16+
dependencies {
17+
api project(':core')
18+
api project(':direct-query-core')
19+
implementation project(':protocol')
20+
implementation project(':opensearch')
21+
implementation project(':datasources')
22+
implementation project(':direct-query-core')
23+
implementation project(':async-query-core')
24+
25+
// Common dependencies
26+
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
27+
implementation group: 'org.json', name: 'json', version: '20231013'
28+
implementation group: 'commons-io', name: 'commons-io', version: "${commons_io_version}"
29+
30+
// Test dependencies
31+
testImplementation(platform("org.junit:junit-bom:5.9.3"))
32+
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.3'
33+
testImplementation group: 'org.mockito', name: 'mockito-core', version: "${mockito_version}"
34+
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: "${mockito_version}"
35+
36+
testCompileOnly('junit:junit:4.13.1') {
37+
exclude group: 'org.hamcrest', module: 'hamcrest-core'
38+
}
39+
testRuntimeOnly("org.junit.vintage:junit-vintage-engine") {
40+
exclude group: 'org.hamcrest', module: 'hamcrest-core'
41+
}
42+
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine") {
43+
exclude group: 'org.hamcrest', module: 'hamcrest-core'
44+
}
45+
testImplementation("org.opensearch.test:framework:${opensearch_version}")
46+
}
47+
48+
test {
49+
useJUnitPlatform()
50+
testLogging {
51+
events "failed"
52+
exceptionFormat "full"
53+
}
54+
}
55+
task junit4(type: Test) {
56+
useJUnitPlatform {
57+
includeEngines("junit-vintage")
58+
}
59+
systemProperty 'tests.security.manager', 'false'
60+
testLogging {
61+
events "failed"
62+
exceptionFormat "full"
63+
}
64+
}
65+
66+
jacocoTestReport {
67+
dependsOn test, junit4
68+
executionData test, junit4
69+
reports {
70+
html.required = true
71+
xml.required = true
72+
}
73+
afterEvaluate {
74+
classDirectories.setFrom(files(classDirectories.files.collect {
75+
fileTree(dir: it)
76+
}))
77+
}
78+
}
79+
80+
jacocoTestCoverageVerification {
81+
dependsOn test, junit4
82+
executionData test, junit4
83+
violationRules {
84+
rule {
85+
element = 'CLASS'
86+
excludes = [
87+
'org.opensearch.sql.directquery.transport.model.*'
88+
]
89+
limit {
90+
counter = 'LINE'
91+
minimum = 1.0
92+
}
93+
limit {
94+
counter = 'BRANCH'
95+
minimum = 1.0
96+
}
97+
}
98+
}
99+
afterEvaluate {
100+
classDirectories.setFrom(files(classDirectories.files.collect {
101+
fileTree(dir: it)
102+
}))
103+
}
104+
}
105+
check.dependsOn jacocoTestCoverageVerification
106+
jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.directquery.rest;
7+
8+
import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
9+
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
10+
import static org.opensearch.rest.RestRequest.Method.POST;
11+
12+
import com.fasterxml.jackson.databind.ObjectMapper;
13+
import com.google.common.collect.ImmutableList;
14+
import java.util.List;
15+
import java.util.Map;
16+
import java.util.Objects;
17+
import lombok.RequiredArgsConstructor;
18+
import org.apache.logging.log4j.LogManager;
19+
import org.apache.logging.log4j.Logger;
20+
import org.opensearch.OpenSearchException;
21+
import org.opensearch.core.action.ActionListener;
22+
import org.opensearch.core.rest.RestStatus;
23+
import org.opensearch.rest.BaseRestHandler;
24+
import org.opensearch.rest.BytesRestResponse;
25+
import org.opensearch.rest.RestChannel;
26+
import org.opensearch.rest.RestRequest;
27+
import org.opensearch.sql.common.setting.Settings;
28+
import org.opensearch.sql.datasource.client.exceptions.DataSourceClientException;
29+
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
30+
import org.opensearch.sql.datasources.utils.Scheduler;
31+
import org.opensearch.sql.directquery.rest.model.ExecuteDirectQueryRequest;
32+
import org.opensearch.sql.directquery.transport.TransportExecuteDirectQueryRequestAction;
33+
import org.opensearch.sql.directquery.transport.format.DirectQueryRequestConverter;
34+
import org.opensearch.sql.directquery.transport.model.ExecuteDirectQueryActionRequest;
35+
import org.opensearch.sql.directquery.transport.model.ExecuteDirectQueryActionResponse;
36+
import org.opensearch.sql.directquery.validator.DirectQueryRequestValidator;
37+
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
38+
import org.opensearch.sql.opensearch.util.RestRequestUtil;
39+
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
40+
import org.opensearch.transport.client.node.NodeClient;
41+
42+
@RequiredArgsConstructor
43+
public class RestDirectQueryManagementAction extends BaseRestHandler {
44+
45+
public static final String DIRECT_QUERY_ACTIONS = "direct_query_actions";
46+
public static final String BASE_DIRECT_QUERY_ACTION_URL =
47+
"/_plugins/_directquery/_query/{dataSources}";
48+
49+
private static final Logger LOG = LogManager.getLogger(RestDirectQueryManagementAction.class);
50+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
51+
private final OpenSearchSettings settings;
52+
53+
@Override
54+
public String getName() {
55+
return DIRECT_QUERY_ACTIONS;
56+
}
57+
58+
@Override
59+
public List<Route> routes() {
60+
return ImmutableList.of(new Route(POST, BASE_DIRECT_QUERY_ACTION_URL));
61+
}
62+
63+
@Override
64+
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) {
65+
// This line consumes the dataSources parameter from the path
66+
String dataSources = restRequest.param("dataSources");
67+
68+
// Also consume all other request parameters to prevent similar errors
69+
RestRequestUtil.consumeAllRequestParameters(restRequest);
70+
71+
if (!dataSourcesEnabled()) {
72+
return dataSourcesDisabledError(restRequest);
73+
}
74+
75+
if (Objects.requireNonNull(restRequest.method()) == POST) {
76+
return executeDirectQueryRequest(restRequest, nodeClient, dataSources);
77+
}
78+
return restChannel ->
79+
restChannel.sendResponse(
80+
new BytesRestResponse(
81+
RestStatus.METHOD_NOT_ALLOWED, String.valueOf(restRequest.method())));
82+
}
83+
84+
private RestChannelConsumer executeDirectQueryRequest(
85+
RestRequest restRequest, NodeClient nodeClient, String dataSources) {
86+
return restChannel -> {
87+
try {
88+
ExecuteDirectQueryRequest directQueryRequest =
89+
DirectQueryRequestConverter.fromXContentParser(restRequest.contentParser());
90+
91+
// If the datasource is not specified in the payload, use the path parameter
92+
if (directQueryRequest.getDataSources() == null) {
93+
directQueryRequest.setDataSources(dataSources);
94+
}
95+
96+
// Generate a session ID if one is not provided in the request
97+
if (directQueryRequest.getSessionId() == null) {
98+
directQueryRequest.setSessionId(java.util.UUID.randomUUID().toString());
99+
}
100+
101+
// Validate request using the dedicated validator
102+
DirectQueryRequestValidator.validateRequest(directQueryRequest);
103+
104+
Scheduler.schedule(
105+
nodeClient,
106+
() ->
107+
nodeClient.execute(
108+
TransportExecuteDirectQueryRequestAction.ACTION_TYPE,
109+
new ExecuteDirectQueryActionRequest(directQueryRequest),
110+
new ActionListener<>() {
111+
@Override
112+
public void onResponse(ExecuteDirectQueryActionResponse response) {
113+
// Format the response here at the REST layer using JsonResponseFormatter
114+
try {
115+
String formattedResponse = formatDirectQueryResponse(response);
116+
restChannel.sendResponse(
117+
new BytesRestResponse(
118+
RestStatus.OK,
119+
"application/json; charset=UTF-8",
120+
formattedResponse));
121+
} catch (Exception e) {
122+
handleException(e, restChannel, restRequest.method());
123+
}
124+
}
125+
126+
@Override
127+
public void onFailure(Exception e) {
128+
handleException(e, restChannel, restRequest.method());
129+
}
130+
}));
131+
} catch (Exception e) {
132+
handleException(e, restChannel, restRequest.method());
133+
}
134+
};
135+
}
136+
137+
/** Format the direct query response using JsonResponseFormatter */
138+
private String formatDirectQueryResponse(ExecuteDirectQueryActionResponse response) {
139+
try {
140+
// Create a formatter that converts the response to a pretty JSON format
141+
return new JsonResponseFormatter<ExecuteDirectQueryActionResponse>(
142+
JsonResponseFormatter.Style.PRETTY) {
143+
@Override
144+
protected Object buildJsonObject(ExecuteDirectQueryActionResponse response) {
145+
// Create a response object with the fields we want to expose
146+
return new DirectQueryResult(
147+
response.getQueryId(), response.getResults(), response.getSessionId());
148+
}
149+
}.format(response);
150+
} catch (Exception e) {
151+
LOG.error("Error formatting direct query response", e);
152+
return "{\"error\": \"" + e.getMessage() + "\"}";
153+
}
154+
}
155+
156+
/** Simple class to represent the formatted response */
157+
private static class DirectQueryResult {
158+
private final String queryId;
159+
private final Map<String, Object> results;
160+
private final String sessionId;
161+
162+
public DirectQueryResult(String queryId, Map<String, ?> results, String sessionId) {
163+
this.queryId = queryId;
164+
this.results = (Map<String, Object>) results;
165+
this.sessionId = sessionId;
166+
}
167+
168+
public String getQueryId() {
169+
return queryId;
170+
}
171+
172+
public Map<String, Object> getResults() {
173+
return results;
174+
}
175+
176+
public String getSessionId() {
177+
return sessionId;
178+
}
179+
}
180+
181+
private void handleException(
182+
Exception e, RestChannel restChannel, RestRequest.Method requestMethod) {
183+
if (e instanceof OpenSearchException) {
184+
OpenSearchException exception = (OpenSearchException) e;
185+
reportError(restChannel, exception, exception.status());
186+
} else {
187+
LOG.error("Error happened during request handling", e);
188+
if (isClientError(e)) {
189+
reportError(restChannel, e, BAD_REQUEST);
190+
} else {
191+
reportError(restChannel, e, INTERNAL_SERVER_ERROR);
192+
}
193+
}
194+
}
195+
196+
private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
197+
channel.sendResponse(
198+
new BytesRestResponse(status, new ErrorMessage(e, status.getStatus()).toString()));
199+
}
200+
201+
private static boolean isClientError(Exception e) {
202+
return e instanceof IllegalArgumentException
203+
|| e instanceof IllegalStateException
204+
|| e instanceof DataSourceClientException
205+
|| e instanceof IllegalAccessException;
206+
}
207+
208+
private boolean dataSourcesEnabled() {
209+
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
210+
}
211+
212+
private RestChannelConsumer dataSourcesDisabledError(RestRequest request) {
213+
RestRequestUtil.consumeAllRequestParameters(request);
214+
215+
return channel -> {
216+
reportError(
217+
channel,
218+
new IllegalAccessException(
219+
String.format("%s setting is false", Settings.Key.DATASOURCES_ENABLED.getKeyValue())),
220+
BAD_REQUEST);
221+
};
222+
}
223+
}

0 commit comments

Comments
 (0)