Skip to content

Commit 7ae6e25

Browse files
committed
feat(snowflake): add more map operations
1 parent 8d8bb70 commit 7ae6e25

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

ci/schema/snowflake.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ INSERT INTO array_types ("x", "y", "z", "grouper", "scalar_column", "multi_dim")
9494
SELECT [2, NULL, 3], ['b', NULL, 'c'], NULL, 'b', 5.0, NULL UNION
9595
SELECT [4, NULL, NULL, 5], ['d', NULL, NULL, 'e'], [4.0, NULL, NULL, 5.0], 'c', 6.0, [[1, 2, 3]];
9696

97+
CREATE OR REPLACE TABLE map ("kv" OBJECT);
98+
99+
INSERT INTO map ("kv")
100+
SELECT object_construct('a', 1, 'b', 2, 'c', 3) UNION
101+
SELECT object_construct('d', 4, 'e', 5, 'c', 6);
102+
103+
97104
CREATE OR REPLACE TABLE struct ("abc" OBJECT);
98105

99106
INSERT INTO struct ("abc")

ibis/backends/snowflake/registry.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import itertools
4+
35
import numpy as np
46
import sqlalchemy as sa
57
from snowflake.sqlalchemy.custom_types import VARIANT
@@ -48,6 +50,10 @@ def _literal(t, op):
4850
return sa.func.date_from_parts(value.year, value.month, value.day)
4951
elif dtype.is_array():
5052
return sa.func.array_construct(*value)
53+
elif dtype.is_map():
54+
return sa.func.object_construct_keep_null(
55+
*zip(itertools.chain.from_iterable(value.items()))
56+
)
5157
return _postgres_literal(t, op)
5258

5359

@@ -116,6 +122,19 @@ def _array_slice(t, op):
116122
ops.StructField: fixed_arity(sa.func.get, 2),
117123
ops.StringFind: _string_find,
118124
ops.MapKeys: unary(sa.func.object_keys),
125+
ops.MapGet: fixed_arity(
126+
lambda arg, key, default: sa.func.coalesce(
127+
sa.func.get(arg, key), sa.cast(default, VARIANT)
128+
),
129+
3,
130+
),
131+
ops.MapContains: fixed_arity(
132+
lambda arg, key: sa.func.array_contains(
133+
sa.func.cast(key, VARIANT), sa.func.object_keys(arg)
134+
),
135+
2,
136+
),
137+
ops.MapLength: unary(lambda arg: sa.func.array_size(sa.func.object_keys(arg))),
119138
ops.BitwiseLeftShift: fixed_arity(sa.func.bitshiftleft, 2),
120139
ops.BitwiseRightShift: fixed_arity(sa.func.bitshiftright, 2),
121140
ops.Ln: unary(sa.func.ln),
@@ -196,8 +215,6 @@ def _array_slice(t, op):
196215
# ibis.expr.operations.array
197216
ops.ArrayRepeat,
198217
ops.Unnest,
199-
# ibis.expr.operations.maps
200-
ops.MapKeys,
201218
# ibis.expr.operations.reductions
202219
ops.All,
203220
ops.Any,

ibis/backends/tests/test_map.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,15 @@
55
import ibis.expr.datatypes as dt
66

77
pytestmark = [
8-
pytest.mark.never(["sqlite", "mysql"], reason="No map support"),
8+
pytest.mark.never(
9+
["sqlite", "mysql", "mssql", "postgres"], reason="No map support"
10+
),
11+
pytest.mark.notyet(
12+
["bigquery", "impala"], reason="backend doesn't implement map types"
13+
),
914
pytest.mark.notimpl(
10-
[
11-
"duckdb",
12-
"postgres",
13-
"impala",
14-
"datafusion",
15-
"pyspark",
16-
"snowflake",
17-
"polars",
18-
"mssql",
19-
],
20-
reason="Not implemented yet",
15+
["duckdb", "datafusion", "pyspark", "polars"], reason="Not implemented yet"
2116
),
22-
pytest.mark.notyet(["bigquery"], reason="BigQuery doesn't implement map types"),
2317
]
2418

2519

@@ -49,6 +43,7 @@ def test_literal_map_values(con):
4943

5044

5145
@pytest.mark.notimpl(["trino"])
46+
@pytest.mark.notyet(["snowflake"])
5247
def test_scalar_isin_literal_map_keys(con):
5348
mapping = ibis.literal({'a': 1, 'b': 2})
5449
a = ibis.literal('a')
@@ -78,6 +73,7 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):
7873
backend.assert_series_equal(result, expected)
7974

8075

76+
@pytest.mark.notyet(["snowflake"])
8177
def test_map_column_contains_key_scalar(backend, alltypes, df):
8278
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
8379
series = df.apply(lambda row: {row['string_col']: row['int_col']}, axis=1)
@@ -88,12 +84,14 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):
8884
backend.assert_series_equal(result, series)
8985

9086

87+
@pytest.mark.notyet(["snowflake"])
9188
def test_map_column_contains_key_column(backend, alltypes, df):
9289
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
9390
result = expr.contains(alltypes.string_col).name('tmp').execute()
9491
assert result.all()
9592

9693

94+
@pytest.mark.notyet(["snowflake"])
9795
def test_literal_map_merge(con):
9896
a = ibis.literal({'a': 0, 'b': 2})
9997
b = ibis.literal({'a': 1, 'c': 3})
@@ -126,6 +124,7 @@ def test_literal_map_get_broadcast(backend, alltypes, df):
126124
backend.assert_series_equal(result, expected)
127125

128126

127+
@pytest.mark.notyet(["snowflake"])
129128
def test_map_construction(con, alltypes, df):
130129
expr = ibis.map(['a', 'b'], [1, 2])
131130
result = con.execute(expr.name('tmp'))

0 commit comments

Comments
 (0)