Skip to content

Commit 16b0c4c

Browse files
committed
feat(datafusion): basic map operations
1 parent ff6100a commit 16b0c4c

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

ibis/backends/sql/compilers/datafusion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class DataFusionCompiler(SQLGlotCompiler):
6767
ops.EndsWith: "ends_with",
6868
ops.ArrayIntersect: "array_intersect",
6969
ops.ArrayUnion: "array_union",
70+
ops.MapKeys: "map_keys",
71+
ops.MapValues: "map_values",
7072
}
7173

7274
def _to_timestamp(self, value, target_dtype, literal=False):
@@ -541,5 +543,16 @@ def visit_ArrayConcat(self, op, *, arg):
541543
map(partial(self.cast, to=op.dtype), arg),
542544
)
543545

546+
def visit_MapGet(self, op, *, arg, key, default):
547+
if op.dtype.is_null():
548+
return NULL
549+
return self.f.coalesce(self.f.map_extract(arg, key)[1], default)
550+
551+
def visit_MapContains(self, op, *, arg, key):
552+
return self.f.array_has(self.f.map_keys(arg), key)
553+
554+
def visit_MapLength(self, op, *, arg):
555+
return self.f.array_length(self.f.map_keys(arg))
556+
544557

545558
compiler = DataFusionCompiler()

ibis/backends/sql/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,14 @@ class DataFusionType(PostgresType):
556556
"float64": dt.float64,
557557
}
558558

559+
@classmethod
560+
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
561+
key_type = cls.from_ibis(dtype.key_type)
562+
value_type = cls.from_ibis(dtype.value_type)
563+
return sge.DataType(
564+
this=typecode.MAP, expressions=[key_type, value_type], nested=True
565+
)
566+
559567

560568
class MySQLType(SqlglotType):
561569
dialect = "mysql"

ibis/backends/tests/test_map.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def test_map_values_nulls(con, map):
146146
),
147147
"a",
148148
marks=[
149-
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
149+
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL"),
150+
pytest.mark.notyet(["datafusion"], raises=Exception, strict=False),
150151
],
151152
id="null_both_non_null_key",
152153
),
@@ -165,22 +166,23 @@ def test_map_values_nulls(con, map):
165166
ibis.literal(None, type="map<string, string>"),
166167
"a",
167168
marks=[
168-
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
169+
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL"),
170+
mark_notyet_datafusion,
169171
],
170172
id="null_map_non_null_key",
171173
),
172174
param(
173175
ibis.literal(None, type="map<string, string>"),
174176
ibis.literal(None, type="string"),
175177
marks=[
176-
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
178+
pytest.mark.notyet("clickhouse", reason="nested types can't be NULL"),
179+
mark_notyet_datafusion,
177180
],
178181
id="null_map_null_key",
179182
),
180183
],
181184
)
182185
@pytest.mark.parametrize("method", ["get", "contains"])
183-
@mark_notyet_datafusion
184186
def test_map_get_contains_nulls(con, map, key, method):
185187
expr = getattr(map, method)
186188
assert con.execute(expr(key)) is None
@@ -251,7 +253,6 @@ def test_column_map_merge(backend):
251253
tm.assert_series_equal(result, expected)
252254

253255

254-
@mark_notyet_datafusion
255256
def test_literal_map_keys(con):
256257
mapping = ibis.literal({"1": "a", "2": "b"})
257258
expr = mapping.keys().name("tmp")
@@ -262,7 +263,6 @@ def test_literal_map_keys(con):
262263
assert np.array_equal(result, ["1", "2"])
263264

264265

265-
@mark_notyet_datafusion
266266
def test_literal_map_values(con):
267267
mapping = ibis.literal({"1": "a", "2": "b"})
268268
expr = mapping.values().name("tmp")
@@ -271,7 +271,6 @@ def test_literal_map_values(con):
271271
assert np.array_equal(result, ["a", "b"])
272272

273273

274-
@mark_notyet_datafusion
275274
def test_scalar_isin_literal_map_keys(con):
276275
mapping = ibis.literal({"a": 1, "b": 2})
277276
a = ibis.literal("a")
@@ -282,7 +281,6 @@ def test_scalar_isin_literal_map_keys(con):
282281
assert con.execute(false) == False # noqa: E712
283282

284283

285-
@mark_notyet_datafusion
286284
def test_map_scalar_contains_key_scalar(con):
287285
mapping = ibis.literal({"a": 1, "b": 2})
288286
a = ibis.literal("a")
@@ -293,7 +291,6 @@ def test_map_scalar_contains_key_scalar(con):
293291
assert con.execute(false) == False # noqa: E712
294292

295293

296-
@mark_notyet_datafusion
297294
def test_map_scalar_contains_key_column(backend, alltypes, df):
298295
value = {"1": "a", "3": "c"}
299296
mapping = ibis.literal(value)
@@ -303,7 +300,6 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):
303300
backend.assert_series_equal(result, expected)
304301

305302

306-
@mark_notyet_datafusion
307303
def test_map_column_contains_key_scalar(backend, alltypes, df):
308304
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
309305
series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1)
@@ -314,7 +310,6 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):
314310
backend.assert_series_equal(result, series)
315311

316312

317-
@mark_notyet_datafusion
318313
def test_map_column_contains_key_column(alltypes):
319314
map_expr = ibis.map(
320315
ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])
@@ -466,7 +461,6 @@ def test_map_get_all_types(con, keys, values):
466461

467462

468463
@keys
469-
@mark_notyet_datafusion
470464
def test_map_contains_all_types(con, keys):
471465
a = ibis.array(keys)
472466
m = ibis.map(a, a)
@@ -521,23 +515,20 @@ def test_map_construct_array_column(con, alltypes, df):
521515

522516

523517
@mark_notyet_postgres
524-
@mark_notyet_datafusion
525518
def test_map_get_with_compatible_value_smaller(con):
526519
value = ibis.literal({"A": 1000, "B": 2000})
527520
expr = value.get("C", 3)
528521
assert con.execute(expr) == 3
529522

530523

531524
@mark_notyet_postgres
532-
@mark_notyet_datafusion
533525
def test_map_get_with_compatible_value_bigger(con):
534526
value = ibis.literal({"A": 1, "B": 2})
535527
expr = value.get("C", 3000)
536528
assert con.execute(expr) == 3000
537529

538530

539531
@mark_notyet_postgres
540-
@mark_notyet_datafusion
541532
def test_map_get_with_incompatible_value_different_kind(con):
542533
value = ibis.literal({"A": 1000, "B": 2000})
543534
expr = value.get("C", 3.0)
@@ -574,7 +565,6 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value):
574565
["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls"
575566
)
576567
@mark_notyet_postgres
577-
@mark_notyet_datafusion
578568
def test_map_get_with_null_on_null_type_with_non_null(con):
579569
value = ibis.literal({"A": None, "B": None})
580570
expr = value.get("C", 1)
@@ -600,7 +590,6 @@ def test_map_create_table(con, temp_table):
600590
raises=exc.OperationNotDefinedError,
601591
reason="No translation rule for <class 'ibis.expr.operations.maps.MapLength'>",
602592
)
603-
@mark_notyet_datafusion
604593
def test_map_length(con):
605594
expr = ibis.literal(dict(a="A", b="B")).length()
606595
assert con.execute(expr) == 2
@@ -613,7 +602,6 @@ def test_map_keys_unnest(backend):
613602
assert frozenset(result) == frozenset("abcdef")
614603

615604

616-
@mark_notyet_datafusion
617605
def test_map_contains_null(con):
618606
expr = ibis.map(["a"], ibis.literal([None], type="array<string>"))
619607
assert con.execute(expr.contains("a"))

0 commit comments

Comments
 (0)