Skip to content

Commit 067120d

Browse files
tokokocpcloud
authored andcommitted
feat(pyspark): add partial support for interval types
1 parent e2c159c commit 067120d

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

ibis/backends/pyspark/datatypes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@ def _spark_struct(spark_dtype_obj, nullable=True):
7979
return dt.Struct(fields, nullable=nullable)
8080

8181

82+
_SPARK_INTERVAL_TO_IBIS_INTERVAL = {
83+
pt.DayTimeIntervalType.SECOND: 's',
84+
pt.DayTimeIntervalType.MINUTE: 'm',
85+
pt.DayTimeIntervalType.HOUR: 'h',
86+
pt.DayTimeIntervalType.DAY: 'D',
87+
}
88+
89+
90+
@dt.dtype.register(pt.DayTimeIntervalType)
91+
def _spark_struct(spark_dtype_obj, nullable=True):
92+
if (
93+
spark_dtype_obj.startField == spark_dtype_obj.endField
94+
and spark_dtype_obj.startField in _SPARK_INTERVAL_TO_IBIS_INTERVAL
95+
):
96+
return dt.Interval(
97+
_SPARK_INTERVAL_TO_IBIS_INTERVAL[spark_dtype_obj.startField],
98+
nullable=nullable,
99+
)
100+
else:
101+
raise com.IbisTypeError("DayTimeIntervalType couldn't be converted to Interval")
102+
103+
82104
_IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()}
83105
_IBIS_DTYPE_TO_SPARK_DTYPE[dt.JSON] = pt.StringType
84106

ibis/backends/pyspark/tests/conftest.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import os
4-
from datetime import datetime, timezone
4+
from datetime import datetime, timezone, timedelta
55

66
import numpy as np
77
import pandas as pd
@@ -276,6 +276,63 @@ def client(data_directory):
276276

277277
df_time_indexed.createTempView('time_indexed_table')
278278

279+
df_interval = client._session.createDataFrame(
280+
[
281+
[
282+
timedelta(days=10),
283+
timedelta(hours=10),
284+
timedelta(minutes=10),
285+
timedelta(seconds=10),
286+
]
287+
],
288+
pt.StructType(
289+
[
290+
pt.StructField(
291+
"interval_day",
292+
pt.DayTimeIntervalType(
293+
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY
294+
),
295+
),
296+
pt.StructField(
297+
"interval_hour",
298+
pt.DayTimeIntervalType(
299+
pt.DayTimeIntervalType.HOUR, pt.DayTimeIntervalType.HOUR
300+
),
301+
),
302+
pt.StructField(
303+
"interval_minute",
304+
pt.DayTimeIntervalType(
305+
pt.DayTimeIntervalType.MINUTE, pt.DayTimeIntervalType.MINUTE
306+
),
307+
),
308+
pt.StructField(
309+
"interval_second",
310+
pt.DayTimeIntervalType(
311+
pt.DayTimeIntervalType.SECOND, pt.DayTimeIntervalType.SECOND
312+
),
313+
),
314+
]
315+
),
316+
)
317+
318+
df_interval.createTempView('interval_table')
319+
320+
df_interval_invalid = client._session.createDataFrame(
321+
[[timedelta(days=10, hours=10, minutes=10, seconds=10)]],
322+
pt.StructType(
323+
[
324+
pt.StructField(
325+
"interval_day_hour",
326+
pt.DayTimeIntervalType(
327+
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.HOUR
328+
),
329+
)
330+
]
331+
),
332+
)
333+
334+
df_interval_invalid.createTempView('invalid_interval_table')
335+
279336
return client
280337

281338

ibis/backends/pyspark/tests/test_basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import pyspark.sql.functions as F # noqa: E402
1111

1212
from ibis.backends.pyspark.compiler import _can_be_replaced_by_column_name # noqa: E402
13+
from ibis.expr import datatypes as dt
14+
from ibis.common.exceptions import IbisTypeError
1315

1416

1517
def test_basic(client):
@@ -211,3 +213,22 @@ def test_can_be_replaced_by_column_name(selection_fn, selection_idx, expected):
211213
selection_to_test = table.op().selections[selection_idx]
212214
result = _can_be_replaced_by_column_name(selection_to_test, table.op().table)
213215
assert result == expected
216+
217+
218+
def test_interval_columns(client):
219+
table = client.table('interval_table')
220+
assert table.schema() == ibis.schema(
221+
pairs=[
222+
('interval_day', dt.Interval('D')),
223+
('interval_hour', dt.Interval('h')),
224+
('interval_minute', dt.Interval('m')),
225+
('interval_second', dt.Interval('s')),
226+
]
227+
)
228+
229+
230+
def test_interval_columns_invalid(client):
231+
with pytest.raises(
232+
IbisTypeError, match="DayTimeIntervalType couldn't be converted to Interval"
233+
):
234+
client.table('invalid_interval_table')

0 commit comments

Comments
 (0)