Skip to content

Commit dd759d3

Browse files
committed
feat(snowflake): make literal maps and params work
1 parent 045edc7 commit dd759d3

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

ibis/backends/snowflake/registry.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import itertools
4+
import json
45

56
import numpy as np
67
import sqlalchemy as sa
7-
from snowflake.sqlalchemy.custom_types import VARIANT
8+
from snowflake.sqlalchemy import VARIANT
89

910
import ibis.expr.operations as ops
1011
from ibis.backends.base.sql.alchemy.registry import (
@@ -52,7 +53,7 @@ def _literal(t, op):
5253
return sa.func.array_construct(*value)
5354
elif dtype.is_map():
5455
return sa.func.object_construct_keep_null(
55-
*zip(itertools.chain.from_iterable(value.items()))
56+
*itertools.chain.from_iterable(value.items())
5657
)
5758
return _postgres_literal(t, op)
5859

@@ -112,6 +113,17 @@ def _array_slice(t, op):
112113
return sa.func.array_slice(t.translate(op.arg), start, stop)
113114

114115

116+
def _map(_, op):
117+
if not (
118+
isinstance(keys := op.keys, ops.Literal)
119+
and isinstance(values := op.values, ops.Literal)
120+
):
121+
raise TypeError("Both keys and values of an `ibis.map` call must be literals")
122+
123+
obj = dict(zip(keys.value, values.value))
124+
return sa.func.to_object(sa.func.parse_json(json.dumps(obj, separators=",:")))
125+
126+
115127
_SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT)
116128
_SF_NEG_INF = -_SF_POS_INF
117129
_SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT)
@@ -201,7 +213,7 @@ def _array_slice(t, op):
201213
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
202214
*map(t.translate, op.cols)
203215
),
204-
ops.ArraySlice: _array_slice,
216+
ops.Map: _map,
205217
}
206218
)
207219

ibis/backends/tests/test_map.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,16 @@ def test_literal_map_get_broadcast(backend, alltypes, df):
124124
backend.assert_series_equal(result, expected)
125125

126126

127-
@pytest.mark.notyet(["snowflake"])
128-
def test_map_construction(con, alltypes, df):
127+
def test_map_construct_dict(con):
129128
expr = ibis.map(['a', 'b'], [1, 2])
130129
result = con.execute(expr.name('tmp'))
131130
assert result == {'a': 1, 'b': 2}
132131

132+
133+
@pytest.mark.notimpl(
134+
["snowflake"], reason="unclear how to implement two arrays -> object construction"
135+
)
136+
def test_map_construct_array_column(con, alltypes, df):
133137
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
134138
result = con.execute(expr)
135139
expected = df.apply(lambda row: {row['string_col']: row['int_col']}, axis=1)

ibis/backends/tests/test_param.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_timestamp_accepts_date_literals(alltypes):
5858
assert expr.compile(params=params) is not None
5959

6060

61-
@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark", "snowflake"])
61+
@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark"])
6262
@pytest.mark.never(
6363
["mysql", "sqlite", "mssql"], reason="backend will never implement array types"
6464
)
@@ -84,17 +84,7 @@ def test_scalar_param_struct(con):
8484

8585

8686
@pytest.mark.notimpl(
87-
[
88-
"clickhouse",
89-
"datafusion",
90-
# TODO: duckdb maps are tricky because they are multimaps
91-
"duckdb",
92-
"impala",
93-
"pyspark",
94-
"snowflake",
95-
"polars",
96-
"trino",
97-
]
87+
["clickhouse", "datafusion", "duckdb", "impala", "pyspark", "polars", "trino"]
9888
)
9989
@pytest.mark.never(
10090
["mysql", "sqlite", "mssql"],

0 commit comments

Comments
 (0)