Skip to content

Commit e7b3343

Browse files
authored
feat(ibis): pushdown the limit of the query request into SQL (#1082)
1 parent 8db5987 commit e7b3343

File tree

10 files changed

+158
-6
lines changed

10 files changed

+158
-6
lines changed

ibis-server/app/mdl/rewriter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def rewrite(self, manifest_str: str, sql: str) -> str:
105105

106106
@staticmethod
107107
def handle_extract_exception(e: Exception):
108-
logger.error("Error when extracting manifest: {}", e)
108+
logger.warning("Error when extracting manifest: {}", e)
109109

110110

111111
class EmbeddedEngineRewriter:

ibis-server/app/routers/v2/connector.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from fastapi import APIRouter, Depends, Header, Query, Request, Response
44
from fastapi.responses import ORJSONResponse
5+
from loguru import logger
56
from opentelemetry import trace
67

78
from app.dependencies import verify_query_dto
@@ -19,7 +20,7 @@
1920
from app.model.metadata.dto import Constraint, MetadataDTO, Table
2021
from app.model.metadata.factory import MetadataFactory
2122
from app.model.validator import Validator
22-
from app.util import build_context, to_json
23+
from app.util import build_context, pushdown_limit, to_json
2324

2425
router = APIRouter(prefix="/connector")
2526
tracer = trace.get_tracer(__name__)
@@ -44,11 +45,17 @@ async def query(
4445
with tracer.start_as_current_span(
4546
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
4647
):
48+
try:
49+
sql = pushdown_limit(dto.sql, limit)
50+
except Exception as e:
51+
logger.warning("Failed to pushdown limit. Using original SQL: {}", e)
52+
sql = dto.sql
53+
4754
rewritten_sql = await Rewriter(
4855
dto.manifest_str,
4956
data_source=data_source,
5057
java_engine_connector=java_engine_connector,
51-
).rewrite(dto.sql)
58+
).rewrite(sql)
5259
connector = Connector(data_source, dto.connection_info)
5360
if dry_run:
5461
connector.dry_run(rewritten_sql)

ibis-server/app/routers/v3/connector.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from fastapi import APIRouter, Depends, Header, Query, Response
44
from fastapi.responses import ORJSONResponse
5+
from loguru import logger
56
from opentelemetry import trace
67

78
from app.config import get_config
@@ -18,7 +19,7 @@
1819
from app.model.connector import Connector
1920
from app.model.data_source import DataSource
2021
from app.model.validator import Validator
21-
from app.util import build_context, to_json
22+
from app.util import build_context, pushdown_limit, to_json
2223

2324
router = APIRouter(prefix="/connector")
2425
tracer = trace.get_tracer(__name__)
@@ -38,9 +39,15 @@ async def query(
3839
with tracer.start_as_current_span(
3940
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
4041
):
42+
try:
43+
sql = pushdown_limit(dto.sql, limit)
44+
except Exception:
45+
logger.warning("Failed to pushdown limit. Using original SQL")
46+
sql = dto.sql
47+
4148
rewritten_sql = await Rewriter(
4249
dto.manifest_str, data_source=data_source, experiment=True
43-
).rewrite(dto.sql)
50+
).rewrite(sql)
4451
connector = Connector(data_source, dto.connection_info)
4552
if dry_run:
4653
connector.dry_run(rewritten_sql)

ibis-server/app/util.py

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import orjson
66
import pandas as pd
7+
import wren_core
78
from fastapi import Header
89
from opentelemetry import trace
910
from opentelemetry.context import Context
@@ -98,3 +99,9 @@ def build_context(headers: Header) -> Context:
9899
if headers is None:
99100
return None
100101
return extract(headers)
102+
103+
104+
@tracer.start_as_current_span("pushdown_limit", kind=trace.SpanKind.INTERNAL)
105+
def pushdown_limit(sql: str, limit: int | None) -> str:
106+
ctx = wren_core.SessionContext()
107+
return ctx.pushdown_limit(sql, limit)

ibis-server/tests/routers/v2/connector/test_postgres.py

+30
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,36 @@ async def test_format_floating(client, manifest_str, postgres):
285285
assert result["data"][0][26] == "12300.00000"
286286

287287

288+
async def test_limit_pushdown(client, manifest_str, postgres: PostgresContainer):
289+
connection_info = _to_connection_info(postgres)
290+
response = await client.post(
291+
url=f"{base_url}/query",
292+
params={"limit": 10},
293+
json={
294+
"connectionInfo": connection_info,
295+
"manifestStr": manifest_str,
296+
"sql": 'SELECT * FROM "Orders" LIMIT 100',
297+
},
298+
)
299+
assert response.status_code == 200
300+
result = response.json()
301+
assert len(result["data"]) == 10
302+
303+
connection_info = _to_connection_info(postgres)
304+
response = await client.post(
305+
url=f"{base_url}/query",
306+
params={"limit": 10},
307+
json={
308+
"connectionInfo": connection_info,
309+
"manifestStr": manifest_str,
310+
"sql": 'SELECT count(*) FILTER (where orderkey > 10) FROM "Orders" LIMIT 100',
311+
},
312+
)
313+
assert response.status_code == 200
314+
result = response.json()
315+
assert len(result["data"]) == 1
316+
317+
288318
async def test_dry_run_with_connection_url_and_password_with_bracket_should_not_raise_value_error(
289319
client, manifest_str, postgres: PostgresContainer
290320
):

ibis-server/tests/routers/v3/connector/postgres/test_query.py

+15
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,18 @@ async def test_query_with_keyword_filter(client, manifest_str, connection_info):
338338
)
339339
assert response.status_code == 200
340340
assert response.text is not None
341+
342+
343+
async def test_limit_pushdown(client, manifest_str, connection_info):
344+
response = await client.post(
345+
url=f"{base_url}/query",
346+
params={"limit": 10},
347+
json={
348+
"connectionInfo": connection_info,
349+
"manifestStr": manifest_str,
350+
"sql": "SELECT * FROM orders LIMIT 100",
351+
},
352+
)
353+
assert response.status_code == 200
354+
result = response.json()
355+
assert len(result["data"]) == 10

wren-core-py/src/context.rs

+40-1
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ use pyo3::{pyclass, pymethods, PyErr, PyResult};
2323
use std::collections::hash_map::Entry;
2424
use std::collections::HashMap;
2525
use std::hash::Hash;
26+
use std::ops::ControlFlow;
2627
use std::sync::Arc;
28+
use wren_core::ast::{visit_statements_mut, Expr, Statement, Value};
29+
use wren_core::dialect::GenericDialect;
2730
use wren_core::logical_plan::utils::map_data_type;
2831
use wren_core::mdl::context::create_ctx_with_mdl;
2932
use wren_core::mdl::function::{
3033
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
3134
RemoteFunction,
3235
};
3336
use wren_core::{mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, WindowUDF};
34-
3537
/// The Python wrapper for the Wren Core session context.
3638
#[pyclass(name = "SessionContext")]
3739
#[derive(Clone)]
@@ -188,6 +190,43 @@ impl PySessionContext {
188190
});
189191
Ok(builder.values().cloned().collect())
190192
}
193+
194+
/// Push down the limit to the given SQL.
195+
/// If the limit is None, the SQL will be returned as is.
196+
/// If the limit is greater than the pushdown limit, the limit will be replaced with the pushdown limit.
197+
/// Otherwise, the limit will be kept as is.
198+
#[pyo3(signature = (sql, limit=None))]
199+
pub fn pushdown_limit(&self, sql: &str, limit: Option<usize>) -> PyResult<String> {
200+
if limit.is_none() {
201+
return Ok(sql.to_string());
202+
}
203+
let pushdown = limit.unwrap();
204+
let mut statements =
205+
wren_core::parser::Parser::parse_sql(&GenericDialect {}, sql)
206+
.map_err(CoreError::from)?;
207+
if statements.len() != 1 {
208+
return Err(CoreError::new("Only one statement is allowed").into());
209+
}
210+
visit_statements_mut(&mut statements, |stmt| {
211+
if let Statement::Query(q) = stmt {
212+
if let Some(limit) = &q.limit {
213+
if let Expr::Value(Value::Number(n, is)) = limit {
214+
if n.parse::<usize>().unwrap() > pushdown {
215+
q.limit = Some(Expr::Value(Value::Number(
216+
pushdown.to_string(),
217+
is.clone(),
218+
)));
219+
}
220+
}
221+
} else {
222+
q.limit =
223+
Some(Expr::Value(Value::Number(pushdown.to_string(), false)));
224+
}
225+
}
226+
ControlFlow::<()>::Continue(())
227+
});
228+
Ok(statements[0].to_string())
229+
}
191230
}
192231

193232
impl PySessionContext {

wren-core-py/src/errors.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::num::ParseIntError;
12
use base64::DecodeError;
23
use pyo3::exceptions::PyException;
34
use pyo3::PyErr;
@@ -54,6 +55,18 @@ impl From<wren_core::DataFusionError> for CoreError {
5455
}
5556
}
5657

58+
impl From<wren_core::parser::ParserError> for CoreError {
59+
fn from(err: wren_core::parser::ParserError) -> Self {
60+
CoreError::new(&format!("Parser error: {}", err))
61+
}
62+
}
63+
64+
impl From<ParseIntError> for CoreError {
65+
fn from(err: ParseIntError) -> Self {
66+
CoreError::new(&format!("ParseIntError: {}", err))
67+
}
68+
}
69+
5770
impl From<csv::Error> for CoreError {
5871
fn from(err: csv::Error) -> Self {
5972
CoreError::new(&format!("CSV error: {}", err))

wren-core-py/tests/test_modeling_core.py

+33
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,36 @@ def test_to_json_base64():
205205
decoded_manifest = json.loads(json_str)
206206
assert decoded_manifest["catalog"] == "my_catalog"
207207
assert len(decoded_manifest["models"]) == 3
208+
209+
210+
def test_limit_pushdown():
211+
session_context = SessionContext()
212+
sql = "SELECT * FROM my_catalog.my_schema.customer"
213+
assert (
214+
session_context.pushdown_limit(sql, 10)
215+
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
216+
)
217+
218+
sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100"
219+
assert (
220+
session_context.pushdown_limit(sql, 10)
221+
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
222+
)
223+
224+
sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
225+
assert (
226+
session_context.pushdown_limit(sql, 100)
227+
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
228+
)
229+
230+
sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
231+
assert (
232+
session_context.pushdown_limit(sql, 100)
233+
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
234+
)
235+
236+
sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100 OFFSET 5"
237+
assert (
238+
session_context.pushdown_limit(sql, 10)
239+
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
240+
)

wren-core/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ pub mod mdl;
44
pub use datafusion::error::DataFusionError;
55
pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
66
pub use datafusion::prelude::SessionContext;
7+
pub use datafusion::sql::sqlparser::*;
78
pub use mdl::AnalyzedWrenMDL;

0 commit comments

Comments
 (0)