Skip to content

Commit b6cefb9

Browse files
committed
feat(sqlalchemy): make ibis.connect with sqlalchemy backends
1 parent 2c48835 commit b6cefb9

File tree

9 files changed

+179
-81
lines changed

9 files changed

+179
-81
lines changed

.github/workflows/ibis-backends.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ jobs:
4747
title: Datafusion
4848
- name: pyspark
4949
title: PySpark
50+
parallel: false
5051
- name: mysql
5152
title: MySQL
5253
services:
@@ -80,6 +81,7 @@ jobs:
8081
backend:
8182
name: impala
8283
title: Impala
84+
parallel: false
8385
services:
8486
- impala
8587
- kudu
@@ -93,6 +95,7 @@ jobs:
9395
backend:
9496
name: impala
9597
title: Impala
98+
parallel: false
9699
services:
97100
- impala
98101
- kudu
@@ -158,12 +161,12 @@ jobs:
158161
- name: download backend data
159162
run: just download-data
160163

161-
- name: run parallel ${{ matrix.backend.name }} tests
162-
if: ${{ matrix.backend.name != 'pyspark' && matrix.backend.name != 'impala' }}
164+
- name: "run parallel tests: ${{ matrix.backend.name }}"
165+
if: ${{ matrix.backend.parallel || matrix.backend.parallel == null }}
163166
run: just ci-check -m ${{ matrix.backend.name }} --numprocesses auto --dist=loadgroup
164167

165-
- name: run non-parallel ${{ matrix.backend.name }} tests
166-
if: ${{ matrix.backend.name == 'pyspark' || matrix.backend.name == 'impala' }}
168+
- name: "run serial tests: ${{ matrix.backend.name }}"
169+
if: ${{ !matrix.backend.parallel }}
167170
run: just ci-check -m ${{ matrix.backend.name }}
168171
env:
169172
IBIS_TEST_NN_HOST: localhost

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ repos:
1616
rev: 5.10.1
1717
hooks:
1818
- id: isort
19-
- repo: https://github.com/pycqa/flake8
20-
rev: 5.0.4
21-
hooks:
22-
- id: flake8
2319
- repo: https://github.com/psf/black
2420
rev: 22.6.0
2521
hooks:
2622
- id: black
23+
- repo: https://github.com/pycqa/flake8
24+
rev: 5.0.4
25+
hooks:
26+
- id: flake8
2727
- repo: https://github.com/MarcoGorelli/absolufy-imports
2828
rev: v0.3.1
2929
hooks:

ibis/backends/base/__init__.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import abc
44
import collections.abc
55
import functools
6+
import importlib.metadata
67
import keyword
78
import re
9+
import sys
10+
import urllib.parse
811
from pathlib import Path
912
from typing import (
1013
TYPE_CHECKING,
@@ -14,6 +17,7 @@
1417
Iterable,
1518
Iterator,
1619
Mapping,
20+
MutableMapping,
1721
)
1822

1923
if TYPE_CHECKING:
@@ -277,6 +281,16 @@ def connect(self, *args, **kwargs) -> BaseBackend:
277281
new_backend.reconnect()
278282
return new_backend
279283

284+
def _from_url(self, url: str) -> BaseBackend:
285+
"""Construct an ibis backend from a SQLAlchemy-conforming URL."""
286+
raise NotImplementedError(
287+
f"`_from_url` not implemented for the {self.name} backend"
288+
)
289+
290+
@staticmethod
291+
def _convert_kwargs(kwargs: MutableMapping) -> None:
292+
"""Manipulate keyword arguments to `.connect` method."""
293+
280294
def reconnect(self) -> None:
281295
"""Reconnect to the database already configured with connect."""
282296
self.do_connect(*self._con_args, **self._con_kwargs)
@@ -672,21 +686,60 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
672686
_connect = RegexDispatcher("_connect")
673687

674688

675-
@_connect.register(r"(?P<backend>.+)://(?P<path>.*)", priority=10)
676-
def _(_: str, *, backend: str, path: str, **kwargs: Any) -> BaseBackend:
689+
@functools.lru_cache(maxsize=None)
690+
def _get_backend_names() -> frozenset[str]:
691+
"""Return the set of known backend names.
692+
693+
Notes
694+
-----
695+
This function returns a frozenset to prevent cache pollution.
696+
697+
If a `set` is used, then any in-place modifications to the set
698+
are visible to every caller of this function.
699+
"""
700+
701+
if sys.version_info < (3, 10):
702+
entrypoints = importlib.metadata.entry_points()["ibis.backends"]
703+
else:
704+
entrypoints = importlib.metadata.entry_points(group="ibis.backends")
705+
return frozenset(ep.name for ep in entrypoints)
706+
707+
708+
_PATTERN = "|".join(
709+
sorted(_get_backend_names().difference(("duckdb", "sqlite")))
710+
)
711+
712+
713+
@_connect.register(rf"(?P<backend>{_PATTERN})://.+", priority=12)
714+
def _(url: str, *, backend: str, **kwargs: Any) -> BaseBackend:
677715
"""Connect to given `backend` with `path`.
678716
679717
Examples
680718
--------
681-
>>> con = ibis.connect("duckdb://relative/path/to/data.db")
682719
>>> con = ibis.connect("postgres://user:pass@hostname:port/database")
720+
>>> con = ibis.connect("mysql://user:pass@hostname:port/database")
683721
"""
684-
instance = getattr(ibis, backend)
722+
instance: BaseBackend = getattr(ibis, backend)
685723
backend += (backend == "postgres") * "ql"
686-
try:
687-
return instance.connect(url=f"{backend}://{path}", **kwargs)
688-
except TypeError:
689-
return instance.connect(path, **kwargs)
724+
params = "?" * bool(kwargs) + urllib.parse.urlencode(kwargs)
725+
url += params
726+
return instance._from_url(url)
727+
728+
729+
@_connect.register(r"(?P<backend>duckdb|sqlite)://(?P<path>.*)", priority=12)
730+
def _(_: str, *, backend: str, path: str, **kwargs: Any) -> BaseBackend:
731+
"""Connect to given `backend` with `path`.
732+
733+
Examples
734+
--------
735+
>>> con = ibis.connect("duckdb://relative/path/to/data.db")
736+
>>> con = ibis.connect("sqlite:///absolute/path/to/data.db")
737+
"""
738+
instance: BaseBackend = getattr(ibis, backend)
739+
params = "?" * bool(kwargs) + urllib.parse.urlencode(kwargs)
740+
path += params
741+
# extra slash for sqlalchemy
742+
return instance._from_url(f"{backend}:///{path}")
690743

691744

692745
@_connect.register(r"file://(?P<path>.*)", priority=10)
@@ -716,7 +769,7 @@ def connect(resource: Path | str, **_: Any) -> BaseBackend:
716769
717770
Examples
718771
--------
719-
>>> con = ibis.connect("duckdb://relative/path/to/data.db")
772+
>>> con = ibis.connect("duckdb:///absolute/path/to/data.db")
720773
>>> con = ibis.connect("relative/path/to/data.duckdb")
721774
"""
722775
raise NotImplementedError(type(resource))
@@ -752,29 +805,20 @@ def _(
752805
Examples
753806
--------
754807
>>> con = ibis.connect("duckdb://relative/path/to/data.csv")
755-
>>> con = ibis.connect("duckdb://relative/path/to/more/data.parquet")
808+
>>> con = ibis.connect("duckdb:///absolute/path/to/more/data.parquet")
756809
"""
757810
con = getattr(ibis, backend).connect(**kwargs)
758811
con.register(f"{extension}://{filename}")
759812
return con
760813

761814

762-
@_connect.register(
763-
r"(?P<filename>.+\.(?P<extension>parquet|csv))",
764-
priority=8,
765-
)
766-
def _(
767-
_: str,
768-
*,
769-
filename: str,
770-
extension: str,
771-
**kwargs: Any,
772-
) -> BaseBackend:
815+
@_connect.register(r".+\.(?:parquet|csv)", priority=8)
816+
def _(filename: str, **kwargs: Any) -> BaseBackend:
773817
"""Connect to `duckdb` and register a parquet or csv file.
774818
775819
Examples
776820
--------
777821
>>> con = ibis.connect("relative/path/to/data.csv")
778822
>>> con = ibis.connect("relative/path/to/more/data.parquet")
779823
"""
780-
return _connect(f"duckdb://{filename}", **kwargs)
824+
return _connect(f"duckdb:///{filename}", **kwargs)

ibis/backends/base/sql/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from functools import lru_cache
66
from typing import Any, Mapping
77

8+
import sqlalchemy as sa
9+
810
import ibis.expr.operations as ops
911
import ibis.expr.schema as sch
1012
import ibis.expr.types as ir
@@ -25,6 +27,33 @@ class BaseSQLBackend(BaseBackend):
2527
table_class = ops.DatabaseTable
2628
table_expr_class = ir.Table
2729

30+
def _from_url(self, url: str) -> BaseBackend:
31+
"""Connect to a backend using a URL `url`.
32+
33+
Parameters
34+
----------
35+
url
36+
URL with which to connect to a backend.
37+
38+
Returns
39+
-------
40+
BaseBackend
41+
A backend instance
42+
"""
43+
url = sa.engine.make_url(url)
44+
45+
kwargs = {
46+
name: value
47+
for name in ("host", "port", "database", "password")
48+
if (value := getattr(url, name, None))
49+
}
50+
if username := url.username:
51+
kwargs["user"] = username
52+
53+
kwargs.update(url.query)
54+
self._convert_kwargs(kwargs)
55+
return self.connect(**kwargs)
56+
2857
def table(self, name: str, database: str | None = None) -> ir.Table:
2958
"""Construct a table expression.
3059

ibis/backends/conftest.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import ibis
2222
import ibis.util as util
23+
from ibis.backends.base import _get_backend_names
2324

2425
TEST_TABLES = {
2526
"functional_alltypes": ibis.schema(
@@ -221,28 +222,6 @@ def _random_identifier(suffix: str) -> str:
221222
return f"__ibis_test_{suffix}_{util.guid()}"
222223

223224

224-
@lru_cache(maxsize=None)
225-
def _get_backend_names() -> frozenset[str]:
226-
"""Return the set of known backend names.
227-
228-
Notes
229-
-----
230-
This function returns a frozenset to prevent cache pollution.
231-
232-
If a `set` is used, then any in-place modifications to the set
233-
are visible to every caller of this function.
234-
"""
235-
import sys
236-
237-
if sys.version_info < (3, 10):
238-
entrypoints = list(importlib.metadata.entry_points()['ibis.backends'])
239-
else:
240-
entrypoints = list(
241-
importlib.metadata.entry_points(group="ibis.backends")
242-
)
243-
return frozenset(ep.name for ep in entrypoints)
244-
245-
246225
def _get_backend_conf(backend_str: str):
247226
"""Convert a backend string to the test class for the backend."""
248227
conftest = importlib.import_module(
@@ -300,6 +279,8 @@ def pytest_ignore_collect(path, config):
300279
def pytest_collection_modifyitems(session, config, items):
301280
# add the backend marker to any tests are inside "ibis/backends"
302281
all_backends = _get_backend_names()
282+
xdist_group_markers = []
283+
303284
for item in items:
304285
parts = item.path.parts
305286
backend = _get_backend_from_parts(parts)
@@ -310,10 +291,16 @@ def pytest_collection_modifyitems(session, config, items):
310291
# anything else is a "core" test and is run by default
311292
item.add_marker(pytest.mark.core)
312293

313-
if "sqlite" in item.nodeid:
314-
item.add_marker(pytest.mark.xdist_group(name="sqlite"))
315-
if "duckdb" in item.nodeid:
316-
item.add_marker(pytest.mark.xdist_group(name="duckdb"))
294+
for name in ("duckdb", "sqlite"):
295+
# build a list of markers so we're don't invalidate the item's
296+
# marker iterator
297+
for _ in item.iter_markers(name=name):
298+
xdist_group_markers.append(
299+
(item, pytest.mark.xdist_group(name=name))
300+
)
301+
302+
for item, marker in xdist_group_markers:
303+
item.add_marker(marker)
317304

318305

319306
@lru_cache(maxsize=None)

ibis/backends/pyspark/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,19 @@ class Options(ibis.config.BaseModel):
9696
description="Treat NaNs in floating point expressions as NULL.",
9797
)
9898

99-
def do_connect(self, session: pyspark.sql.SparkSession) -> None:
99+
def _from_url(self, url: str) -> Backend:
100+
"""Construct a PySpark backend from a URL `url`."""
101+
url = sa.engine.make_url(url)
102+
params = list(itertools.chain.from_iterable(url.query.items()))
103+
if database := url.database:
104+
params.append("spark.sql.warehouse.dir")
105+
params.append(str(Path(database).absolute()))
106+
107+
builder = SparkSession.builder.config(*params)
108+
session = builder.getOrCreate()
109+
return self.connect(session)
110+
111+
def do_connect(self, session: SparkSession) -> None:
100112
"""Create a PySpark `Backend` for use with Ibis.
101113
102114
Parameters

0 commit comments

Comments
 (0)