Skip to content

Commit ff6100a

Browse files
committed
feat(postgres): implement maps in terms of JSONB instead of HSTORE
1 parent d7cd846 commit ff6100a

File tree

9 files changed

+125
-132
lines changed

9 files changed

+125
-132
lines changed

ci/schema/postgres.sql

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
CREATE EXTENSION IF NOT EXISTS hstore;
21
CREATE EXTENSION IF NOT EXISTS postgis;
32
CREATE EXTENSION IF NOT EXISTS plpython3u;
43
CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;
@@ -298,10 +297,10 @@ INSERT INTO win VALUES
298297
('a', 4, 1);
299298

300299
DROP TABLE IF EXISTS map CASCADE;
301-
CREATE TABLE map (idx BIGINT, kv HSTORE);
300+
CREATE TABLE map (idx BIGINT, kv JSONB);
302301
INSERT INTO map VALUES
303-
(1, 'a=>1,b=>2,c=>3'),
304-
(2, 'd=>4,e=>5,c=>6');
302+
(1, '{"a": 1, "b": 2, "c": 3}'),
303+
(2, '{"d": 4, "e": 5, "f": 6}');
305304

306305
DROP TABLE IF EXISTS topk;
307306

ibis/backends/postgres/__init__.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import contextlib
66
import inspect
7-
import warnings
87
from operator import itemgetter
98
from typing import TYPE_CHECKING, Any
109
from urllib.parse import unquote_plus
@@ -208,7 +207,6 @@ def do_connect(
208207
database: str | None = None,
209208
schema: str | None = None,
210209
autocommit: bool = True,
211-
enable_map_support: bool = True,
212210
**kwargs: Any,
213211
) -> None:
214212
"""Create an Ibis client connected to PostgreSQL database.
@@ -229,9 +227,6 @@ def do_connect(
229227
PostgreSQL schema to use. If `None`, use the default `search_path`.
230228
autocommit
231229
Whether or not to autocommit
232-
enable_map_support
233-
Whether or not to enable map support. If `True`, the HSTORE
234-
extension will be loaded to support maps of string -> string.
235230
kwargs
236231
Additional keyword arguments to pass to the backend client connection.
237232
@@ -265,9 +260,6 @@ def do_connect(
265260
"""
266261
import pandas as pd
267262
import psycopg
268-
import psycopg.types.json
269-
270-
psycopg.types.json.set_json_loads(loads=lambda x: x)
271263

272264
self.con = psycopg.connect(
273265
host=host,
@@ -282,7 +274,7 @@ def do_connect(
282274

283275
self.con.adapters.register_dumper(type(pd.NaT), NatDumper)
284276

285-
self._post_connect(enable_map_support)
277+
self._post_connect()
286278

287279
@util.experimental
288280
@classmethod
@@ -300,26 +292,8 @@ def from_connection(cls, con: psycopg.Connection, /) -> Backend:
300292
new_backend._post_connect()
301293
return new_backend
302294

303-
def _post_connect(self, enable_map_support: bool = True) -> None:
304-
import psycopg.types
305-
import psycopg.types.hstore
306-
307-
con = self.con
308-
309-
try:
310-
# try to load hstore
311-
if enable_map_support:
312-
with con.cursor() as cursor, con.transaction():
313-
cursor.execute("CREATE EXTENSION IF NOT EXISTS hstore")
314-
psycopg.types.hstore.register_hstore(
315-
psycopg.types.TypeInfo.fetch(self.con, "hstore"), self.con
316-
)
317-
except psycopg.Error as e:
318-
warnings.warn(f"Failed to load hstore extension: {e}")
319-
except TypeError:
320-
pass
321-
322-
with con.cursor() as cursor, con.transaction():
295+
def _post_connect(self) -> None:
296+
with (con := self.con).cursor() as cursor, con.transaction():
323297
cursor.execute("SET TIMEZONE = UTC")
324298

325299
@property

ibis/backends/postgres/converter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import json
4+
35
from ibis.formats.pandas import PandasData
46

57

@@ -18,3 +20,19 @@ def convert_GeoSpatial(cls, s, dtype, pandas_type):
1820
@classmethod
1921
def convert_Binary(cls, s, dtype, pandas_type):
2022
return s.map(bytes, na_action="ignore")
23+
24+
@classmethod
25+
def convert_Map_element(cls, dtype):
26+
convert_key = cls.get_element_converter(dtype.key_type)
27+
convert_value = cls.get_element_converter(dtype.value_type)
28+
29+
def convert(raw_row):
30+
if raw_row is None:
31+
return raw_row
32+
33+
row = json.loads(raw_row)
34+
return dict(
35+
zip(map(convert_key, row.keys()), map(convert_value, row.values()))
36+
)
37+
38+
return convert

ibis/backends/postgres/tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from collections.abc import Iterable
2626
from pathlib import Path
2727

28+
import ibis.expr.types as ir
29+
2830
PG_USER = os.environ.get(
2931
"IBIS_TEST_POSTGRES_USER", os.environ.get("PGUSER", "postgres")
3032
)
@@ -46,12 +48,17 @@ class TestConf(ServiceBackendTest):
4648

4749
returned_timestamp_unit = "s"
4850
supports_structs = False
51+
supports_map = True
4952
rounding_method = "half_to_even"
5053
service_name = "postgres"
5154
deps = ("psycopg",)
5255

5356
driver_supports_multiple_statements = True
5457

58+
@property
59+
def map(self) -> ir.Table | None:
60+
return self.connection.table("map").cast({"kv": "map<string, int>"})
61+
5562
@property
5663
def test_files(self) -> Iterable[Path]:
5764
return self.data_dir.joinpath("csv").glob("*.csv")

ibis/backends/sql/compilers/postgres.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
import json
45
import string
56
import textwrap
67
from functools import partial, reduce
@@ -114,9 +115,7 @@ class PostgresCompiler(SQLGlotCompiler):
114115
ops.GeoWithin: "st_within",
115116
ops.GeoX: "st_x",
116117
ops.GeoY: "st_y",
117-
ops.MapContains: "exist",
118-
ops.MapKeys: "akeys",
119-
ops.MapValues: "avals",
118+
ops.MapContains: "jsonb_contains",
120119
ops.RegexSearch: "regexp_like",
121120
ops.TimeFromHMS: "make_time",
122121
ops.RandomUUID: "gen_random_uuid",
@@ -130,8 +129,14 @@ def to_sqlglot(
130129
params: Mapping[ir.Expr, Any] | None = None,
131130
):
132131
table_expr = expr.as_table()
133-
geocols = table_expr.schema().geospatial
134-
conversions = {name: table_expr[name].as_ewkb() for name in geocols}
132+
schema = table_expr.schema()
133+
134+
conversions = {name: table_expr[name].as_ewkb() for name in schema.geospatial}
135+
conversions.update(
136+
(col, table_expr[col].cast(dt.string))
137+
for col, typ in schema.items()
138+
if typ.is_map() or typ.is_json()
139+
)
135140

136141
if conversions:
137142
table_expr = table_expr.mutate(**conversions)
@@ -160,25 +165,44 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF):
160165

161166
type_mapper = self.type_mapper
162167
argnames = udf_node.argnames
163-
return """\
168+
args = ", ".join(argnames)
169+
name = type(udf_node).__name__
170+
argsig = ", ".join(argnames)
171+
raw_args = [
172+
f"json.loads({argname})" if arg.dtype.is_map() else argname
173+
for argname, arg in zip(argnames, udf_node.args)
174+
]
175+
args = ", ".join(raw_args)
176+
call = f"{name}({args})"
177+
defn = """\
164178
CREATE OR REPLACE FUNCTION {ident}({signature})
165179
RETURNS {return_type}
166180
LANGUAGE {language}
167181
AS $$
182+
{json_import}
183+
def {name}({argsig}):
168184
{source}
169-
return {name}({args})
185+
return {call}
170186
$$""".format(
171-
name=type(udf_node).__name__,
172187
ident=self.__sql_name__(udf_node),
173188
signature=", ".join(
174189
f"{argname} {type_mapper.to_string(arg.dtype)}"
175190
for argname, arg in zip(argnames, udf_node.args)
176191
),
177192
return_type=type_mapper.to_string(udf_node.dtype),
178193
language=config.get("language", "plpython3u"),
179-
source=source,
180-
args=", ".join(argnames),
194+
json_import=(
195+
"import json"
196+
if udf_node.dtype.is_map()
197+
or any(arg.dtype.is_map() for arg in udf_node.args)
198+
else ""
199+
),
200+
name=name,
201+
argsig=argsig,
202+
source=textwrap.indent(source, " " * 4),
203+
call=call if not udf_node.dtype.is_map() else f"json.dumps({call})",
181204
)
205+
return defn
182206

183207
def visit_Mode(self, op, *, arg, where):
184208
expr = self.f.mode()
@@ -513,22 +537,61 @@ def visit_ToJSONArray(self, op, *, arg):
513537
def visit_Map(self, op, *, keys, values):
514538
# map(["a", "b"], NULL) results in {"a": NULL, "b": NULL} in regular postgres,
515539
# so we need to modify it to return NULL instead
516-
regular = self.f.map(keys, values)
517-
return self.if_(values.is_(NULL), NULL, regular)
540+
k, v = map(sg.to_identifier, "kv")
541+
regular = (
542+
sg.select(self.f.jsonb_object_agg(k, v))
543+
.from_(
544+
sg.select(
545+
self.f.unnest(keys).as_(k), self.f.unnest(values).as_(v)
546+
).subquery()
547+
)
548+
.subquery()
549+
)
550+
return self.if_(keys.is_(NULL).or_(values.is_(NULL)), NULL, regular)
518551

519552
def visit_MapLength(self, op, *, arg):
520-
return self.f.cardinality(self.f.akeys(arg))
553+
return (
554+
sg.select(self.f.count(sge.Star()))
555+
.from_(self.f.jsonb_object_keys(arg))
556+
.subquery()
557+
)
521558

522559
def visit_MapGet(self, op, *, arg, key, default):
523-
return self.if_(
524-
self.f.exist(arg, key),
525-
self.f.jsonb_extract_path_text(self.f.to_jsonb(arg), key),
526-
default,
527-
)
560+
if op.dtype.is_null():
561+
return NULL
562+
else:
563+
return self.cast(
564+
self.if_(
565+
self.f.jsonb_contains(arg, key),
566+
self.f.jsonb_extract_path_text(arg, key),
567+
default,
568+
),
569+
op.dtype,
570+
)
528571

529572
def visit_MapMerge(self, op, *, left, right):
530573
return sge.DPipe(this=left, expression=right)
531574

575+
def visit_MapKeys(self, op, *, arg):
576+
return self.if_(
577+
arg.is_(NULL), NULL, self.f.array(sg.select(self.f.jsonb_object_keys(arg)))
578+
)
579+
580+
def visit_MapValues(self, op, *, arg):
581+
col = gen_name("json_each_col")
582+
return self.if_(
583+
arg.is_(NULL),
584+
NULL,
585+
self.f.array(
586+
sg.select(
587+
sge.Dot(
588+
this=sg.to_identifier(col),
589+
expression=sg.to_identifier("value", quoted=True),
590+
)
591+
).from_(self.f.jsonb_each(arg).as_(col))
592+
),
593+
)
594+
532595
def visit_TypeOf(self, op, *, arg):
533596
typ = self.cast(self.f.pg_typeof(arg), dt.string)
534597
return self.if_(
@@ -591,6 +654,11 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
591654
)
592655
elif dtype.is_json():
593656
return self.cast(value, dt.json)
657+
elif dtype.is_map():
658+
return sge.Cast(
659+
this=sge.convert(json.dumps(value)),
660+
to=sge.DataType(this=sge.DataType.Type.JSONB),
661+
)
594662
return None
595663

596664
def visit_TimestampFromYMDHMS(

ibis/backends/sql/datatypes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,7 @@ class PostgresType(SqlglotType):
519519
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
520520
if not dtype.key_type.is_string():
521521
raise com.IbisTypeError("Postgres only supports string keys in maps")
522-
if not dtype.value_type.is_string():
523-
raise com.IbisTypeError("Postgres only supports string values in maps")
524-
return sge.DataType(this=typecode.HSTORE)
522+
return sge.DataType(this=typecode.JSONB)
525523

526524
@classmethod
527525
def from_string(cls, text: str, nullable: bool | None = None) -> dt.DataType:

ibis/backends/sql/dialects.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ class Polars(Postgres):
435435

436436

437437
Postgres.Generator.TRANSFORMS |= {
438-
sge.Map: rename_func("hstore"),
439438
sge.Split: rename_func("string_to_array"),
440439
sge.RegexpSplit: rename_func("regexp_split_to_array"),
441440
sge.DateFromParts: rename_func("make_date"),

0 commit comments

Comments
 (0)