Skip to content

Commit 5230d86

Browse files
committed
Provide the batch SQL analysis API (#663)
* remove unused api * add Deprecated annotation for v1 api * add analysis batch api * add required ibis api * enhance python test * fix test
1 parent b06f733 commit 5230d86

File tree

9 files changed

+185
-63
lines changed

9 files changed

+185
-63
lines changed

ibis-server/app/mdl/analyzer.py

+13
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,16 @@ def analyze(manifest_str: str, sql: str) -> list[dict]:
1717
return r.json() if r.status_code == httpx.codes.OK else r.raise_for_status()
1818
except httpx.ConnectError as e:
1919
raise ConnectionError(f"Can not connect to Wren Engine: {e}") from e
20+
21+
22+
def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]:
23+
try:
24+
r = httpx.request(
25+
method="GET",
26+
url=f"{wren_engine_endpoint}/v2/analysis/sqls",
27+
headers={"Content-Type": "application/json", "Accept": "application/json"},
28+
content=orjson.dumps({"manifestStr": manifest_str, "sqls": sqls}),
29+
)
30+
return r.json() if r.status_code == httpx.codes.OK else r.raise_for_status()
31+
except httpx.ConnectError as e:
32+
raise ConnectionError(f"Can not connect to Wren Engine: {e}") from e

ibis-server/app/model/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class AnalyzeSQLDTO(BaseModel):
134134
sql: str
135135

136136

137+
class AnalyzeSQLBatchDTO(BaseModel):
138+
manifest_str: str = manifest_str_field
139+
sqls: list[str]
140+
141+
137142
class DryPlanDTO(BaseModel):
138143
manifest_str: str = manifest_str_field
139144
sql: str
+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from fastapi import APIRouter
22

33
from app.logger import log_dto
4-
from app.mdl.analyzer import analyze
5-
from app.model import AnalyzeSQLDTO
4+
from app.mdl.analyzer import analyze, analyze_batch
5+
from app.model import AnalyzeSQLBatchDTO, AnalyzeSQLDTO
66

77
router = APIRouter(prefix="/analysis", tags=["analysis"])
88

@@ -11,3 +11,9 @@
1111
@log_dto
1212
def analyze_sql(dto: AnalyzeSQLDTO) -> list[dict]:
1313
return analyze(dto.manifest_str, dto.sql)
14+
15+
16+
@router.get("/sqls")
17+
@log_dto
18+
def analyze_sql_batch(dto: AnalyzeSQLBatchDTO) -> list[list[dict]]:
19+
return analyze_batch(dto.manifest_str, dto.sqls)

ibis-server/tests/routers/v2/test_analysis.py

+38
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,31 @@ def test_analysis_sql_order_by(manifest_hash):
213213
assert result[0]["sortings"][1]["nodeLocation"] == {"line": 1, "column": 52}
214214

215215

216+
def test_analysis_sqls(manifest_hash):
217+
result = get_sql_analysis_batch(
218+
{
219+
"manifestStr": manifest_hash,
220+
"sqls": [
221+
"SELECT * FROM customer",
222+
"SELECT custkey, count(*) FROM customer GROUP BY 1",
223+
"WITH t1 AS (SELECT * FROM customer) SELECT * FROM t1",
224+
"SELECT * FROM orders WHERE orderkey = 1 UNION SELECT * FROM orders where orderkey = 2",
225+
],
226+
}
227+
)
228+
assert len(result) == 4
229+
assert len(result[0]) == 1
230+
assert result[0][0]["relation"]["tableName"] == "customer"
231+
assert len(result[1]) == 1
232+
assert result[1][0]["relation"]["tableName"] == "customer"
233+
assert len(result[2]) == 2
234+
assert result[2][0]["relation"]["tableName"] == "customer"
235+
assert result[2][1]["relation"]["tableName"] == "t1"
236+
assert len(result[3]) == 2
237+
assert result[3][0]["relation"]["tableName"] == "orders"
238+
assert result[3][1]["relation"]["tableName"] == "orders"
239+
240+
216241
def get_sql_analysis(input_dto):
217242
response = client.request(
218243
method="GET",
@@ -224,3 +249,16 @@ def get_sql_analysis(input_dto):
224249
)
225250
assert response.status_code == 200
226251
return response.json()
252+
253+
254+
def get_sql_analysis_batch(input_dto):
255+
response = client.request(
256+
method="GET",
257+
url="/v2/analysis/sqls",
258+
json={
259+
"manifestStr": input_dto["manifestStr"],
260+
"sqls": input_dto["sqls"],
261+
},
262+
)
263+
assert response.status_code == 200
264+
return response.json()

wren-main/src/main/java/io/wren/main/web/AnalysisResource.java

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static io.wren.main.web.WrenExceptionMapper.bindAsyncResponse;
4444
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
4545

46+
@Deprecated
4647
@Path("/v1/analysis")
4748
public class AnalysisResource
4849
{

wren-main/src/main/java/io/wren/main/web/AnalysisResourceV2.java

+32
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.wren.base.WrenMDL;
2121
import io.wren.base.sqlrewrite.analyzer.decisionpoint.DecisionPointAnalyzer;
2222
import io.wren.main.ManifestCacheManager;
23+
import io.wren.main.web.dto.SqlAnalysisInputBatchDto;
2324
import io.wren.main.web.dto.SqlAnalysisInputDtoV2;
2425
import jakarta.ws.rs.Consumes;
2526
import jakarta.ws.rs.GET;
@@ -69,4 +70,35 @@ public void getSqlAnalysis(
6970
})
7071
.whenComplete(bindAsyncResponse(asyncResponse));
7172
}
73+
74+
@GET
75+
@Path("/sqls")
76+
@Consumes(APPLICATION_JSON)
77+
@Produces(APPLICATION_JSON)
78+
public void getSqlAnalysisBatch(
79+
SqlAnalysisInputBatchDto inputBatchDto,
80+
@Suspended AsyncResponse asyncResponse)
81+
{
82+
CompletableFuture
83+
.supplyAsync(() ->
84+
Optional.ofNullable(inputBatchDto.getManifestStr())
85+
.orElseThrow(() -> new IllegalArgumentException("Manifest is required")))
86+
.thenApply(manifestStr -> {
87+
try {
88+
return WrenMDL.fromJson(new String(Base64.getDecoder().decode(manifestStr), UTF_8));
89+
}
90+
catch (IOException e) {
91+
throw new RuntimeException(e);
92+
}
93+
})
94+
.thenApply(mdl ->
95+
inputBatchDto.getSqls().stream().map(sql -> {
96+
Statement statement = parseSql(sql);
97+
return DecisionPointAnalyzer.analyze(
98+
statement,
99+
SessionContext.builder().setCatalog(mdl.getCatalog()).setSchema(mdl.getSchema()).build(),
100+
mdl).stream().map(AnalysisResource::toQueryAnalysisDto).toList();
101+
}).toList())
102+
.whenComplete(bindAsyncResponse(asyncResponse));
103+
}
72104
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package io.wren.main.web.dto;
16+
17+
import com.fasterxml.jackson.annotation.JsonCreator;
18+
import com.fasterxml.jackson.annotation.JsonProperty;
19+
20+
import java.util.List;
21+
22+
public class SqlAnalysisInputBatchDto
23+
{
24+
private final String manifestStr;
25+
private final List<String> sqls;
26+
27+
@JsonCreator
28+
public SqlAnalysisInputBatchDto(String manifestStr, List<String> sqls)
29+
{
30+
this.manifestStr = manifestStr;
31+
this.sqls = sqls == null ? List.of() : sqls;
32+
}
33+
34+
@JsonProperty
35+
public String getManifestStr()
36+
{
37+
return manifestStr;
38+
}
39+
40+
@JsonProperty
41+
public List<String> getSqls()
42+
{
43+
return sqls;
44+
}
45+
}

wren-tests/src/test/java/io/wren/testing/RequireWrenServer.java

+12-61
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
import io.wren.base.dto.Manifest;
3131
import io.wren.main.connector.duckdb.DuckDBMetadata;
3232
import io.wren.main.validation.ValidationResult;
33-
import io.wren.main.web.dto.CheckOutputDto;
34-
import io.wren.main.web.dto.DeployInputDto;
3533
import io.wren.main.web.dto.DryPlanDto;
3634
import io.wren.main.web.dto.DryPlanDtoV2;
3735
import io.wren.main.web.dto.DuckDBQueryDtoV2;
@@ -41,6 +39,7 @@
4139
import io.wren.main.web.dto.PreviewDtoV2;
4240
import io.wren.main.web.dto.QueryAnalysisDto;
4341
import io.wren.main.web.dto.QueryResultDto;
42+
import io.wren.main.web.dto.SqlAnalysisInputBatchDto;
4443
import io.wren.main.web.dto.SqlAnalysisInputDto;
4544
import io.wren.main.web.dto.SqlAnalysisInputDtoV2;
4645
import io.wren.main.web.dto.ValidateDto;
@@ -53,10 +52,6 @@
5352

5453
import java.io.IOException;
5554
import java.util.List;
56-
import java.util.concurrent.CompletableFuture;
57-
import java.util.concurrent.ExecutionException;
58-
import java.util.concurrent.TimeUnit;
59-
import java.util.concurrent.TimeoutException;
6055

6156
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
6257
import static io.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator;
@@ -81,11 +76,9 @@ public abstract class RequireWrenServer
8176
protected HttpClient client;
8277

8378
private static final JsonCodec<ErrorMessageDto> ERROR_CODEC = jsonCodec(ErrorMessageDto.class);
84-
private static final JsonCodec<CheckOutputDto> CHECK_OUTPUT_DTO_CODEC = jsonCodec(CheckOutputDto.class);
8579
public static final JsonCodec<Manifest> MANIFEST_JSON_CODEC = jsonCodec(Manifest.class);
8680
private static final JsonCodec<PreviewDto> PREVIEW_DTO_CODEC = jsonCodec(PreviewDto.class);
8781
private static final JsonCodec<PreviewDtoV2> PREVIEW_DTO_V2_CODEC = jsonCodec(PreviewDtoV2.class);
88-
private static final JsonCodec<DeployInputDto> DEPLOY_INPUT_DTO_JSON_CODEC = jsonCodec(DeployInputDto.class);
8982
private static final JsonCodec<SqlAnalysisInputDto> SQL_ANALYSIS_INPUT_DTO_CODEC = jsonCodec(SqlAnalysisInputDto.class);
9083
private static final JsonCodec<SqlAnalysisInputDtoV2> SQL_ANALYSIS_INPUT_DTO_V2_CODEC = jsonCodec(SqlAnalysisInputDtoV2.class);
9184
private static final JsonCodec<ConfigManager.ConfigEntry> CONFIG_ENTRY_JSON_CODEC = jsonCodec(ConfigManager.ConfigEntry.class);
@@ -100,6 +93,8 @@ public abstract class RequireWrenServer
10093
private static final JsonCodec<List<QueryAnalysisDto>> QUERY_ANALYSIS_DTO_LIST_CODEC = listJsonCodec(QueryAnalysisDto.class);
10194
private static final JsonCodec<DuckDBQueryDtoV2> DUCKDB_QUERY_DTO_V2_JSON_CODEC = jsonCodec(DuckDBQueryDtoV2.class);
10295
private static final JsonCodec<DuckDBValidationDtoV2> DUCKDB_VALIDATION_DTO_V2_JSON_CODEC = jsonCodec(DuckDBValidationDtoV2.class);
96+
private static final JsonCodec<SqlAnalysisInputBatchDto> SQL_ANALYSIS_INPUT_BATCH_DTO_CODEC = jsonCodec(SqlAnalysisInputBatchDto.class);
97+
private static final JsonCodec<List<List<QueryAnalysisDto>>> QUERY_ANALYSIS_DTO_LIST_LIST_CODEC = listJsonCodec(listJsonCodec(QueryAnalysisDto.class));
10398

10499
public RequireWrenServer() {}
105100

@@ -299,78 +294,34 @@ protected void validateDuckDBV2(String ruleName, DuckDBValidationDtoV2 duckDBVal
299294
}
300295
}
301296

302-
protected void deployMDL(DeployInputDto dto)
303-
{
304-
Request request = preparePost()
305-
.setUri(server().getHttpServerBasedUrl().resolve("/v1/mdl/deploy"))
306-
.setHeader(CONTENT_TYPE, "application/json")
307-
.setBodyGenerator(jsonBodyGenerator(DEPLOY_INPUT_DTO_JSON_CODEC, dto))
308-
.build();
309-
310-
StringResponseHandler.StringResponse response = executeHttpRequest(request, createStringResponseHandler());
311-
if (response.getStatusCode() != 202) {
312-
getWebApplicationException(response);
313-
}
314-
}
315-
316-
protected Manifest getCurrentManifest()
317-
{
318-
Request request = prepareGet()
319-
.setUri(server().getHttpServerBasedUrl().resolve("/v1/mdl"))
320-
.build();
321-
322-
StringResponseHandler.StringResponse response = executeHttpRequest(request, createStringResponseHandler());
323-
if (response.getStatusCode() != 200) {
324-
getWebApplicationException(response);
325-
}
326-
return MANIFEST_JSON_CODEC.fromJson(response.getBody());
327-
}
328-
329-
protected CheckOutputDto getDeployStatus()
297+
protected List<QueryAnalysisDto> getSqlAnalysis(SqlAnalysisInputDto inputDto)
330298
{
331299
Request request = prepareGet()
332-
.setUri(server().getHttpServerBasedUrl().resolve("/v1/mdl/status"))
300+
.setUri(server().getHttpServerBasedUrl().resolve("/v1/analysis/sql"))
301+
.setHeader(CONTENT_TYPE, "application/json")
302+
.setBodyGenerator(jsonBodyGenerator(SQL_ANALYSIS_INPUT_DTO_CODEC, inputDto))
333303
.build();
334304

335305
StringResponseHandler.StringResponse response = executeHttpRequest(request, createStringResponseHandler());
336306
if (response.getStatusCode() != 200) {
337307
getWebApplicationException(response);
338308
}
339-
return CHECK_OUTPUT_DTO_CODEC.fromJson(response.getBody());
340-
}
341-
342-
protected void waitUntilReady()
343-
throws ExecutionException, InterruptedException, TimeoutException
344-
{
345-
CompletableFuture.runAsync(() -> {
346-
while (true) {
347-
CheckOutputDto checkOutputDto = getDeployStatus();
348-
if (checkOutputDto.getStatus() == CheckOutputDto.Status.READY) {
349-
break;
350-
}
351-
try {
352-
Thread.sleep(1000);
353-
}
354-
catch (InterruptedException e) {
355-
throw new AssertionError("Status doesn't change to READY", e);
356-
}
357-
}
358-
}).get(60, TimeUnit.SECONDS);
309+
return QUERY_ANALYSIS_DTO_LIST_CODEC.fromJson(response.getBody());
359310
}
360311

361-
protected List<QueryAnalysisDto> getSqlAnalysis(SqlAnalysisInputDto inputDto)
312+
protected List<List<QueryAnalysisDto>> getSqlAnalysisBatch(SqlAnalysisInputBatchDto inputBatchDto)
362313
{
363314
Request request = prepareGet()
364-
.setUri(server().getHttpServerBasedUrl().resolve("/v1/analysis/sql"))
315+
.setUri(server().getHttpServerBasedUrl().resolve("/v2/analysis/sqls"))
365316
.setHeader(CONTENT_TYPE, "application/json")
366-
.setBodyGenerator(jsonBodyGenerator(SQL_ANALYSIS_INPUT_DTO_CODEC, inputDto))
317+
.setBodyGenerator(jsonBodyGenerator(SQL_ANALYSIS_INPUT_BATCH_DTO_CODEC, inputBatchDto))
367318
.build();
368319

369320
StringResponseHandler.StringResponse response = executeHttpRequest(request, createStringResponseHandler());
370321
if (response.getStatusCode() != 200) {
371322
getWebApplicationException(response);
372323
}
373-
return QUERY_ANALYSIS_DTO_LIST_CODEC.fromJson(response.getBody());
324+
return QUERY_ANALYSIS_DTO_LIST_LIST_CODEC.fromJson(response.getBody());
374325
}
375326

376327
protected List<QueryAnalysisDto> getSqlAnalysisV2(SqlAnalysisInputDtoV2 inputDto)

wren-tests/src/test/java/io/wren/testing/TestAnalysisResource.java

+31
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@
2222
import io.wren.base.sqlrewrite.analyzer.decisionpoint.FilterAnalysis;
2323
import io.wren.base.sqlrewrite.analyzer.decisionpoint.RelationAnalysis;
2424
import io.wren.main.web.dto.QueryAnalysisDto;
25+
import io.wren.main.web.dto.SqlAnalysisInputBatchDto;
2526
import io.wren.main.web.dto.SqlAnalysisInputDto;
2627
import org.testng.annotations.Test;
2728

2829
import java.io.IOException;
2930
import java.nio.file.Files;
3031
import java.nio.file.Path;
32+
import java.util.Base64;
3133
import java.util.List;
3234
import java.util.Set;
3335

3436
import static io.wren.base.dto.Model.onTableReference;
3537
import static io.wren.base.dto.TableReference.tableReference;
3638
import static io.wren.main.web.dto.NodeLocationDto.nodeLocationDto;
3739
import static io.wren.testing.AbstractTestFramework.DEFAULT_SESSION_CONTEXT;
40+
import static java.nio.charset.StandardCharsets.UTF_8;
3841
import static org.assertj.core.api.Assertions.assertThat;
3942

4043
public class TestAnalysisResource
@@ -184,4 +187,32 @@ public void testBasic()
184187
assertThat(result.get(0).getSortings().get(1).getOrdering()).isEqualTo(SortItem.Ordering.DESCENDING.name());
185188
assertThat(result.get(0).getSortings().get(1).getNodeLocation()).isEqualTo(nodeLocationDto(1, 52));
186189
}
190+
191+
@Test
192+
public void testBatchAnalysis()
193+
{
194+
SqlAnalysisInputBatchDto inputBatchDto = new SqlAnalysisInputBatchDto(
195+
base64Encode(toJson(manifest)),
196+
List.of("select * from customer",
197+
"select custkey, count(*) from customer group by 1",
198+
"with t1 as (select * from customer) select * from t1",
199+
"select * from orders where orderstatus = 'O' union select * from orders where orderstatus = 'F'"));
200+
201+
List<List<QueryAnalysisDto>> results = getSqlAnalysisBatch(inputBatchDto);
202+
assertThat(results.size()).isEqualTo(4);
203+
assertThat(results.get(0).size()).isEqualTo(1);
204+
assertThat(results.get(1).size()).isEqualTo(1);
205+
assertThat(results.get(2).size()).isEqualTo(2);
206+
assertThat(results.get(3).size()).isEqualTo(2);
207+
}
208+
209+
private String toJson(Manifest manifest)
210+
{
211+
return MANIFEST_JSON_CODEC.toJson(manifest);
212+
}
213+
214+
private String base64Encode(String str)
215+
{
216+
return Base64.getEncoder().encodeToString(str.getBytes(UTF_8));
217+
}
187218
}

0 commit comments

Comments
 (0)