Skip to content

Commit 8d8bb70

Browse files
committed
feat(snowflake): add more array operations
1 parent e74328b commit 8d8bb70

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

ibis/backends/snowflake/registry.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def _literal(t, op):
4646
return sa.func.timestamp_from_parts(*args)
4747
elif dtype.is_date():
4848
return sa.func.date_from_parts(value.year, value.month, value.day)
49+
elif dtype.is_array():
50+
return sa.func.array_construct(*value)
4951
return _postgres_literal(t, op)
5052

5153

@@ -88,6 +90,22 @@ def _extract_url_query(t, op):
8890
return sa.func.nullif(sa.func.as_varchar(r), "")
8991

9092

93+
def _array_slice(t, op):
94+
arg = t.translate(op.arg)
95+
96+
if (start := op.start) is not None:
97+
start = t.translate(start)
98+
else:
99+
start = 0
100+
101+
if (stop := op.stop) is not None:
102+
stop = t.translate(stop)
103+
else:
104+
stop = sa.func.array_size(arg)
105+
106+
return sa.func.array_slice(t.translate(op.arg), start, stop)
107+
108+
91109
_SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT)
92110
_SF_NEG_INF = -_SF_POS_INF
93111
_SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT)
@@ -158,6 +176,13 @@ def _extract_url_query(t, op):
158176
),
159177
# snowflake typeof only accepts VARIANT
160178
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.cast(arg, VARIANT))),
179+
ops.ArrayIndex: fixed_arity(sa.func.get, 2),
180+
ops.ArrayLength: fixed_arity(sa.func.array_size, 1),
181+
ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2),
182+
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
183+
*map(t.translate, op.cols)
184+
),
185+
ops.ArraySlice: _array_slice,
161186
}
162187
)
163188

@@ -169,12 +194,7 @@ def _extract_url_query(t, op):
169194
ops.NTile,
170195
ops.NthValue,
171196
# ibis.expr.operations.array
172-
ops.ArrayColumn,
173-
ops.ArrayConcat,
174-
ops.ArrayIndex,
175-
ops.ArrayLength,
176197
ops.ArrayRepeat,
177-
ops.ArraySlice,
178198
ops.Unnest,
179199
# ibis.expr.operations.maps
180200
ops.MapKeys,

ibis/backends/tests/test_array.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
]
2323

2424

25-
@pytest.mark.notimpl(["datafusion", "snowflake"])
25+
@pytest.mark.notimpl(["datafusion"])
2626
def test_array_column(backend, alltypes, df):
2727
expr = ibis.array([alltypes['double_col'], alltypes['double_col']])
2828
assert isinstance(expr, ir.ArrayColumn)
@@ -35,7 +35,6 @@ def test_array_column(backend, alltypes, df):
3535
backend.assert_series_equal(result, expected, check_names=False)
3636

3737

38-
@pytest.mark.notimpl(["snowflake"])
3938
def test_array_scalar(con):
4039
expr = ibis.array([1.0, 2.0, 3.0])
4140
assert isinstance(expr, ir.ArrayScalar)
@@ -48,7 +47,7 @@ def test_array_scalar(con):
4847
assert np.array_equal(result, expected)
4948

5049

51-
@pytest.mark.notimpl(["snowflake", "polars", "datafusion"])
50+
@pytest.mark.notimpl(["polars", "datafusion", "snowflake"])
5251
def test_array_repeat(con):
5352
expr = ibis.array([1.0, 2.0]) * 2
5453

@@ -61,7 +60,7 @@ def test_array_repeat(con):
6160

6261

6362
# Issues #2370
64-
@pytest.mark.notimpl(["datafusion", "snowflake"])
63+
@pytest.mark.notimpl(["datafusion"])
6564
def test_array_concat(con):
6665
left = ibis.literal([1, 2, 3])
6766
right = ibis.literal([2, 1])
@@ -74,13 +73,12 @@ def test_array_concat(con):
7473
assert np.array_equal(result, expected)
7574

7675

77-
@pytest.mark.notimpl(["datafusion", "snowflake"])
76+
@pytest.mark.notimpl(["datafusion"])
7877
def test_array_length(con):
7978
expr = ibis.literal([1, 2, 3]).length()
8079
assert con.execute(expr.name("tmp")) == 3
8180

8281

83-
@pytest.mark.notimpl(["snowflake"])
8482
def test_list_literal(con):
8583
arr = [1, 2, 3]
8684
expr = ibis.literal(arr)
@@ -91,7 +89,6 @@ def test_list_literal(con):
9189
assert np.array_equal(result, arr)
9290

9391

94-
@pytest.mark.notimpl(["snowflake"])
9592
def test_np_array_literal(con):
9693
arr = np.array([1, 2, 3])
9794
expr = ibis.literal(arr)
@@ -103,7 +100,7 @@ def test_np_array_literal(con):
103100

104101

105102
@pytest.mark.parametrize("idx", range(3))
106-
@pytest.mark.notimpl(["snowflake", "polars", "datafusion"])
103+
@pytest.mark.notimpl(["polars", "datafusion"])
107104
def test_array_index(con, idx):
108105
arr = [1, 2, 3]
109106
expr = ibis.literal(arr)
@@ -372,7 +369,7 @@ def test_unnest_default_name(con):
372369
(-3, -1),
373370
],
374371
)
375-
@pytest.mark.notimpl(["dask", "datafusion", "polars", "snowflake"])
372+
@pytest.mark.notimpl(["dask", "datafusion", "polars"])
376373
def test_array_slice(con, start, stop):
377374
array_types = con.tables.array_types
378375
expr = array_types.select(sliced=array_types.y[start:stop])

0 commit comments

Comments
 (0)