Skip to content

Commit 75f594d

Browse files
committed
feat(sqlite): support most date/timestamp interval arithmetic
1 parent fe29210 commit 75f594d

File tree

4 files changed

+119
-57
lines changed

4 files changed

+119
-57
lines changed

.github/workflows/ibis-backends.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,10 @@ jobs:
481481
- name: show installed deps
482482
run: poetry run pip list
483483

484+
- name: show version of python-linked sqlite
485+
if: matrix.backend.name == 'sqlite'
486+
run: poetry run python -c 'import sqlite3; print(sqlite3.sqlite_version)'
487+
484488
- name: "run parallel tests: ${{ matrix.backend.name }}"
485489
if: ${{ !matrix.backend.serial }}
486490
run: just ci-check -m ${{ matrix.backend.name }} --numprocesses auto --dist=loadgroup

ibis/backends/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def is_older_than(module_name, given_version):
151151
# For now, many of our tests don't do this, and we're working to change this situation
152152
# by improving all tests file by file. All files that have already been improved are
153153
# added to this list to prevent regression.
154-
FIlES_WITH_STRICT_EXCEPTION_CHECK = [
154+
FILES_WITH_STRICT_EXCEPTION_CHECK = [
155155
"ibis/backends/tests/test_api.py",
156156
"ibis/backends/tests/test_array.py",
157157
"ibis/backends/tests/test_aggregation.py",
@@ -337,7 +337,7 @@ def _filter_none_from_raises(kwargs):
337337
for marker in item.iter_markers(name="notimpl"):
338338
if backend in marker.args[0]:
339339
if (
340-
item.location[0] in FIlES_WITH_STRICT_EXCEPTION_CHECK
340+
item.location[0] in FILES_WITH_STRICT_EXCEPTION_CHECK
341341
and "raises" not in marker.kwargs.keys()
342342
):
343343
raise ValueError("notimpl requires a raises")
@@ -351,7 +351,7 @@ def _filter_none_from_raises(kwargs):
351351
for marker in item.iter_markers(name="notyet"):
352352
if backend in marker.args[0]:
353353
if (
354-
item.location[0] in FIlES_WITH_STRICT_EXCEPTION_CHECK
354+
item.location[0] in FILES_WITH_STRICT_EXCEPTION_CHECK
355355
and "raises" not in marker.kwargs.keys()
356356
):
357357
raise ValueError("notyet requires a raises")

ibis/backends/sql/compilers/sqlite.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import math
4+
import sqlite3
45

56
import sqlglot as sg
67
import sqlglot.expressions as sge
@@ -19,6 +20,8 @@ class SQLiteCompiler(SQLGlotCompiler):
1920

2021
dialect = SQLite
2122
type_mapper = SQLiteType
23+
supports_time_shift_modifiers = sqlite3.sqlite_version_info >= (3, 46, 0)
24+
supports_subsec = sqlite3.sqlite_version_info >= (3, 42, 0)
2225

2326
# We could set `supports_order_by=True` for SQLite >= 3.44.0 (2023-11-01).
2427
agg = AggGen(supports_filter=True)
@@ -53,10 +56,7 @@ class SQLiteCompiler(SQLGlotCompiler):
5356
ops.IntervalSubtract,
5457
ops.IntervalMultiply,
5558
ops.IntervalFloorDivide,
56-
ops.IntervalFromInteger,
5759
ops.TimestampBucket,
58-
ops.TimestampAdd,
59-
ops.TimestampSub,
6060
ops.TimestampDiff,
6161
ops.StringToDate,
6262
ops.StringToTimestamp,
@@ -333,18 +333,65 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
333333
return self._temporal_truncate(self.f.anon.datetime, arg, unit)
334334

335335
def visit_DateArithmetic(self, op, *, left, right):
336-
unit = op.right.dtype.unit
337-
sign = "+" if isinstance(op, ops.DateAdd) else "-"
338-
if unit not in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
336+
right = right.this
337+
338+
if (unit := op.right.dtype.unit) in (
339+
IntervalUnit.QUARTER,
340+
IntervalUnit.MICROSECOND,
341+
IntervalUnit.NANOSECOND,
342+
):
339343
raise com.UnsupportedOperationError(
340-
"SQLite does not allow binary op {sign!r} with INTERVAL offset {unit}"
344+
f"SQLite does not support `{unit}` units in temporal arithmetic"
341345
)
342-
if isinstance(op.right, ops.Literal):
343-
return self.f.date(left, f"{sign}{op.right.value} {unit.plural}")
346+
elif unit == IntervalUnit.WEEK:
347+
unit = IntervalUnit.DAY
348+
right *= 7
349+
elif unit == IntervalUnit.MILLISECOND:
350+
# sqlite doesn't allow milliseconds, so divide milliseconds by 1e3 to
351+
# get seconds, and change the unit to seconds
352+
unit = IntervalUnit.SECOND
353+
right /= 1e3
354+
355+
# compute whether we're adding or subtracting an interval
356+
sign = "+" if isinstance(op, (ops.DateAdd, ops.TimestampAdd)) else "-"
357+
358+
modifiers = []
359+
360+
# floor the result if the unit is a year, month, or day to match other
361+
# backend behavior
362+
if unit in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
363+
if not self.supports_time_shift_modifiers:
364+
raise com.UnsupportedOperationError(
365+
"SQLite does not support time shift modifiers until version 3.46; "
366+
f"found version {sqlite3.sqlite_version}"
367+
)
368+
modifiers.append("floor")
369+
370+
if isinstance(op, (ops.TimestampAdd, ops.TimestampSub)):
371+
# if the left operand is a timestamp, return as much precision as
372+
# possible
373+
if not self.supports_subsec:
374+
raise com.UnsupportedOperationError(
375+
"SQLite does not support subsecond resolution until version 3.42; "
376+
f"found version {sqlite3.sqlite_version}"
377+
)
378+
func = self.f.datetime
379+
modifiers.append("subsec")
344380
else:
345-
return self.f.date(left, self.f.concat(sign, right, f" {unit.plural}"))
381+
func = self.f.date
382+
383+
return func(
384+
left,
385+
self.f.concat(
386+
sign, self.cast(right, dt.string), " ", unit.singular.lower()
387+
),
388+
*modifiers,
389+
dialect=self.dialect,
390+
)
346391

347-
visit_DateAdd = visit_DateSub = visit_DateArithmetic
392+
visit_TimestampAdd = visit_TimestampSub = visit_DateAdd = visit_DateSub = (
393+
visit_DateArithmetic
394+
)
348395

349396
def visit_DateDiff(self, op, *, left, right):
350397
return self.f.julianday(left) - self.f.julianday(right)

0 commit comments

Comments
 (0)