Skip to content

Commit f5a0a5a

Browse files
committed
feat(api): add ibis.range function for generating sequences
1 parent 5d1fadf commit f5a0a5a

File tree

13 files changed

+355
-5
lines changed

13 files changed

+355
-5
lines changed

ibis/backends/bigquery/registry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,16 @@ def _group_concat(translator, op):
776776
return f"STRING_AGG({arg}, {sep})"
777777

778778

779+
def _integer_range(translator, op):
780+
start = translator.translate(op.start)
781+
stop = translator.translate(op.stop)
782+
step = translator.translate(op.step)
783+
n = f"FLOOR(({stop} - {start}) / NULLIF({step}, 0))"
784+
gen_array = f"GENERATE_ARRAY({start}, {stop}, {step})"
785+
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
786+
return f"IF({n} > 0, ARRAY({inner}), [])"
787+
788+
779789
OPERATION_REGISTRY = {
780790
**operation_registry,
781791
# Literal
@@ -939,6 +949,7 @@ def _group_concat(translator, op):
939949
ops.TimeDelta: _time_delta,
940950
ops.DateDelta: _date_delta,
941951
ops.TimestampDelta: _timestamp_delta,
952+
ops.IntegerRange: _integer_range,
942953
}
943954

944955
_invalid_operations = {

ibis/backends/clickhouse/compiler/values.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ def formatter(op, *, left, right, **_):
817817
ops.ExtractFragment: "fragment",
818818
ops.ArrayPosition: "indexOf",
819819
ops.ArrayFlatten: "arrayFlatten",
820+
ops.IntegerRange: "range",
820821
}
821822

822823

ibis/backends/duckdb/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def _to_json_collection(t, op):
497497
ops.ToJSONMap: _to_json_collection,
498498
ops.ToJSONArray: _to_json_collection,
499499
ops.ArrayFlatten: unary(sa.func.flatten),
500+
ops.IntegerRange: fixed_arity(sa.func.range, 3),
500501
}
501502
)
502503

ibis/backends/polars/compiler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,3 +1198,20 @@ def execute_agg_udf(op, **kw):
11981198
if (where := op.where) is not None:
11991199
first = first.filter(translate(where, **kw))
12001200
return getattr(first, op.__func_name__)(*rest)
1201+
1202+
1203+
@translate.register(ops.IntegerRange)
1204+
def execute_integer_range(op, **kw):
1205+
if not isinstance(op.step, ops.Literal):
1206+
raise NotImplementedError("Dynamic step not supported by Polars")
1207+
step = op.step.value
1208+
1209+
dtype = dtype_to_polars(op.dtype)
1210+
empty = pl.int_ranges(0, 0, dtype=dtype)
1211+
1212+
if step == 0:
1213+
return empty
1214+
1215+
start = translate(op.start, **kw)
1216+
stop = translate(op.stop, **kw)
1217+
return pl.int_ranges(start, stop, step, dtype=dtype)

ibis/backends/postgres/registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,27 @@ def _array_filter(t, op):
604604
)
605605

606606

607+
def _integer_range(t, op):
608+
start = t.translate(op.start)
609+
stop = t.translate(op.stop)
610+
step = t.translate(op.step)
611+
satype = t.get_sqla_type(op.dtype)
612+
# `sequence` doesn't allow arguments that would produce an empty range, so
613+
# check that first
614+
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
615+
seq = sa.func.generate_series(start, stop, step, type_=satype)
616+
return sa.case(
617+
# TODO(cpcloud): revisit using array_remove when my brain is working
618+
(
619+
n > 0,
620+
sa.func.array_remove(
621+
sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype
622+
),
623+
),
624+
else_=sa.cast(pg.array([]), satype),
625+
)
626+
627+
607628
operation_registry.update(
608629
{
609630
ops.Literal: _literal,
@@ -802,5 +823,6 @@ def _array_filter(t, op):
802823
ops.ArrayPosition: fixed_arity(_array_position, 2),
803824
ops.ArrayMap: _array_map,
804825
ops.ArrayFilter: _array_filter,
826+
ops.IntegerRange: _integer_range,
805827
}
806828
)

ibis/backends/pyspark/compiler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,3 +2080,20 @@ def compile_levenshtein(t, op, **kwargs):
20802080
@compiles(ops.ArrayFlatten)
20812081
def compile_flatten(t, op, **kwargs):
20822082
return F.flatten(t.translate(op.arg, **kwargs))
2083+
2084+
2085+
@compiles(ops.IntegerRange)
2086+
def compile_integer_range(t, op, **kwargs):
2087+
start = t.translate(op.start, **kwargs)
2088+
stop = t.translate(op.stop, **kwargs)
2089+
step = t.translate(op.step, **kwargs)
2090+
2091+
denom = F.when(step == 0, F.lit(None)).otherwise(step)
2092+
n = F.floor((stop - start) / denom)
2093+
seq = F.sequence(start, stop, step)
2094+
seq = F.slice(
2095+
seq,
2096+
1,
2097+
F.size(seq) - F.when(F.element_at(seq, F.size(seq)) == stop, 1).otherwise(0),
2098+
)
2099+
return F.when(n > 0, seq).otherwise(F.array())

ibis/backends/snowflake/registry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,14 @@ def _timestamp_bucket(t, op):
491491
lambda part, left, right: sa.func.timestampdiff(part, right, left), 3
492492
),
493493
ops.TimestampBucket: _timestamp_bucket,
494+
ops.IntegerRange: fixed_arity(
495+
lambda start, stop, step: sa.func.iff(
496+
step != 0,
497+
sa.func.array_generate_range(start, stop, step),
498+
sa.func.array_construct(),
499+
),
500+
3,
501+
),
494502
}
495503
)
496504

ibis/backends/tests/test_array.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
except ImportError:
4040
PySparkAnalysisException = None
4141

42+
43+
try:
44+
from polars.exceptions import PolarsInvalidOperationError
45+
except ImportError:
46+
PolarsInvalidOperationError = None
47+
4248
pytestmark = [
4349
pytest.mark.never(
4450
["sqlite", "mysql", "mssql"],
@@ -910,3 +916,135 @@ def test_array_flatten(backend, flatten_data, column, expected):
910916
expr = t[column].flatten()
911917
result = backend.connection.execute(expr)
912918
backend.assert_series_equal(result, expected, check_names=False)
919+
920+
921+
polars_overflow = pytest.mark.notyet(
922+
["polars"],
923+
reason="but polars overflows allocation with some inputs",
924+
raises=BaseException,
925+
)
926+
927+
928+
@pytest.mark.notyet(
929+
["datafusion"],
930+
reason="range isn't implemented upstream",
931+
raises=com.OperationNotDefinedError,
932+
)
933+
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
934+
@pytest.mark.parametrize("n", [param(-2, marks=[polars_overflow]), 0, 2])
935+
def test_range_single_argument(con, n):
936+
expr = ibis.range(n)
937+
result = con.execute(expr)
938+
assert list(result) == list(range(n))
939+
940+
941+
@pytest.mark.notyet(
942+
["datafusion"],
943+
reason="range and unnest aren't implemented upstream",
944+
raises=com.OperationNotDefinedError,
945+
)
946+
@pytest.mark.parametrize(
947+
"n",
948+
[
949+
param(
950+
-2,
951+
marks=[
952+
pytest.mark.broken(
953+
["snowflake"],
954+
reason="snowflake unnests empty arrays to null",
955+
raises=AssertionError,
956+
)
957+
],
958+
),
959+
param(
960+
0,
961+
marks=[
962+
pytest.mark.broken(
963+
["snowflake"],
964+
reason="snowflake unnests empty arrays to null",
965+
raises=AssertionError,
966+
)
967+
],
968+
),
969+
2,
970+
],
971+
)
972+
@pytest.mark.notimpl(
973+
["polars", "flink", "pandas", "dask"], raises=com.OperationNotDefinedError
974+
)
975+
def test_range_single_argument_unnest(con, n):
976+
expr = ibis.range(n).unnest()
977+
result = con.execute(expr)
978+
tm.assert_series_equal(
979+
result,
980+
pd.Series(list(range(n)), dtype=result.dtype, name=expr.get_name()),
981+
check_index=False,
982+
)
983+
984+
985+
@pytest.mark.parametrize(
986+
"step",
987+
[
988+
param(
989+
-2,
990+
marks=[
991+
pytest.mark.notyet(
992+
["polars"],
993+
reason="panic upstream",
994+
raises=PolarsInvalidOperationError,
995+
)
996+
],
997+
),
998+
param(
999+
-1,
1000+
marks=[
1001+
pytest.mark.notyet(
1002+
["polars"],
1003+
reason="panic upstream",
1004+
raises=PolarsInvalidOperationError,
1005+
)
1006+
],
1007+
),
1008+
1,
1009+
2,
1010+
],
1011+
)
1012+
@pytest.mark.parametrize(
1013+
("start", "stop"),
1014+
[
1015+
param(-7, -7),
1016+
param(-7, 0),
1017+
param(-7, 7),
1018+
param(0, -7, marks=[polars_overflow]),
1019+
param(0, 0),
1020+
param(0, 7),
1021+
param(7, -7, marks=[polars_overflow]),
1022+
param(7, 0, marks=[polars_overflow]),
1023+
param(7, 7),
1024+
],
1025+
)
1026+
@pytest.mark.notyet(
1027+
["datafusion"],
1028+
reason="range and unnest aren't implemented upstream",
1029+
raises=com.OperationNotDefinedError,
1030+
)
1031+
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
1032+
def test_range_start_stop_step(con, start, stop, step):
1033+
expr = ibis.range(start, stop, step)
1034+
result = con.execute(expr)
1035+
assert list(result) == list(range(start, stop, step))
1036+
1037+
1038+
@pytest.mark.parametrize("stop", [-7, 0, 7])
1039+
@pytest.mark.parametrize("start", [-7, 0, 7])
1040+
@pytest.mark.notyet(
1041+
["clickhouse"], raises=ClickhouseDatabaseError, reason="not supported upstream"
1042+
)
1043+
@pytest.mark.notyet(
1044+
["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream"
1045+
)
1046+
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
1047+
def test_range_start_stop_step_zero(con, start, stop):
1048+
expr = ibis.range(start, stop, 0)
1049+
result = con.execute(expr)
1050+
assert list(result) == []

ibis/backends/trino/registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,24 @@ def _interval_from_integer(t, op):
350350
return sa.type_coerce(sa.func.parse_duration(arg), INTERVAL)
351351

352352

353+
def _integer_range(t, op):
354+
start = t.translate(op.start)
355+
stop = t.translate(op.stop)
356+
step = t.translate(op.step)
357+
satype = t.get_sqla_type(op.dtype)
358+
# `sequence` doesn't allow arguments that would produce an empty range, so
359+
# check that first
360+
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
361+
return if_(
362+
n > 0,
363+
# TODO(cpcloud): revisit using array_remove when my brain is working
364+
sa.func.array_remove(
365+
sa.func.sequence(start, stop, step, type_=satype), stop, type_=satype
366+
),
367+
sa.literal_column("ARRAY[]"),
368+
)
369+
370+
353371
operation_registry.update(
354372
{
355373
# conditional expressions
@@ -547,6 +565,7 @@ def _interval_from_integer(t, op):
547565
ops.IntervalAdd: fixed_arity(operator.add, 2),
548566
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
549567
ops.IntervalFromInteger: _interval_from_integer,
568+
ops.IntegerRange: _integer_range,
550569
}
551570
)
552571

0 commit comments

Comments
 (0)