Skip to content

Commit 3ce4f2a

Browse files
cpcloudkszucs
authored andcommitted
feat(sqlite): implement date_truncate
1 parent d630a77 commit 3ce4f2a

File tree

6 files changed

+72
-42
lines changed

6 files changed

+72
-42
lines changed

ibis/backends/base/sql/alchemy/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def _sort_key(t, expr):
496496
ops.FloorDivide: _floor_divide,
497497
# other
498498
ops.SortKey: _sort_key,
499+
ops.Date: unary(lambda arg: sa.cast(arg, sa.DATE)),
499500
}
500501

501502

ibis/backends/mysql/registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def _day_of_week_name(t, expr):
240240
ops.Round: _round,
241241
ops.RandomScalar: _random,
242242
# dates and times
243-
ops.Date: unary(sa.func.date),
244243
ops.DateAdd: infix_op('+'),
245244
ops.DateSub: infix_op('-'),
246245
ops.DateDiff: fixed_arity(sa.func.datediff, 2),

ibis/backends/postgres/registry.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,12 @@ def _timestamp_truncate(t, expr):
119119

120120

121121
def _interval_from_integer(t, expr):
122-
arg, unit = expr.op().args
123-
sa_arg = t.translate(arg)
122+
op = expr.op()
123+
sa_arg = t.translate(op.arg)
124124
interval = sa.text(f"INTERVAL '1 {expr.type().resolution}'")
125125
return sa_arg * interval
126126

127127

128-
def _timestamp_add(t, expr):
129-
sa_args = list(map(t.translate, expr.op().args))
130-
return sa_args[0] + sa_args[1]
131-
132-
133128
def _is_nan(t, expr):
134129
(arg,) = expr.op().args
135130
sa_arg = t.translate(arg)
@@ -678,7 +673,6 @@ def _day_of_week_name(t, expr):
678673
ops.Round: _round,
679674
ops.Modulus: _mod,
680675
# dates and times
681-
ops.Date: unary(lambda x: sa.cast(x, sa.Date)),
682676
ops.DateTruncate: _timestamp_truncate,
683677
ops.TimestampTruncate: _timestamp_truncate,
684678
ops.IntervalFromInteger: _interval_from_integer,

ibis/backends/sqlite/registry.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import sqlalchemy as sa
22
import toolz
3+
from multipledispatch import Dispatcher
34

45
import ibis
56
import ibis.common.exceptions as com
67
import ibis.expr.datatypes as dt
78
import ibis.expr.operations as ops
89
import ibis.expr.types as ir
910
from ibis.backends.base.sql.alchemy import (
11+
AlchemyExprTranslator,
1012
fixed_arity,
1113
sqlalchemy_operation_registry,
1214
sqlalchemy_window_functions_registry,
@@ -19,32 +21,55 @@
1921
operation_registry.update(sqlalchemy_window_functions_registry)
2022

2123

22-
def _cast(t, expr):
23-
# It's not all fun and games with SQLite
24+
sqlite_cast = Dispatcher("sqlite_cast")
25+
26+
27+
@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Timestamp)
28+
def _unixepoch(t, arg, _):
29+
return sa.func.datetime(t.translate(arg), "unixepoch")
30+
31+
32+
@sqlite_cast.register(AlchemyExprTranslator, ir.StringValue, dt.Timestamp)
33+
def _string_to_timestamp(t, arg, _):
34+
return sa.func.strftime('%Y-%m-%d %H:%M:%f', t.translate(arg))
35+
36+
37+
@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Date)
38+
def _integer_to_date(t, arg, _):
39+
return sa.func.date(sa.func.datetime(t.translate(arg), "unixepoch"))
40+
41+
42+
@sqlite_cast.register(
43+
AlchemyExprTranslator,
44+
(ir.StringValue, ir.TimestampValue),
45+
dt.Date,
46+
)
47+
def _string_or_timestamp_to_date(t, arg, _):
48+
return sa.func.date(t.translate(arg))
49+
2450

51+
@sqlite_cast.register(
52+
AlchemyExprTranslator,
53+
ir.ValueExpr,
54+
(dt.Date, dt.Timestamp),
55+
)
56+
def _value_to_temporal(t, arg, _):
57+
raise com.UnsupportedOperationError(type(arg))
58+
59+
60+
@sqlite_cast.register(AlchemyExprTranslator, ir.CategoryValue, dt.Int32)
61+
def _category_to_int(t, arg, _):
62+
return t.translate(arg)
63+
64+
65+
@sqlite_cast.register(AlchemyExprTranslator, ir.ValueExpr, dt.DataType)
66+
def _default_cast_impl(t, arg, target_type):
67+
return sa.cast(t.translate(arg), t.get_sqla_type(target_type))
68+
69+
70+
def _cast(t, expr):
2571
op = expr.op()
26-
arg, target_type = op.args
27-
sa_arg = t.translate(arg)
28-
sa_type = t.get_sqla_type(target_type)
29-
30-
if isinstance(target_type, dt.Timestamp):
31-
if isinstance(arg, ir.IntegerValue):
32-
return sa.func.datetime(sa_arg, 'unixepoch')
33-
elif isinstance(arg, ir.StringValue):
34-
return sa.func.strftime('%Y-%m-%d %H:%M:%f', sa_arg)
35-
raise com.UnsupportedOperationError(type(arg))
36-
37-
if isinstance(target_type, dt.Date):
38-
if isinstance(arg, ir.IntegerValue):
39-
return sa.func.date(sa.func.datetime(sa_arg, 'unixepoch'))
40-
elif isinstance(arg, ir.StringValue):
41-
return sa.func.date(sa_arg)
42-
raise com.UnsupportedOperationError(type(arg))
43-
44-
if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
45-
return sa_arg
46-
else:
47-
return sa.cast(sa_arg, sa_type)
72+
return sqlite_cast(t, op.arg, op.to)
4873

4974

5075
def _substr(t, expr):
@@ -237,6 +262,7 @@ def _rpad(t, expr):
237262
ops.Greatest: varargs(sa.func.max),
238263
ops.IfNull: fixed_arity(sa.func.ifnull, 2),
239264
ops.DateTruncate: _truncate(sa.func.date),
265+
ops.Date: unary(sa.func.date),
240266
ops.TimestampTruncate: _truncate(sa.func.datetime),
241267
ops.Strftime: _strftime,
242268
ops.ExtractYear: _strftime_int('%Y'),

ibis/backends/tests/test_param.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value):
2727
('start_string', 'end_string'),
2828
[('2009-03-01', '2010-07-03'), ('2014-12-01', '2017-01-05')],
2929
)
30-
@pytest.mark.notimpl(["datafusion", "pyspark", "sqlite"])
31-
def test_date_scalar_parameter(
32-
backend, alltypes, df, start_string, end_string
33-
):
30+
@pytest.mark.notimpl(["datafusion", "pyspark"])
31+
def test_date_scalar_parameter(backend, alltypes, start_string, end_string):
3432
start, end = ibis.param(dt.date), ibis.param(dt.date)
3533

3634
col = alltypes.timestamp_col.date()

ibis/backends/tests/test_temporal.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,20 @@
1313

1414

1515
@pytest.mark.parametrize('attr', ['year', 'month', 'day'])
16-
@pytest.mark.notimpl(["datafusion", "sqlite"])
17-
def test_date_extract(backend, alltypes, df, attr):
18-
expr = getattr(alltypes.timestamp_col.date(), attr)()
16+
@pytest.mark.parametrize(
17+
"expr_fn",
18+
[
19+
param(lambda c: c.date(), id="date"),
20+
param(
21+
lambda c: c.cast("date"),
22+
id="cast",
23+
marks=pytest.mark.notimpl(["impala"]),
24+
),
25+
],
26+
)
27+
@pytest.mark.notimpl(["datafusion"])
28+
def test_date_extract(backend, alltypes, df, attr, expr_fn):
29+
expr = getattr(expr_fn(alltypes.timestamp_col), attr)()
1930
expected = getattr(df.timestamp_col.dt, attr).astype('int32')
2031

2132
result = expr.execute()
@@ -172,16 +183,17 @@ def test_timestamp_truncate(backend, alltypes, df, unit):
172183
"mysql",
173184
"postgres",
174185
"pyspark",
186+
"sqlite",
175187
]
176188
),
177189
),
178190
],
179191
)
180-
@pytest.mark.notimpl(["datafusion", "sqlite"])
192+
@pytest.mark.notimpl(["datafusion"])
181193
def test_date_truncate(backend, alltypes, df, unit):
182194
expr = alltypes.timestamp_col.date().truncate(unit)
183195

184-
dtype = f'datetime64[{unit}]'
196+
dtype = f"datetime64[{unit}]"
185197
expected = pd.Series(df.timestamp_col.values.astype(dtype))
186198

187199
result = expr.execute()

0 commit comments

Comments
 (0)