Skip to content

Commit 79cef68

Browse files
authored
feat(timestamps): add support for timestamp/date +/- intervals for additional backends (#9799)
1 parent de6d988 commit 79cef68

File tree

8 files changed

+103
-96
lines changed

8 files changed

+103
-96
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,7 @@ def visit_DayOfWeekName(self, op, *, arg):
942942
)
943943

944944
def visit_IntervalFromInteger(self, op, *, arg, unit):
945-
return sge.Interval(
946-
this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper())
947-
)
945+
return sge.Interval(this=arg, unit=self.v[unit.singular.upper()])
948946

949947
### String Instruments
950948
def visit_Strip(self, op, *, arg):

ibis/backends/sql/compilers/mssql.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,12 @@ class MSSQLCompiler(SQLGlotCompiler):
8686
ops.BitXor,
8787
ops.Covariance,
8888
ops.CountDistinctStar,
89-
ops.DateAdd,
9089
ops.DateDiff,
91-
ops.DateSub,
9290
ops.EndsWith,
9391
ops.IntervalAdd,
94-
ops.IntervalFromInteger,
95-
ops.IntervalMultiply,
9692
ops.IntervalSubtract,
93+
ops.IntervalMultiply,
94+
ops.IntervalFloorDivide,
9795
ops.IsInf,
9896
ops.IsNan,
9997
ops.LPad,
@@ -115,9 +113,7 @@ class MSSQLCompiler(SQLGlotCompiler):
115113
ops.StringToDate,
116114
ops.StringToTimestamp,
117115
ops.StructColumn,
118-
ops.TimestampAdd,
119116
ops.TimestampDiff,
120-
ops.TimestampSub,
121117
ops.Unnest,
122118
)
123119

@@ -404,6 +400,8 @@ def visit_Cast(self, op, *, arg, to):
404400
return arg
405401
elif from_.is_integer() and to.is_timestamp():
406402
return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00")
403+
elif from_.is_integer() and to.is_interval():
404+
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
407405
return super().visit_Cast(op, arg=arg, to=to)
408406

409407
def visit_Sum(self, op, *, arg, where):
@@ -500,5 +498,18 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
500498

501499
return result
502500

501+
def visit_TimestampAdd(self, op, *, left, right):
502+
return self.f.dateadd(
503+
right.unit, self.cast(right.this, dt.int64), left, dialect=self.dialect
504+
)
505+
506+
def visit_TimestampSub(self, op, *, left, right):
507+
return self.f.dateadd(
508+
right.unit, -self.cast(right.this, dt.int64), left, dialect=self.dialect
509+
)
510+
511+
visit_DateAdd = visit_TimestampAdd
512+
visit_DateSub = visit_TimestampSub
513+
503514

504515
compiler = MSSQLCompiler()

ibis/backends/sql/compilers/mysql.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,16 @@ def visit_Cast(self, op, *, arg, to):
120120
# for TEXT (except when casting of course!)
121121
return arg
122122
elif from_.is_integer() and to.is_interval():
123-
return self.visit_IntervalFromInteger(
124-
ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit
125-
)
123+
return sge.Interval(this=arg, unit=self.v[to.unit.singular.upper()])
126124
elif from_.is_integer() and to.is_timestamp():
127125
return self.f.from_unixtime(arg)
128126
return super().visit_Cast(op, arg=arg, to=to)
129127

130128
def visit_TimestampDiff(self, op, *, left, right):
131-
return self.f.timestampdiff(
132-
sge.Var(this="SECOND"), right, left, dialect=self.dialect
133-
)
129+
return self.f.timestampdiff(self.v.SECOND, right, left, dialect=self.dialect)
134130

135131
def visit_DateDiff(self, op, *, left, right):
136-
return self.f.timestampdiff(
137-
sge.Var(this="DAY"), right, left, dialect=self.dialect
138-
)
132+
return self.f.timestampdiff(self.v.DAY, right, left, dialect=self.dialect)
139133

140134
def visit_ApproxCountDistinct(self, op, *, arg, where):
141135
if where is not None:
@@ -317,16 +311,16 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit):
317311

318312
def visit_DateTimeDelta(self, op, *, left, right, part):
319313
return self.f.timestampdiff(
320-
sge.Var(this=part.this), right, left, dialect=self.dialect
314+
self.v[part.this], right, left, dialect=self.dialect
321315
)
322316

323317
visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta
324318

325319
def visit_ExtractMillisecond(self, op, *, arg):
326-
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000)
320+
return self.f.floor(self.f.extract(self.v.microsecond, arg) / 1_000)
327321

328322
def visit_ExtractMicrosecond(self, op, *, arg):
329-
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg))
323+
return self.f.floor(self.f.extract(self.v.microsecond, arg))
330324

331325
def visit_Strip(self, op, *, arg):
332326
return self.visit_LRStrip(op, arg=arg, position="BOTH")
@@ -337,14 +331,9 @@ def visit_LStrip(self, op, *, arg):
337331
def visit_RStrip(self, op, *, arg):
338332
return self.visit_LRStrip(op, arg=arg, position="TRAILING")
339333

340-
def visit_IntervalFromInteger(self, op, *, arg, unit):
341-
return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper()))
342-
343334
def visit_TimestampAdd(self, op, *, left, right):
344335
if op.right.dtype.unit.short == "ms":
345-
right = sge.Interval(
346-
this=right.this * 1_000, unit=sge.Var(this="MICROSECOND")
347-
)
336+
right = sge.Interval(this=right.this * 1_000, unit=self.v.MICROSECOND)
348337
return self.f.date_add(left, right, dialect=self.dialect)
349338

350339
def visit_UnwrapJSONString(self, op, *, arg):

ibis/backends/sql/compilers/oracle.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class OracleCompiler(SQLGlotCompiler):
6767
ops.TimestampDelta,
6868
ops.TimestampFromYMDHMS,
6969
ops.TimeFromHMS,
70-
ops.IntervalFromInteger,
7170
ops.DayOfWeekIndex,
7271
ops.DayOfWeekName,
7372
ops.DateDiff,
@@ -138,30 +137,37 @@ def visit_Literal(self, op, *, value, dtype):
138137
elif dtype.is_uuid():
139138
return sge.convert(str(value))
140139
elif dtype.is_interval():
141-
if dtype.unit.short in ("Y", "M"):
142-
return self.f.numtoyminterval(value, dtype.unit.name)
143-
elif dtype.unit.short in ("D", "h", "m", "s"):
144-
return self.f.numtodsinterval(value, dtype.unit.name)
145-
else:
146-
raise com.UnsupportedOperationError(
147-
f"Intervals with precision {dtype.unit.name} not supported in Oracle."
148-
)
140+
return self._value_to_interval(value, dtype.unit)
149141

150142
return super().visit_Literal(op, value=value, dtype=dtype)
151143

144+
def _value_to_interval(self, arg, unit):
145+
short = unit.short
146+
147+
if short in ("Y", "M"):
148+
return self.f.numtoyminterval(arg, unit.singular)
149+
elif short in ("D", "h", "m", "s"):
150+
return self.f.numtodsinterval(arg, unit.singular)
151+
elif short == "ms":
152+
return self.f.numtodsinterval(arg / 1e3, "second")
153+
elif short in "us":
154+
return self.f.numtodsinterval(arg / 1e6, "second")
155+
elif short in "ns":
156+
return self.f.numtodsinterval(arg / 1e9, "second")
157+
else:
158+
raise com.UnsupportedArgumentError(
159+
f"Interval {unit.name} not supported by Oracle"
160+
)
161+
152162
def visit_Cast(self, op, *, arg, to):
153-
if to.is_interval():
163+
from_ = op.arg.dtype
164+
if from_.is_numeric() and to.is_interval():
154165
# CASTing to an INTERVAL in Oracle requires specifying digits of
155166
# precision that are a pain. There are two helper functions that
156167
# should be used instead.
157-
if to.unit.short in ("D", "h", "m", "s"):
158-
return self.f.numtodsinterval(arg, to.unit.name)
159-
elif to.unit.short in ("Y", "M"):
160-
return self.f.numtoyminterval(arg, to.unit.name)
161-
else:
162-
raise com.UnsupportedArgumentError(
163-
f"Interval {to.unit.name} not supported by Oracle"
164-
)
168+
return self._value_to_interval(arg, to.unit)
169+
elif from_.is_string() and to.is_date():
170+
return self.f.to_date(arg, "FXYYYY-MM-DD")
165171
return self.cast(arg, to)
166172

167173
def visit_Limit(self, op, *, parent, n, offset):
@@ -457,5 +463,8 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by):
457463

458464
return out
459465

466+
def visit_IntervalFromInteger(self, op, *, arg, unit):
467+
return self._value_to_interval(arg, unit)
468+
460469

461470
compiler = OracleCompiler()

ibis/backends/sql/compilers/snowflake.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,6 @@ def visit_TimestampSub(self, op, *, left, right):
378378
visit_DateAdd = visit_TimestampAdd
379379
visit_DateSub = visit_TimestampSub
380380

381-
def visit_IntervalFromInteger(self, op, *, arg, unit):
382-
return sge.Interval(this=arg, unit=self.v[unit.name])
383-
384381
def visit_IntegerRange(self, op, *, start, stop, step):
385382
return self.if_(
386383
step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array()

ibis/backends/tests/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@
133133

134134
try:
135135
from oracledb.exceptions import DatabaseError as OracleDatabaseError
136+
from oracledb.exceptions import InterfaceError as OracleInterfaceError
136137
except ImportError:
137-
OracleDatabaseError = None
138+
OracleDatabaseError = OracleInterfaceError = None
138139

139140
try:
140141
from pyodbc import DataError as PyODBCDataError

ibis/backends/tests/test_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col):
157157
["2009-01-20", datetime.date(2009, 1, 20), datetime.datetime(2009, 1, 20)],
158158
ids=["string", "date", "datetime"],
159159
)
160-
@pytest.mark.notimpl(["druid", "oracle"])
160+
@pytest.mark.notimpl(["druid"])
161161
def test_scalar_param_date(backend, alltypes, value):
162162
param = ibis.param("date")
163163
ds_col = alltypes.date_string_col

0 commit comments

Comments
 (0)