Skip to content

Commit 2f4195c

Browse files
committed
add required ibis api
1 parent 3c275db commit 2f4195c

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

ibis-server/app/mdl/analyzer.py

+13
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,16 @@ def analyze(manifest_str: str, sql: str) -> list[dict]:
1717
return r.json() if r.status_code == httpx.codes.OK else r.raise_for_status()
1818
except httpx.ConnectError as e:
1919
raise ConnectionError(f"Can not connect to Wren Engine: {e}") from e
20+
21+
22+
def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]:
23+
try:
24+
r = httpx.request(
25+
method="GET",
26+
url=f"{wren_engine_endpoint}/v2/analysis/sqls",
27+
headers={"Content-Type": "application/json", "Accept": "application/json"},
28+
content=orjson.dumps({"manifestStr": manifest_str, "sqls": sqls}),
29+
)
30+
return r.json() if r.status_code == httpx.codes.OK else r.raise_for_status()
31+
except httpx.ConnectError as e:
32+
raise ConnectionError(f"Can not connect to Wren Engine: {e}") from e

ibis-server/app/model/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class AnalyzeSQLDTO(BaseModel):
134134
sql: str
135135

136136

137+
class AnalyzeSQLBatchDTO(BaseModel):
138+
manifest_str: str = manifest_str_field
139+
sqls: list[str]
140+
141+
137142
class DryPlanDTO(BaseModel):
138143
manifest_str: str = manifest_str_field
139144
sql: str
+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from fastapi import APIRouter
22

33
from app.logger import log_dto
4-
from app.mdl.analyzer import analyze
5-
from app.model import AnalyzeSQLDTO
4+
from app.mdl.analyzer import analyze, analyze_batch
5+
from app.model import AnalyzeSQLBatchDTO, AnalyzeSQLDTO
66

77
router = APIRouter(prefix="/analysis", tags=["analysis"])
88

@@ -11,3 +11,9 @@
1111
@log_dto
1212
def analyze_sql(dto: AnalyzeSQLDTO) -> list[dict]:
1313
return analyze(dto.manifest_str, dto.sql)
14+
15+
16+
@router.get("/sqls")
17+
@log_dto
18+
def analyze_sql_batch(dto: AnalyzeSQLBatchDTO) -> list[list[dict]]:
19+
return analyze_batch(dto.manifest_str, dto.sqls)

ibis-server/tests/routers/v2/test_analysis.py

+32
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,25 @@ def test_analysis_sql_order_by():
212212
assert result[0]["sortings"][1]["nodeLocation"] == {"line": 1, "column": 52}
213213

214214

215+
def test_analysis_sqls():
216+
result = get_sql_analysis_batch(
217+
{
218+
"manifestStr": manifest_str,
219+
"sqls": [
220+
"SELECT * FROM customer",
221+
"SELECT custkey, count(*) FROM customer GROUP BY 1",
222+
"WITH t1 AS (SELECT * FROM customer) SELECT * FROM t1",
223+
"SELECT * FROM orders WHERE orderkey = 1 UNION SELECT * FROM orders where orderkey = 2",
224+
],
225+
}
226+
)
227+
assert len(result) == 4
228+
assert len(result[0]) == 1
229+
assert len(result[1]) == 1
230+
assert len(result[2]) == 2
231+
assert len(result[3]) == 2
232+
233+
215234
def get_sql_analysis(input_dto):
216235
response = client.request(
217236
method="GET",
@@ -223,3 +242,16 @@ def get_sql_analysis(input_dto):
223242
)
224243
assert response.status_code == 200
225244
return response.json()
245+
246+
247+
def get_sql_analysis_batch(input_dto):
248+
response = client.request(
249+
method="GET",
250+
url="/v2/analysis/sqls",
251+
json={
252+
"manifestStr": input_dto["manifestStr"],
253+
"sqls": input_dto["sqls"],
254+
},
255+
)
256+
assert response.status_code == 200
257+
return response.json()

0 commit comments

Comments
 (0)