Skip to content

feat(ibis): pushdown the limit of the query request into SQL #1082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def rewrite(self, manifest_str: str, sql: str) -> str:

@staticmethod
def handle_extract_exception(e: Exception):
logger.error("Error when extracting manifest: {}", e)
logger.warning("Error when extracting manifest: {}", e)


class EmbeddedEngineRewriter:
Expand Down
11 changes: 9 additions & 2 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastapi import APIRouter, Depends, Header, Query, Request, Response
from fastapi.responses import ORJSONResponse
from loguru import logger
from opentelemetry import trace

from app.dependencies import verify_query_dto
Expand All @@ -19,7 +20,7 @@
from app.model.metadata.dto import Constraint, MetadataDTO, Table
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.util import build_context, to_json
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
tracer = trace.get_tracer(__name__)
Expand All @@ -44,11 +45,17 @@ async def query(
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
try:
sql = pushdown_limit(dto.sql, limit)
except Exception as e:
logger.warning("Failed to pushdown limit. Using original SQL: {}", e)
sql = dto.sql

rewritten_sql = await Rewriter(
dto.manifest_str,
data_source=data_source,
java_engine_connector=java_engine_connector,
).rewrite(dto.sql)
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
if dry_run:
connector.dry_run(rewritten_sql)
Expand Down
11 changes: 9 additions & 2 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastapi import APIRouter, Depends, Header, Query, Response
from fastapi.responses import ORJSONResponse
from loguru import logger
from opentelemetry import trace

from app.config import get_config
Expand All @@ -18,7 +19,7 @@
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.validator import Validator
from app.util import build_context, to_json
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
tracer = trace.get_tracer(__name__)
Expand All @@ -38,9 +39,15 @@ async def query(
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
try:
sql = pushdown_limit(dto.sql, limit)
except Exception:
logger.warning("Failed to pushdown limit. Using original SQL")
sql = dto.sql

rewritten_sql = await Rewriter(
dto.manifest_str, data_source=data_source, experiment=True
).rewrite(dto.sql)
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
if dry_run:
connector.dry_run(rewritten_sql)
Expand Down
7 changes: 7 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import orjson
import pandas as pd
import wren_core
from fastapi import Header
from opentelemetry import trace
from opentelemetry.context import Context
Expand Down Expand Up @@ -98,3 +99,9 @@ def build_context(headers: Header) -> Context:
if headers is None:
return None
return extract(headers)


@tracer.start_as_current_span("pushdown_limit", kind=trace.SpanKind.INTERNAL)
def pushdown_limit(sql: str, limit: int | None) -> str:
ctx = wren_core.SessionContext()
return ctx.pushdown_limit(sql, limit)
30 changes: 30 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,36 @@ async def test_format_floating(client, manifest_str, postgres):
assert result["data"][0][26] == "12300.00000"


async def test_limit_pushdown(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
url=f"{base_url}/query",
params={"limit": 10},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 100',
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 10

connection_info = _to_connection_info(postgres)
response = await client.post(
url=f"{base_url}/query",
params={"limit": 10},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT count(*) FILTER (where orderkey > 10) FROM "Orders" LIMIT 100',
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1


async def test_dry_run_with_connection_url_and_password_with_bracket_should_not_raise_value_error(
client, manifest_str, postgres: PostgresContainer
):
Expand Down
15 changes: 15 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,18 @@ async def test_query_with_keyword_filter(client, manifest_str, connection_info):
)
assert response.status_code == 200
assert response.text is not None


async def test_limit_pushdown(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
params={"limit": 10},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM orders LIMIT 100",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 10
41 changes: 40 additions & 1 deletion wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ use pyo3::{pyclass, pymethods, PyErr, PyResult};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::hash::Hash;
use std::ops::ControlFlow;
use std::sync::Arc;
use wren_core::ast::{visit_statements_mut, Expr, Statement, Value};
use wren_core::dialect::GenericDialect;
use wren_core::logical_plan::utils::map_data_type;
use wren_core::mdl::context::create_ctx_with_mdl;
use wren_core::mdl::function::{
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
RemoteFunction,
};
use wren_core::{mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, WindowUDF};

/// The Python wrapper for the Wren Core session context.
#[pyclass(name = "SessionContext")]
#[derive(Clone)]
Expand Down Expand Up @@ -188,6 +190,43 @@ impl PySessionContext {
});
Ok(builder.values().cloned().collect())
}

/// Push down the limit to the given SQL.
/// If the limit is None, the SQL will be returned as is.
/// If the limit is greater than the pushdown limit, the limit will be replaced with the pushdown limit.
/// Otherwise, the limit will be kept as is.
#[pyo3(signature = (sql, limit=None))]
pub fn pushdown_limit(&self, sql: &str, limit: Option<usize>) -> PyResult<String> {
if limit.is_none() {
return Ok(sql.to_string());
}
let pushdown = limit.unwrap();
let mut statements =
wren_core::parser::Parser::parse_sql(&GenericDialect {}, sql)
.map_err(CoreError::from)?;
if statements.len() != 1 {
return Err(CoreError::new("Only one statement is allowed").into());
}
visit_statements_mut(&mut statements, |stmt| {
if let Statement::Query(q) = stmt {
if let Some(limit) = &q.limit {
if let Expr::Value(Value::Number(n, is)) = limit {
if n.parse::<usize>().unwrap() > pushdown {
q.limit = Some(Expr::Value(Value::Number(
pushdown.to_string(),
is.clone(),
)));
}
}
} else {
q.limit =
Some(Expr::Value(Value::Number(pushdown.to_string(), false)));
}
}
ControlFlow::<()>::Continue(())
});
Ok(statements[0].to_string())
}
}

impl PySessionContext {
Expand Down
13 changes: 13 additions & 0 deletions wren-core-py/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::num::ParseIntError;
use base64::DecodeError;
use pyo3::exceptions::PyException;
use pyo3::PyErr;
Expand Down Expand Up @@ -54,6 +55,18 @@ impl From<wren_core::DataFusionError> for CoreError {
}
}

impl From<wren_core::parser::ParserError> for CoreError {
fn from(err: wren_core::parser::ParserError) -> Self {
CoreError::new(&format!("Parser error: {}", err))
}
}

impl From<ParseIntError> for CoreError {
fn from(err: ParseIntError) -> Self {
CoreError::new(&format!("ParseIntError: {}", err))
}
}

impl From<csv::Error> for CoreError {
fn from(err: csv::Error) -> Self {
CoreError::new(&format!("CSV error: {}", err))
Expand Down
33 changes: 33 additions & 0 deletions wren-core-py/tests/test_modeling_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,36 @@ def test_to_json_base64():
decoded_manifest = json.loads(json_str)
assert decoded_manifest["catalog"] == "my_catalog"
assert len(decoded_manifest["models"]) == 3


def test_limit_pushdown():
session_context = SessionContext()
sql = "SELECT * FROM my_catalog.my_schema.customer"
assert (
session_context.pushdown_limit(sql, 10)
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
)

sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100"
assert (
session_context.pushdown_limit(sql, 10)
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
)

sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
assert (
session_context.pushdown_limit(sql, 100)
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10"
)

sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
assert (
session_context.pushdown_limit(sql, 100)
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
)

sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100 OFFSET 5"
assert (
session_context.pushdown_limit(sql, 10)
== "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5"
)
1 change: 1 addition & 0 deletions wren-core/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub mod mdl;
pub use datafusion::error::DataFusionError;
pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
pub use datafusion::prelude::SessionContext;
pub use datafusion::sql::sqlparser::*;
pub use mdl::AnalyzedWrenMDL;