Skip to content

Commit c567fe0

Browse files
committed
feat(api): add timestamp range
1 parent f20e34e commit c567fe0

File tree

16 files changed

+497
-43
lines changed

16 files changed

+497
-43
lines changed

ibis/backends/bigquery/registry.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -776,14 +776,62 @@ def _group_concat(translator, op):
776776
return f"STRING_AGG({arg}, {sep})"
777777

778778

779-
def _integer_range(translator, op):
780-
start = translator.translate(op.start)
781-
stop = translator.translate(op.stop)
782-
step = translator.translate(op.step)
783-
n = f"FLOOR(({stop} - {start}) / NULLIF({step}, 0))"
784-
gen_array = f"GENERATE_ARRAY({start}, {stop}, {step})"
785-
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
786-
return f"IF({n} > 0, ARRAY({inner}), [])"
779+
def _zero(dtype):
780+
if dtype.is_interval():
781+
return "MAKE_INTERVAL()"
782+
return "0"
783+
784+
785+
def _sign(value, dtype):
786+
if dtype.is_interval():
787+
zero = _zero(dtype)
788+
return f"""\
789+
CASE
790+
WHEN {value} < {zero} THEN -1
791+
WHEN {value} = {zero} THEN 0
792+
WHEN {value} > {zero} THEN 1
793+
ELSE NULL
794+
END"""
795+
return f"SIGN({value})"
796+
797+
798+
def _nullifzero(step, zero, step_dtype):
799+
if step_dtype.is_interval():
800+
return f"IF({step} = {zero}, NULL, {step})"
801+
return f"NULLIF({step}, {zero})"
802+
803+
804+
def _make_range(func):
805+
def _range(translator, op):
806+
start = translator.translate(op.start)
807+
stop = translator.translate(op.stop)
808+
step = translator.translate(op.step)
809+
810+
step_dtype = op.step.dtype
811+
step_sign = _sign(step, step_dtype)
812+
delta_sign = _sign(step, step_dtype)
813+
zero = _zero(step_dtype)
814+
nullifzero = _nullifzero(step, zero, step_dtype)
815+
816+
condition = f"{nullifzero} IS NOT NULL AND {step_sign} = {delta_sign}"
817+
gen_array = f"{func}({start}, {stop}, {step})"
818+
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
819+
return f"IF({condition}, ARRAY({inner}), [])"
820+
821+
return _range
822+
823+
824+
def _timestamp_range(translator, op):
825+
start = op.start
826+
stop = op.stop
827+
828+
if start.dtype.timezone is None or stop.dtype.timezone is None:
829+
raise com.IbisTypeError(
830+
"Timestamps without timezone values are not supported when generating timestamp ranges"
831+
)
832+
833+
rule = _make_range("GENERATE_TIMESTAMP_ARRAY")
834+
return rule(translator, op)
787835

788836

789837
OPERATION_REGISTRY = {
@@ -949,7 +997,8 @@ def _integer_range(translator, op):
949997
ops.TimeDelta: _time_delta,
950998
ops.DateDelta: _date_delta,
951999
ops.TimestampDelta: _timestamp_delta,
952-
ops.IntegerRange: _integer_range,
1000+
ops.IntegerRange: _make_range("GENERATE_ARRAY"),
1001+
ops.TimestampRange: _timestamp_range,
9531002
}
9541003

9551004
_invalid_operations = {

ibis/backends/clickhouse/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ast
44
import atexit
55
import glob
6+
import warnings
67
from contextlib import closing, suppress
78
from functools import partial
89
from typing import TYPE_CHECKING, Any, Literal
@@ -169,6 +170,11 @@ def do_connect(
169170
compress=compression,
170171
**kwargs,
171172
)
173+
try:
174+
with closing(self.raw_sql("SET session_timezone = 'UTC'")):
175+
pass
176+
except Exception as e: # noqa: BLE001
177+
warnings.warn(f"Could not set timezone to UTC: {e}", category=UserWarning)
172178
self._temp_views = set()
173179

174180
@property

ibis/backends/clickhouse/compiler/values.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,30 @@ def _agg_udf(op, *, where, **kw) -> str:
10171017
@translate_val.register(ops.TimestampDelta)
10181018
def _delta(op, *, part, left, right, **_):
10191019
return sg.exp.DateDiff(this=left, expression=right, unit=part)
1020+
1021+
1022+
@translate_val.register(ops.TimestampRange)
1023+
def _timestamp_range(op, *, start, stop, step, **_):
1024+
unit = op.step.dtype.unit.name.lower()
1025+
1026+
if not isinstance(op.step, ops.Literal):
1027+
raise com.UnsupportedOperationError(
1028+
"ClickHouse doesn't support non-literal step values"
1029+
)
1030+
1031+
step_value = op.step.value
1032+
1033+
offset = sg.to_identifier("offset")
1034+
1035+
# e.g., offset -> dateAdd(DAY, offset, start)
1036+
func = sg.exp.Lambda(
1037+
this=F.dateAdd(sg.to_identifier(unit), offset, start), expressions=[offset]
1038+
)
1039+
1040+
if step_value == 0:
1041+
return F.array()
1042+
1043+
result = F.arrayMap(
1044+
func, F.range(0, F.timestampDiff(unit, start, stop), step_value)
1045+
)
1046+
return result

ibis/backends/duckdb/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _literal(t, op):
183183
sqla_type = t.get_sqla_type(dtype)
184184

185185
if dtype.is_interval():
186-
return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'")
186+
return getattr(sa.func, f"to_{dtype.unit.plural}")(value)
187187
elif dtype.is_array():
188188
values = value.tolist() if isinstance(value, np.ndarray) else value
189189
return sa.cast(sa.func.list_value(*values), sqla_type)
@@ -550,6 +550,8 @@ def _array_remove(t, op):
550550
ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2),
551551
ops.GeoX: unary(sa.func.ST_X),
552552
ops.GeoY: unary(sa.func.ST_Y),
553+
# other ops
554+
ops.TimestampRange: fixed_arity(sa.func.range, 3),
553555
}
554556
)
555557

ibis/backends/polars/compiler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,9 @@ def execute_agg_udf(op, **kw):
12051205
@translate.register(ops.IntegerRange)
12061206
def execute_integer_range(op, **kw):
12071207
if not isinstance(op.step, ops.Literal):
1208-
raise NotImplementedError("Dynamic step not supported by Polars")
1208+
raise com.UnsupportedOperationError(
1209+
"Dynamic integer step not supported by Polars"
1210+
)
12091211
step = op.step.value
12101212

12111213
dtype = dtype_to_polars(op.dtype)
@@ -1217,3 +1219,17 @@ def execute_integer_range(op, **kw):
12171219
start = translate(op.start, **kw)
12181220
stop = translate(op.stop, **kw)
12191221
return pl.int_ranges(start, stop, step, dtype=dtype)
1222+
1223+
1224+
@translate.register(ops.TimestampRange)
1225+
def execute_timestamp_range(op, **kw):
1226+
if not isinstance(op.step, ops.Literal):
1227+
raise com.UnsupportedOperationError(
1228+
"Dynamic interval step not supported by Polars"
1229+
)
1230+
step = op.step.value
1231+
unit = op.step.dtype.unit.value
1232+
1233+
start = translate(op.start, **kw)
1234+
stop = translate(op.stop, **kw)
1235+
return pl.datetime_ranges(start, stop, f"{step}{unit}", closed="left")

ibis/backends/postgres/registry.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -618,19 +618,36 @@ def _array_filter(t, op):
618618
)
619619

620620

621-
def _integer_range(t, op):
621+
def zero_value(dtype):
622+
if dtype.is_interval():
623+
return sa.func.make_interval()
624+
return 0
625+
626+
627+
def interval_sign(v):
628+
zero = sa.func.make_interval()
629+
return sa.case((v == zero, 0), (v < zero, -1), (v > zero, 1))
630+
631+
632+
def _sign(value, dtype):
633+
if dtype.is_interval():
634+
return interval_sign(value)
635+
return sa.func.sign(value)
636+
637+
638+
def _range(t, op):
622639
start = t.translate(op.start)
623640
stop = t.translate(op.stop)
624641
step = t.translate(op.step)
625642
satype = t.get_sqla_type(op.dtype)
626-
# `sequence` doesn't allow arguments that would produce an empty range, so
627-
# check that first
628-
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
629643
seq = sa.func.generate_series(start, stop, step, type_=satype)
644+
zero = zero_value(op.step.dtype)
630645
return sa.case(
631-
# TODO(cpcloud): revisit using array_remove when my brain is working
632646
(
633-
n > 0,
647+
sa.and_(
648+
sa.func.nullif(step, zero).is_not(None),
649+
_sign(step, op.step.dtype) == _sign(stop - start, op.step.dtype),
650+
),
634651
sa.func.array_remove(
635652
sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype
636653
),
@@ -839,6 +856,7 @@ def _integer_range(t, op):
839856
ops.ArrayPosition: fixed_arity(_array_position, 2),
840857
ops.ArrayMap: _array_map,
841858
ops.ArrayFilter: _array_filter,
842-
ops.IntegerRange: _integer_range,
859+
ops.IntegerRange: _range,
860+
ops.TimestampRange: _range,
843861
}
844862
)

ibis/backends/pyspark/compiler.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,18 +2082,45 @@ def compile_flatten(t, op, **kwargs):
20822082
return F.flatten(t.translate(op.arg, **kwargs))
20832083

20842084

2085+
def _zero_value(dtype):
2086+
if dtype.is_interval():
2087+
return F.expr(f"INTERVAL 0 {dtype.resolution}")
2088+
return F.lit(0)
2089+
2090+
2091+
def _build_sequence(start, stop, step, zero):
2092+
seq = F.sequence(start, stop, step)
2093+
length = F.size(seq)
2094+
last_element = F.element_at(seq, length)
2095+
# slice off the last element if we'd be inclusive on the right
2096+
seq = F.when(last_element == stop, F.slice(seq, 1, length - 1)).otherwise(seq)
2097+
return F.when(
2098+
(step != zero) & (F.signum(step) == F.signum(stop - start)), seq
2099+
).otherwise(F.array())
2100+
2101+
20852102
@compiles(ops.IntegerRange)
20862103
def compile_integer_range(t, op, **kwargs):
20872104
start = t.translate(op.start, **kwargs)
20882105
stop = t.translate(op.stop, **kwargs)
20892106
step = t.translate(op.step, **kwargs)
20902107

2091-
denom = F.when(step == 0, F.lit(None)).otherwise(step)
2092-
n = F.floor((stop - start) / denom)
2093-
seq = F.sequence(start, stop, step)
2094-
seq = F.slice(
2095-
seq,
2096-
1,
2097-
F.size(seq) - F.when(F.element_at(seq, F.size(seq)) == stop, 1).otherwise(0),
2098-
)
2099-
return F.when(n > 0, seq).otherwise(F.array())
2108+
return _build_sequence(start, stop, step, _zero_value(op.step.dtype))
2109+
2110+
2111+
@compiles(ops.TimestampRange)
2112+
def compile_timestamp_range(t, op, **kwargs):
2113+
start = t.translate(op.start, **kwargs)
2114+
stop = t.translate(op.stop, **kwargs)
2115+
2116+
if not isinstance(op.step, ops.Literal):
2117+
raise com.UnsupportedOperationError(
2118+
"`step` argument of timestamp range must be a literal"
2119+
)
2120+
2121+
step_value = op.step.value
2122+
unit = op.step.dtype.resolution
2123+
2124+
step = F.expr(f"INTERVAL {step_value} {unit}")
2125+
2126+
return _build_sequence(start, stop, step, _zero_value(op.step.dtype))

ibis/backends/snowflake/converter.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
from typing import TYPE_CHECKING
45

56
from ibis.formats.pandas import PandasData
@@ -18,7 +19,31 @@ def convert_JSON(s, dtype, pandas_type):
1819
converter = SnowflakePandasData.convert_JSON_element(dtype)
1920
return s.map(converter, na_action="ignore").astype("object")
2021

21-
convert_Struct = convert_Array = convert_Map = convert_JSON
22+
convert_Struct = convert_Map = convert_JSON
23+
24+
@staticmethod
25+
def get_element_converter(dtype):
26+
funcgen = getattr(
27+
SnowflakePandasData,
28+
f"convert_{type(dtype).__name__}_element",
29+
lambda _: lambda x: x,
30+
)
31+
return funcgen(dtype)
32+
33+
def convert_Timestamp_element(dtype):
34+
return lambda values: list(map(datetime.datetime.fromisoformat, values))
35+
36+
def convert_Date_element(dtype):
37+
return lambda values: list(map(datetime.date.fromisoformat, values))
38+
39+
def convert_Time_element(dtype):
40+
return lambda values: list(map(datetime.time.fromisoformat, values))
41+
42+
@staticmethod
43+
def convert_Array(s, dtype, pandas_type):
44+
raw_json_objects = SnowflakePandasData.convert_JSON(s, dtype, pandas_type)
45+
converter = SnowflakePandasData.get_element_converter(dtype.value_type)
46+
return raw_json_objects.map(converter, na_action="ignore")
2247

2348

2449
class SnowflakePyArrowData(PyArrowData):

ibis/backends/snowflake/registry.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,65 @@ def _timestamp_bucket(t, op):
280280
)
281281

282282

283+
class _flatten(sa.sql.functions.GenericFunction):
284+
def __init__(self, arg, *, type: sa.types.TypeEngine) -> None:
285+
super().__init__(arg)
286+
self.type = sa.sql.sqltypes.TableValueType(
287+
sa.Column("index", sa.BIGINT()), sa.Column("value", type)
288+
)
289+
290+
291+
@compiles(_flatten, "snowflake")
292+
def compiles_flatten(element, compiler, **kw):
293+
(arg,) = element.clauses.clauses
294+
return f"TABLE(FLATTEN(INPUT => {compiler.process(arg, **kw)}, MODE => 'ARRAY'))"
295+
296+
297+
def _timestamp_range(t, op):
298+
if not isinstance(op.step, ops.Literal):
299+
raise com.UnsupportedOperationError("`step` argument must be a literal")
300+
301+
start = t.translate(op.start)
302+
stop = t.translate(op.stop)
303+
304+
unit = op.step.dtype.unit.name.lower()
305+
step = op.step.value
306+
307+
value_type = op.dtype.value_type
308+
309+
f = _flatten(
310+
sa.func.array_generate_range(0, sa.func.datediff(unit, start, stop), step),
311+
type=t.get_sqla_type(op.start.dtype),
312+
).alias("f")
313+
return sa.func.iff(
314+
step != 0,
315+
sa.select(
316+
sa.func.array_agg(
317+
sa.func.replace(
318+
# conversion to varchar is necessary to control
319+
# the timestamp format
320+
#
321+
# otherwise, since timestamps in arrays become strings
322+
# anyway due to lack of parameterized type support in
323+
# Snowflake the format depends on a session parameter
324+
sa.func.to_varchar(
325+
sa.func.dateadd(unit, f.c.value, start),
326+
'YYYY-MM-DD"T"HH24:MI:SS.FF6'
327+
+ (value_type.timezone is not None) * "TZH:TZM",
328+
),
329+
# timezones are always hour:minute offsets from UTC, not
330+
# named, so replacing "Z" shouldn't be an issue
331+
"Z",
332+
"+00:00",
333+
),
334+
)
335+
)
336+
.select_from(f)
337+
.scalar_subquery(),
338+
sa.func.array_construct(),
339+
)
340+
341+
283342
_TIMESTAMP_UNITS_TO_SCALE = {"s": 0, "ms": 3, "us": 6, "ns": 9}
284343

285344
_SF_POS_INF = sa.func.to_double("Inf")
@@ -504,6 +563,7 @@ def _timestamp_bucket(t, op):
504563
),
505564
3,
506565
),
566+
ops.TimestampRange: _timestamp_range,
507567
}
508568
)
509569

0 commit comments

Comments
 (0)