diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 65df64183..e0182013b 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -105,7 +105,7 @@ async def rewrite(self, manifest_str: str, sql: str) -> str: @staticmethod def handle_extract_exception(e: Exception): - logger.error("Error when extracting manifest: {}", e) + logger.warning("Error when extracting manifest: {}", e) class EmbeddedEngineRewriter: diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index 61d2ea514..500742f87 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, Header, Query, Request, Response from fastapi.responses import ORJSONResponse +from loguru import logger from opentelemetry import trace from app.dependencies import verify_query_dto @@ -19,7 +20,7 @@ from app.model.metadata.dto import Constraint, MetadataDTO, Table from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator -from app.util import build_context, to_json +from app.util import build_context, pushdown_limit, to_json router = APIRouter(prefix="/connector") tracer = trace.get_tracer(__name__) @@ -44,11 +45,17 @@ async def query( with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): + try: + sql = pushdown_limit(dto.sql, limit) + except Exception as e: + logger.warning("Failed to pushdown limit. Using original SQL: {}", e) + sql = dto.sql + rewritten_sql = await Rewriter( dto.manifest_str, data_source=data_source, java_engine_connector=java_engine_connector, - ).rewrite(dto.sql) + ).rewrite(sql) connector = Connector(data_source, dto.connection_info) if dry_run: connector.dry_run(rewritten_sql) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 4e8dd378a..d9e7e3006 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, Header, Query, Response from fastapi.responses import ORJSONResponse +from loguru import logger from opentelemetry import trace from app.config import get_config @@ -18,7 +19,7 @@ from app.model.connector import Connector from app.model.data_source import DataSource from app.model.validator import Validator -from app.util import build_context, to_json +from app.util import build_context, pushdown_limit, to_json router = APIRouter(prefix="/connector") tracer = trace.get_tracer(__name__) @@ -38,9 +39,15 @@ async def query( with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): + try: + sql = pushdown_limit(dto.sql, limit) + except Exception: + logger.warning("Failed to pushdown limit. Using original SQL") + sql = dto.sql + rewritten_sql = await Rewriter( dto.manifest_str, data_source=data_source, experiment=True - ).rewrite(dto.sql) + ).rewrite(sql) connector = Connector(data_source, dto.connection_info) if dry_run: connector.dry_run(rewritten_sql) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index ff03919df..9addf0759 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -4,6 +4,7 @@ import orjson import pandas as pd +import wren_core from fastapi import Header from opentelemetry import trace from opentelemetry.context import Context @@ -98,3 +99,9 @@ def build_context(headers: Header) -> Context: if headers is None: return None return extract(headers) + + +@tracer.start_as_current_span("pushdown_limit", kind=trace.SpanKind.INTERNAL) +def pushdown_limit(sql: str, limit: int | None) -> str: + ctx = wren_core.SessionContext() + return ctx.pushdown_limit(sql, limit) diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 64e1977db..4c6a11a6b 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -285,6 +285,36 @@ async def test_format_floating(client, manifest_str, postgres): assert result["data"][0][26] == "12300.00000" +async def test_limit_pushdown(client, manifest_str, postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = await client.post( + url=f"{base_url}/query", + params={"limit": 10}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 100', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 10 + + connection_info = _to_connection_info(postgres) + response = await client.post( + url=f"{base_url}/query", + params={"limit": 10}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT count(*) FILTER (where orderkey > 10) FROM "Orders" LIMIT 100', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + async def test_dry_run_with_connection_url_and_password_with_bracket_should_not_raise_value_error( client, manifest_str, postgres: PostgresContainer ): diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index fa2d237c4..039f70405 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -338,3 +338,18 @@ async def test_query_with_keyword_filter(client, manifest_str, connection_info): ) assert response.status_code == 200 assert response.text is not None + + +async def test_limit_pushdown(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + params={"limit": 10}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM orders LIMIT 100", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 10 diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index f17b77a81..fae3f88d6 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -23,7 +23,10 @@ use pyo3::{pyclass, pymethods, PyErr, PyResult}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::hash::Hash; +use std::ops::ControlFlow; use std::sync::Arc; +use wren_core::ast::{visit_statements_mut, Expr, Statement, Value}; +use wren_core::dialect::GenericDialect; use wren_core::logical_plan::utils::map_data_type; use wren_core::mdl::context::create_ctx_with_mdl; use wren_core::mdl::function::{ @@ -31,7 +34,6 @@ use wren_core::mdl::function::{ RemoteFunction, }; use wren_core::{mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, WindowUDF}; - /// The Python wrapper for the Wren Core session context. #[pyclass(name = "SessionContext")] #[derive(Clone)] @@ -188,6 +190,43 @@ impl PySessionContext { }); Ok(builder.values().cloned().collect()) } + + /// Push down the limit to the given SQL. + /// If the limit is None, the SQL will be returned as is. + /// If the limit is greater than the pushdown limit, the limit will be replaced with the pushdown limit. + /// Otherwise, the limit will be kept as is. + #[pyo3(signature = (sql, limit=None))] + pub fn pushdown_limit(&self, sql: &str, limit: Option) -> PyResult { + if limit.is_none() { + return Ok(sql.to_string()); + } + let pushdown = limit.unwrap(); + let mut statements = + wren_core::parser::Parser::parse_sql(&GenericDialect {}, sql) + .map_err(CoreError::from)?; + if statements.len() != 1 { + return Err(CoreError::new("Only one statement is allowed").into()); + } + visit_statements_mut(&mut statements, |stmt| { + if let Statement::Query(q) = stmt { + if let Some(limit) = &q.limit { + if let Expr::Value(Value::Number(n, is)) = limit { + if n.parse::().unwrap() > pushdown { + q.limit = Some(Expr::Value(Value::Number( + pushdown.to_string(), + is.clone(), + ))); + } + } + } else { + q.limit = + Some(Expr::Value(Value::Number(pushdown.to_string(), false))); + } + } + ControlFlow::<()>::Continue(()) + }); + Ok(statements[0].to_string()) + } } impl PySessionContext { diff --git a/wren-core-py/src/errors.rs b/wren-core-py/src/errors.rs index 56dc403eb..b6196d5c4 100644 --- a/wren-core-py/src/errors.rs +++ b/wren-core-py/src/errors.rs @@ -1,3 +1,4 @@ +use std::num::ParseIntError; use base64::DecodeError; use pyo3::exceptions::PyException; use pyo3::PyErr; @@ -54,6 +55,18 @@ impl From for CoreError { } } +impl From for CoreError { + fn from(err: wren_core::parser::ParserError) -> Self { + CoreError::new(&format!("Parser error: {}", err)) + } +} + +impl From for CoreError { + fn from(err: ParseIntError) -> Self { + CoreError::new(&format!("ParseIntError: {}", err)) + } +} + impl From for CoreError { fn from(err: csv::Error) -> Self { CoreError::new(&format!("CSV error: {}", err)) diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index d86880d90..81d54a692 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -205,3 +205,36 @@ def test_to_json_base64(): decoded_manifest = json.loads(json_str) assert decoded_manifest["catalog"] == "my_catalog" assert len(decoded_manifest["models"]) == 3 + + +def test_limit_pushdown(): + session_context = SessionContext() + sql = "SELECT * FROM my_catalog.my_schema.customer" + assert ( + session_context.pushdown_limit(sql, 10) + == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10" + ) + + sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100" + assert ( + session_context.pushdown_limit(sql, 10) + == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10" + ) + + sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10" + assert ( + session_context.pushdown_limit(sql, 100) + == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10" + ) + + sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5" + assert ( + session_context.pushdown_limit(sql, 100) + == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5" + ) + + sql = "SELECT * FROM my_catalog.my_schema.customer LIMIT 100 OFFSET 5" + assert ( + session_context.pushdown_limit(sql, 10) + == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5" + ) diff --git a/wren-core/core/src/lib.rs b/wren-core/core/src/lib.rs index 491d82071..b0686bb75 100644 --- a/wren-core/core/src/lib.rs +++ b/wren-core/core/src/lib.rs @@ -4,4 +4,5 @@ pub mod mdl; pub use datafusion::error::DataFusionError; pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; pub use datafusion::prelude::SessionContext; +pub use datafusion::sql::sqlparser::*; pub use mdl::AnalyzedWrenMDL;