Skip to content

feat(ibis): persistent query cache #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 2, 2025
9 changes: 8 additions & 1 deletion ibis-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.mdl.java_engine import JavaEngineConnector
from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware
from app.model import ConfigModel, CustomHttpError
from app.query_cache import QueryCacheManager
from app.routers import v2, v3

get_config().init_logger()
Expand All @@ -22,12 +23,18 @@
# Use state to store the singleton instance
class State(TypedDict):
java_engine_connector: JavaEngineConnector
query_cache_manager: QueryCacheManager


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
query_cache_manager = QueryCacheManager()

async with JavaEngineConnector() as java_engine_connector:
yield {"java_engine_connector": java_engine_connector}
yield {
"java_engine_connector": java_engine_connector,
"query_cache_manager": query_cache_manager,
}


app = FastAPI(lifespan=lifespan)
Expand Down
61 changes: 61 additions & 0 deletions ibis-server/app/query_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import hashlib
import os
from typing import Any, Optional

import ibis
from loguru import logger
from opentelemetry import trace

tracer = trace.get_tracer(__name__)


class QueryCacheManager:
def __init__(self):
pass

@tracer.start_as_current_span("get_cache", kind=trace.SpanKind.INTERNAL)
def get(self, data_source: str, sql: str, info) -> Optional[Any]:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_path = self._get_cache_path(cache_key)

# Check if cache file exists
if os.path.exists(cache_path):
try:
cache = ibis.read_parquet(cache_path)
df = cache.execute()
return df
except Exception as e:
logger.debug(f"Failed to read query cache {e}")
return None

return None

@tracer.start_as_current_span("set_cache", kind=trace.SpanKind.INTERNAL)
def set(self, data_source: str, sql: str, result: Any, info) -> None:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_path = self._get_cache_path(cache_key)

try:
# Create cache directory if it doesn't exist
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
cache = ibis.memtable(result)
logger.info(f"\nWriting query cache to {cache_path}\n")
cache.to_parquet(cache_path)
except Exception as e:
logger.debug(f"Failed to write query cache: {e}")
return

def _generate_cache_key(self, data_source: str, sql: str, info) -> str:
key_parts = [
data_source,
sql,
info.host.get_secret_value(),
info.port.get_secret_value(),
info.user.get_secret_value(),
]
key_string = "|".join(key_parts)

return hashlib.sha256(key_string.encode()).hexdigest()

def _get_cache_path(self, cache_key: str) -> str:
return f"/tmp/wren-engine/{cache_key}.cache"
55 changes: 49 additions & 6 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from app.model.metadata.dto import Constraint, MetadataDTO, Table
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
Expand All @@ -30,23 +31,32 @@ def get_java_engine_connector(request: Request) -> JavaEngineConnector:
return request.state.java_engine_connector


def get_query_cache_manager(request: Request) -> QueryCacheManager:
return request.state.query_cache_manager


@router.post(
"/{data_source}/query", dependencies=[Depends(verify_query_dto)], deprecated=True
)
async def query(
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[bool, Query(alias="dryRun")] = False,
cache_enable: Annotated[bool, Query(alias="cacheEnable")] = False,
limit: int | None = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
headers: Annotated[str | None, Header()] = None,
) -> Response:
span_name = (
f"v2_query_{data_source}_dry_run" if dry_run else f"v2_query_{data_source}"
)
span_name = f"v2_query_{data_source}"
if dry_run:
span_name += "_dry_run"
if cache_enable:
span_name += "_cache_enable"

with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
) as span:
try:
sql = pushdown_limit(dto.sql, limit)
except Exception as e:
Expand All @@ -59,10 +69,43 @@ async def query(
java_engine_connector=java_engine_connector,
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)

# First check if the query is a dry run
# If it is dry run.
# We don't need to check query cache
if dry_run:
connector.dry_run(rewritten_sql)
return Response(status_code=204)
return ORJSONResponse(to_json(connector.query(rewritten_sql, limit=limit)))
dry_response = Response(status_code=204)
dry_response.headers["X-Cache-Hit"] = "true"
return dry_response

# Not a dry run
# Check if the query is cached
cached_result = None
cache_hit = False
enable_cache = cache_enable

if enable_cache:
cached_result = query_cache_manager.get(
str(data_source), dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

if cache_hit:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
else:
result = connector.query(rewritten_sql, limit=limit)
if enable_cache:
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
)

response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response


@router.post("/{data_source}/validate/{rule_name}", deprecated=True)
Expand Down
48 changes: 41 additions & 7 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.routers import v2
from app.routers.v2.connector import get_java_engine_connector
from app.routers.v2.connector import get_java_engine_connector, get_query_cache_manager
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
Expand All @@ -36,16 +37,21 @@ async def query(
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[bool, Query(alias="dryRun")] = False,
cache_enable: Annotated[bool, Query(alias="cacheEnable")] = False,
limit: int | None = None,
headers: Annotated[str | None, Header()] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
) -> Response:
span_name = (
f"v3_query_{data_source}_dry_run" if dry_run else f"v3_query_{data_source}"
)
span_name = f"v3_query_{data_source}"
if dry_run:
span_name += "_dry_run"
if cache_enable:
span_name += "_cache_enable"

with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
) as span:
try:
sql = pushdown_limit(dto.sql, limit)
rewritten_sql = await Rewriter(
Expand All @@ -54,8 +60,36 @@ async def query(
connector = Connector(data_source, dto.connection_info)
if dry_run:
connector.dry_run(rewritten_sql)
return Response(status_code=204)
return ORJSONResponse(to_json(connector.query(rewritten_sql, limit=limit)))
dry_response = Response(status_code=204)
dry_response.headers["X-Cache-Hit"] = "true"
return dry_response

# Not a dry run
# Check if the query is cached
cached_result = None
cache_hit = False
enable_cache = cache_enable

if enable_cache:
cached_result = query_cache_manager.get(
str(data_source), dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

if cache_hit:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
else:
result = connector.query(rewritten_sql, limit=limit)
if enable_cache:
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
)
response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
except Exception as e:
logger.warning(
"Failed to execute v3 query, fallback to v2: {}\n" + MIGRATION_MESSAGE,
Expand Down
57 changes: 57 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import os
import shutil
from urllib.parse import quote_plus, urlparse

import orjson
Expand All @@ -10,6 +12,7 @@
from testcontainers.postgres import PostgresContainer

from app.model.validator import rules
from app.query_cache import QueryCacheManager
from tests.conftest import file_path

pytestmark = pytest.mark.postgres
Expand Down Expand Up @@ -124,6 +127,15 @@ def postgres(request) -> PostgresContainer:
return pg


@pytest.fixture(scope="function")
def cache_dir():
temp_dir = "/tmp/wren-engine-test"
os.makedirs(temp_dir, exist_ok=True)
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir, ignore_errors=True)


async def test_query(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
Expand Down Expand Up @@ -164,6 +176,51 @@ async def test_query(client, manifest_str, postgres: PostgresContainer):
}


async def test_query_with_cache(
client, manifest_str, postgres: PostgresContainer, cache_dir, monkeypatch
):
# Override the cache path to use our test directory
monkeypatch.setattr(
QueryCacheManager,
"_get_cache_path",
lambda self, key: f"{cache_dir}/{key}.cache",
)

connection_info = _to_connection_info(postgres)

# First request - should miss cache
response1 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 10',
},
)

assert response1.status_code == 200
assert response1.headers["X-Cache-Hit"] == "false"
result1 = response1.json()

# Second request with same SQL - should hit cache
response2 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 10',
},
)
assert response2.status_code == 200
assert response2.headers["X-Cache-Hit"] == "true"
result2 = response2.json()

# Verify results are identical
assert result1["data"] == result2["data"]
assert result1["columns"] == result2["columns"]
assert result1["dtypes"] == result2["dtypes"]


async def test_query_with_connection_url(
client, manifest_str, postgres: PostgresContainer
):
Expand Down
Loading
Loading