Skip to content

feat(ibis): persistent query cache #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 2, 2025
9 changes: 8 additions & 1 deletion ibis-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.mdl.java_engine import JavaEngineConnector
from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware
from app.model import ConfigModel, CustomHttpError
from app.query_cache import QueryCacheManager
from app.routers import v2, v3

get_config().init_logger()
Expand All @@ -22,12 +23,18 @@
# Use state to store the singleton instance
class State(TypedDict):
java_engine_connector: JavaEngineConnector
query_cache_manager: QueryCacheManager


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
query_cache_manager = QueryCacheManager()

async with JavaEngineConnector() as java_engine_connector:
yield {"java_engine_connector": java_engine_connector}
yield {
"java_engine_connector": java_engine_connector,
"query_cache_manager": query_cache_manager,
}


app = FastAPI(lifespan=lifespan)
Expand Down
1 change: 1 addition & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class QueryDTO(BaseModel):
sql: str
manifest_str: str = manifest_str_field
connection_info: ConnectionInfo = connection_info_field
enable_cache: bool = False


class QueryBigQueryDTO(QueryDTO):
Expand Down
56 changes: 56 additions & 0 deletions ibis-server/app/query_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import hashlib
import os
from typing import Any, Optional

import ibis
from loguru import logger


class QueryCacheManager:
def __init__(self):
pass

def get(self, data_source: str, sql: str, info) -> Optional[Any]:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_path = self._get_cache_path(cache_key)

# Check if cache file exists
if os.path.exists(cache_path):
try:
cache = ibis.read_parquet(cache_path)
df = cache.execute()
return df
except Exception as e:
logger.debug(f"Failed to read query cache {e}")
return None

return None

def set(self, data_source: str, sql: str, result: Any, info) -> None:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_path = self._get_cache_path(cache_key)

try:
# Create cache directory if it doesn't exist
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
cache = ibis.memtable(result)
logger.info(f"\nWriting query cache to {cache_path}\n")
cache.to_parquet(cache_path)
except Exception as e:
logger.debug(f"Failed to write query cache: {e}")
return

def _generate_cache_key(self, data_source: str, sql: str, info) -> str:
key_parts = [
data_source,
sql,
info.host.get_secret_value(),
info.port.get_secret_value(),
info.user.get_secret_value(),
]
key_string = "|".join(key_parts)

return hashlib.sha256(key_string.encode()).hexdigest()

def _get_cache_path(self, cache_key: str) -> str:
return f"/tmp/wren-engine/{cache_key}.cache"
57 changes: 47 additions & 10 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from app.model.metadata.dto import Constraint, MetadataDTO, Table
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.util import build_context, to_json

router = APIRouter(prefix="/connector")
Expand All @@ -29,13 +30,18 @@ def get_java_engine_connector(request: Request) -> JavaEngineConnector:
return request.state.java_engine_connector


def get_query_cache_manager(request: Request) -> QueryCacheManager:
return request.state.query_cache_manager


@router.post("/{data_source}/query", dependencies=[Depends(verify_query_dto)])
async def query(
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[bool, Query(alias="dryRun")] = False,
limit: int | None = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
headers: Annotated[str | None, Header()] = None,
) -> Response:
span_name = (
Expand All @@ -44,16 +50,47 @@ async def query(
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
rewritten_sql = await Rewriter(
dto.manifest_str,
data_source=data_source,
java_engine_connector=java_engine_connector,
).rewrite(dto.sql)
connector = Connector(data_source, dto.connection_info)
if dry_run:
connector.dry_run(rewritten_sql)
return Response(status_code=204)
return ORJSONResponse(to_json(connector.query(rewritten_sql, limit=limit)))
cached_result = None
cache_hit = False
enable_cache = dto.enable_cache

if enable_cache:
cached_result = query_cache_manager.get(
str(data_source), dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

# Cache Hit !
if cached_result is not None:
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
# Cache Miss
else:
rewritten_sql = await Rewriter(
dto.manifest_str,
data_source=data_source,
java_engine_connector=java_engine_connector,
).rewrite(dto.sql)

connector = Connector(data_source, dto.connection_info)

if dry_run:
connector.dry_run(rewritten_sql)
dry_response = Response(status_code=204)
dry_response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return dry_response
else:
# missing cache and not dry run
# so we need to query the datasource and cache the result
result = connector.query(rewritten_sql, limit=limit)
if enable_cache:
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
)
response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response


@router.post("/{data_source}/validate/{rule_name}")
Expand Down
1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ markers = [
"minio_file: mark a test as a minio file test",
"gcs_file: mark a test as a gcs file test",
"beta: mark a test as a test for beta versions of the engine",
"cache: mark a test as a cache test",
]

[tool.ruff]
Expand Down
Empty file.
183 changes: 183 additions & 0 deletions ibis-server/tests/routers/v2/query_cache/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import base64
import os
import shutil

import orjson
import pandas as pd
import pytest
import sqlalchemy
from testcontainers.postgres import PostgresContainer

from app.query_cache import QueryCacheManager
from tests.conftest import file_path

pytestmark = pytest.mark.cache

base_url = "/v2/connector/postgres"

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"tableReference": {
"schema": "public",
"table": "orders",
},
"columns": [
{"name": "orderkey", "expression": "o_orderkey", "type": "integer"},
{"name": "custkey", "expression": "o_custkey", "type": "integer"},
{
"name": "orderstatus",
"expression": "o_orderstatus",
"type": "varchar",
},
{
"name": "totalprice",
"expression": "o_totalprice",
"type": "float",
},
{"name": "orderdate", "expression": "o_orderdate", "type": "date"},
{
"name": "order_cust_key",
"expression": "concat(o_orderkey, '_', o_custkey)",
"type": "varchar",
},
{
"name": "timestamp",
"expression": "cast('2024-01-01T23:59:59' as timestamp)",
"type": "timestamp",
},
{
"name": "timestamptz",
"expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "cast(NULL as timestamp)",
"type": "timestamp",
},
{
"name": "bytea_column",
"expression": "cast('abc' as bytea)",
"type": "bytea",
},
],
"primaryKey": "orderkey",
},
{
"name": "Customer",
"refSql": "SELECT * FROM public.customer",
"columns": [
{
"name": "custkey",
"type": "integer",
"expression": "c_custkey",
},
{
"name": "orders",
"type": "Orders",
"relationship": "CustomerOrders",
},
{
"name": "orders_key",
"type": "varchar",
"isCalculated": True,
"expression": "orders.orderkey",
},
],
},
],
"relationships": [
{
"name": "CustomerOrders",
"models": ["Customer", "Orders"],
"joinType": "ONE_TO_MANY",
"condition": "Customer.custkey = Orders.custkey",
}
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(scope="module")
def postgres(request) -> PostgresContainer:
pg = PostgresContainer("postgres:16-alpine").start()
engine = sqlalchemy.create_engine(pg.get_connection_url())
# Load sample data
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql(
"customer", engine, index=False
)
request.addfinalizer(pg.stop)
return pg


@pytest.fixture(scope="function")
def cache_dir():
temp_dir = "/tmp/wren-engine-test"
os.makedirs(temp_dir, exist_ok=True)
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir, ignore_errors=True)


async def test_query_with_cache(
client, manifest_str, postgres: PostgresContainer, cache_dir, monkeypatch
):
# Override the cache path to use our test directory
monkeypatch.setattr(
QueryCacheManager,
"_get_cache_path",
lambda self, key: f"{cache_dir}/{key}.cache",
)

connection_info = _to_connection_info(postgres)
test_sql = 'SELECT * FROM "Orders" LIMIT 10'

# First request - should miss cache
response1 = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": test_sql,
},
)
assert response1.status_code == 200
result1 = response1.json()

# Second request with same SQL - should hit cache
response2 = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": test_sql,
},
)
assert response2.status_code == 200
result2 = response2.json()

# Verify results are identical
assert result1["data"] == result2["data"]
assert result1["columns"] == result2["columns"]
assert result1["dtypes"] == result2["dtypes"]


def _to_connection_info(pg: PostgresContainer):
return {
"host": pg.get_container_host_ip(),
"port": pg.get_exposed_port(pg.port),
"user": pg.username,
"password": pg.password,
"database": pg.dbname,
}
Loading