Skip to content

feat(ibis): persistent query cache #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 2, 2025
9 changes: 8 additions & 1 deletion ibis-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.mdl.java_engine import JavaEngineConnector
from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware
from app.model import ConfigModel, CustomHttpError
from app.query_cache import QueryCacheManager
from app.routers import v2, v3

get_config().init_logger()
Expand All @@ -22,12 +23,18 @@
# Use state to store the singleton instance
class State(TypedDict):
java_engine_connector: JavaEngineConnector
query_cache_manager: QueryCacheManager


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
query_cache_manager = QueryCacheManager()

async with JavaEngineConnector() as java_engine_connector:
yield {"java_engine_connector": java_engine_connector}
yield {
"java_engine_connector": java_engine_connector,
"query_cache_manager": query_cache_manager,
}


app = FastAPI(lifespan=lifespan)
Expand Down
75 changes: 75 additions & 0 deletions ibis-server/app/query_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import hashlib
from typing import Any, Optional

import ibis
import opendal
from loguru import logger
from opentelemetry import trace

tracer = trace.get_tracer(__name__)


class QueryCacheManager:
def __init__(self, root: str = "/tmp/wren-engine/"):
self.root = root

@tracer.start_as_current_span("get_cache", kind=trace.SpanKind.INTERNAL)
def get(self, data_source: str, sql: str, info) -> Optional[Any]:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_file_name = self._get_cache_file_name(cache_key)
op = self._get_dal_operator()
full_path = self._get_full_path(cache_file_name)

# Check if cache file exists
if op.exists(cache_file_name):
try:
logger.info(f"\nReading query cache {cache_file_name}\n")
cache = ibis.read_parquet(full_path)
df = cache.execute()
logger.info("\nquery cache to dataframe\n")
return df
except Exception as e:
logger.debug(f"Failed to read query cache {e}")
return None

return None

@tracer.start_as_current_span("set_cache", kind=trace.SpanKind.INTERNAL)
def set(self, data_source: str, sql: str, result: Any, info) -> None:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_file_name = self._get_cache_file_name(cache_key)
op = self._get_dal_operator()
full_path = self._get_full_path(cache_file_name)

try:
# Create cache directory if it doesn't exist
with op.open(cache_file_name, mode="wb") as file:
cache = ibis.memtable(result)
logger.info(f"\nWriting query cache to {cache_file_name}\n")
if file.writable():
cache.to_parquet(full_path)
except Exception as e:
logger.debug(f"Failed to write query cache: {e}")
return

def _generate_cache_key(self, data_source: str, sql: str, info) -> str:
key_parts = [
data_source,
sql,
info.host.get_secret_value(),
info.port.get_secret_value(),
info.user.get_secret_value(),
]
key_string = "|".join(key_parts)

return hashlib.sha256(key_string.encode()).hexdigest()

def _get_cache_file_name(self, cache_key: str) -> str:
return f"{cache_key}.cache"

def _get_full_path(self, path: str) -> str:
return self.root + path

def _get_dal_operator(self) -> Any:
# Default implementation using local filesystem
return opendal.Operator("fs", root=self.root)
55 changes: 49 additions & 6 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from app.model.metadata.dto import Constraint, MetadataDTO, Table
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
Expand All @@ -30,23 +31,32 @@ def get_java_engine_connector(request: Request) -> JavaEngineConnector:
return request.state.java_engine_connector


def get_query_cache_manager(request: Request) -> QueryCacheManager:
return request.state.query_cache_manager


@router.post(
"/{data_source}/query", dependencies=[Depends(verify_query_dto)], deprecated=True
)
async def query(
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[bool, Query(alias="dryRun")] = False,
cache_enable: Annotated[bool, Query(alias="cacheEnable")] = False,
limit: int | None = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
headers: Annotated[str | None, Header()] = None,
) -> Response:
span_name = (
f"v2_query_{data_source}_dry_run" if dry_run else f"v2_query_{data_source}"
)
span_name = f"v2_query_{data_source}"
if dry_run:
span_name += "_dry_run"
if cache_enable:
span_name += "_cache_enable"

with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
) as span:
try:
sql = pushdown_limit(dto.sql, limit)
except Exception as e:
Expand All @@ -59,10 +69,43 @@ async def query(
java_engine_connector=java_engine_connector,
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)

# First check if the query is a dry run
# If it is dry run.
# We don't need to check query cache
if dry_run:
connector.dry_run(rewritten_sql)
return Response(status_code=204)
return ORJSONResponse(to_json(connector.query(rewritten_sql, limit=limit)))
dry_response = Response(status_code=204)
dry_response.headers["X-Cache-Hit"] = "true"
return dry_response

# Not a dry run
# Check if the query is cached
cached_result = None
cache_hit = False
enable_cache = cache_enable

if enable_cache:
cached_result = query_cache_manager.get(
str(data_source), dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

if cache_hit:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
else:
result = connector.query(rewritten_sql, limit=limit)
if enable_cache:
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
)

response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response


@router.post("/{data_source}/validate/{rule_name}", deprecated=True)
Expand Down
48 changes: 41 additions & 7 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.routers import v2
from app.routers.v2.connector import get_java_engine_connector
from app.routers.v2.connector import get_java_engine_connector, get_query_cache_manager
from app.util import build_context, pushdown_limit, to_json

router = APIRouter(prefix="/connector")
Expand All @@ -36,16 +37,21 @@ async def query(
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[bool, Query(alias="dryRun")] = False,
cache_enable: Annotated[bool, Query(alias="cacheEnable")] = False,
limit: int | None = None,
headers: Annotated[str | None, Header()] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
) -> Response:
span_name = (
f"v3_query_{data_source}_dry_run" if dry_run else f"v3_query_{data_source}"
)
span_name = f"v3_query_{data_source}"
if dry_run:
span_name += "_dry_run"
if cache_enable:
span_name += "_cache_enable"

with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
) as span:
try:
sql = pushdown_limit(dto.sql, limit)
rewritten_sql = await Rewriter(
Expand All @@ -54,8 +60,36 @@ async def query(
connector = Connector(data_source, dto.connection_info)
if dry_run:
connector.dry_run(rewritten_sql)
return Response(status_code=204)
return ORJSONResponse(to_json(connector.query(rewritten_sql, limit=limit)))
dry_response = Response(status_code=204)
dry_response.headers["X-Cache-Hit"] = "true"
return dry_response

# Not a dry run
# Check if the query is cached
cached_result = None
cache_hit = False
enable_cache = cache_enable

if enable_cache:
cached_result = query_cache_manager.get(
str(data_source), dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

if cache_hit:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
else:
result = connector.query(rewritten_sql, limit=limit)
if enable_cache:
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
)
response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = str(cache_hit).lower()
return response
except Exception as e:
logger.warning(
"Failed to execute v3 query, fallback to v2: {}\n" + MIGRATION_MESSAGE,
Expand Down
36 changes: 36 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,42 @@ async def test_query(client, manifest_str, postgres: PostgresContainer):
}


async def test_query_with_cache(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)

# First request - should miss cache
response1 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 10',
},
)

assert response1.status_code == 200
assert response1.headers["X-Cache-Hit"] == "false"
result1 = response1.json()

# Second request with same SQL - should hit cache
response2 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 10',
},
)
assert response2.status_code == 200
assert response2.headers["X-Cache-Hit"] == "true"
result2 = response2.json()

# Verify results are identical
assert result1["data"] == result2["data"]
assert result1["columns"] == result2["columns"]
assert result1["dtypes"] == result2["dtypes"]


async def test_query_with_connection_url(
client, manifest_str, postgres: PostgresContainer
):
Expand Down
34 changes: 34 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,40 @@ async def test_query(client, manifest_str, connection_info):
}


async def test_query_with_cache(client, manifest_str, connection_info):
# First request - should miss cache
response1 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM wren.public.orders LIMIT 1",
},
)

assert response1.status_code == 200
assert response1.headers["X-Cache-Hit"] == "false"
result1 = response1.json()

# Second request with same SQL - should hit cache
response2 = await client.post(
url=f"{base_url}/query?cacheEnable=true", # Enable cache
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM wren.public.orders LIMIT 1",
},
)

assert response2.status_code == 200
assert response2.headers["X-Cache-Hit"] == "true"
result2 = response2.json()

assert result1["data"] == result2["data"]
assert result1["columns"] == result2["columns"]
assert result1["dtypes"] == result2["dtypes"]


async def test_query_with_connection_url(client, manifest_str, connection_url):
response = await client.post(
url=f"{base_url}/query",
Expand Down
Loading