Skip to content

Commit 0ead096

Browse files
committed
fix(databricks/pyspark): unify timestamp/timestamp_ntz behavior
1 parent 3b4e56f commit 0ead096

File tree

19 files changed

+289
-95
lines changed

19 files changed

+289
-95
lines changed

.github/workflows/ibis-backends.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,13 @@ jobs:
557557
- pyspark==3.3.4
558558
- pandas==1.5.3
559559
- numpy==1.23.5
560+
- python-version: "3.9"
561+
pyspark-minor-version: "3.4"
562+
tag: local
563+
deps:
564+
- pyspark==3.4.4
565+
- pandas==1.5.3
566+
- numpy==1.23.5
560567
- python-version: "3.11"
561568
pyspark-minor-version: "3.5"
562569
tag: local
@@ -609,7 +616,7 @@ jobs:
609616

610617
# it requires a version of pandas that pyspark is not compatible with
611618
- name: remove lonboard
612-
if: matrix.pyspark-minor-version == '3.3'
619+
if: matrix.pyspark-minor-version != '3.5'
613620
run: uv remove --group docs --no-sync lonboard
614621

615622
- name: install pyspark-specific dependencies

ibis/backends/bigquery/tests/unit/test_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_integer_to_timestamp(case, unit, snapshot):
116116

117117

118118
@pytest.mark.parametrize(
119-
("case",),
119+
"case",
120120
[
121121
param("a\\b\\c", id="escape_backslash"),
122122
param("a\ab\bc\fd\ne\rf\tg\vh", id="escape_ascii_sequences"),

ibis/backends/polars/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any)
559559

560560
@_read_in_memory.register("ibis.expr.types.Table")
561561
def _table(source, table_name, _conn, **kwargs: Any):
562-
_conn._add_table(table_name, source.to_polars())
562+
_conn._add_table(table_name, _conn.to_polars(source))
563563

564564

565565
@_read_in_memory.register("polars.DataFrame")

ibis/backends/postgres/tests/test_udf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def con_for_udf(con, sql_table_setup, sql_define_udf, sql_define_py_udf, test_da
7474
c.execute(sql_table_setup)
7575
c.execute(sql_define_udf)
7676
c.execute(sql_define_py_udf)
77-
yield con
77+
return con
7878

7979

8080
@pytest.fixture

ibis/backends/pyspark/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ class SparkConnectGrpcException(Exception):
5151

5252
from ibis.expr.api import Watermark
5353

54+
5455
PYSPARK_VERSION = vparse(pyspark.__version__)
55-
PYSPARK_LT_34 = PYSPARK_VERSION < vparse("3.4")
56-
PYSPARK_LT_35 = PYSPARK_VERSION < vparse("3.5")
56+
PYSPARK_33 = vparse("3.3") <= PYSPARK_VERSION < vparse("3.4")
57+
PYSPARK_35 = vparse("3.5") <= PYSPARK_VERSION
58+
SUPPORTS_TIMESTAMP_NTZ = vparse("3.4") <= PYSPARK_VERSION
59+
5760
ConnectionMode = Literal["streaming", "batch"]
5861

5962

@@ -244,7 +247,7 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):
244247
if catalog is None and db is None:
245248
yield
246249
return
247-
if catalog is not None and PYSPARK_LT_34:
250+
if catalog is not None and PYSPARK_33:
248251
raise com.UnsupportedArgumentError(
249252
"Catalogs are not supported in pyspark < 3.4"
250253
)
@@ -313,7 +316,7 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):
313316

314317
@contextlib.contextmanager
315318
def _active_catalog(self, name: str | None):
316-
if name is None or PYSPARK_LT_34:
319+
if name is None or PYSPARK_33:
317320
yield
318321
return
319322

@@ -408,7 +411,7 @@ def _register_udfs(self, expr: ir.Expr) -> None:
408411
spark_udf = F.udf(udf_func, udf_return)
409412
elif udf.__input_type__ == InputType.PYARROW:
410413
# raise not implemented error if running on pyspark < 3.5
411-
if PYSPARK_LT_35:
414+
if not PYSPARK_35:
412415
raise NotImplementedError(
413416
"pyarrow UDFs are only supported in pyspark >= 3.5"
414417
)

ibis/backends/pyspark/datatypes.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
from __future__ import annotations
22

3-
import pyspark
3+
from functools import partial
4+
from inspect import isclass
5+
46
import pyspark.sql.types as pt
5-
from packaging.version import parse as vparse
67

78
import ibis.common.exceptions as com
89
import ibis.expr.datatypes as dt
910
import ibis.expr.schema as sch
1011
from ibis.formats import SchemaMapper, TypeMapper
1112

12-
# DayTimeIntervalType introduced in Spark 3.2 (at least) but didn't show up in
13-
# PySpark until version 3.3
14-
PYSPARK_33 = vparse(pyspark.__version__) >= vparse("3.3")
15-
PYSPARK_35 = vparse(pyspark.__version__) >= vparse("3.5")
16-
17-
1813
_from_pyspark_dtypes = {
1914
pt.BinaryType: dt.Binary,
2015
pt.BooleanType: dt.Boolean,
@@ -27,52 +22,64 @@
2722
pt.NullType: dt.Null,
2823
pt.ShortType: dt.Int16,
2924
pt.StringType: dt.String,
30-
pt.TimestampType: dt.Timestamp,
3125
}
3226

33-
_to_pyspark_dtypes = {v: k for k, v in _from_pyspark_dtypes.items()}
27+
try:
28+
_from_pyspark_dtypes[pt.TimestampNTZType] = dt.Timestamp
29+
except AttributeError:
30+
_from_pyspark_dtypes[pt.TimestampType] = dt.Timestamp
31+
else:
32+
_from_pyspark_dtypes[pt.TimestampType] = partial(dt.Timestamp, timezone="UTC")
33+
34+
_to_pyspark_dtypes = {
35+
v: k
36+
for k, v in _from_pyspark_dtypes.items()
37+
if isclass(v) and not issubclass(v, dt.Timestamp) and not isinstance(v, partial)
38+
}
3439
_to_pyspark_dtypes[dt.JSON] = pt.StringType
3540
_to_pyspark_dtypes[dt.UUID] = pt.StringType
3641

3742

38-
if PYSPARK_33:
39-
_pyspark_interval_units = {
40-
pt.DayTimeIntervalType.SECOND: "s",
41-
pt.DayTimeIntervalType.MINUTE: "m",
42-
pt.DayTimeIntervalType.HOUR: "h",
43-
pt.DayTimeIntervalType.DAY: "D",
44-
}
45-
46-
4743
class PySparkType(TypeMapper):
4844
@classmethod
4945
def to_ibis(cls, typ, nullable=True):
5046
"""Convert a pyspark type to an ibis type."""
47+
from ibis.backends.pyspark import SUPPORTS_TIMESTAMP_NTZ
48+
5149
if isinstance(typ, pt.DecimalType):
5250
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
5351
elif isinstance(typ, pt.ArrayType):
5452
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
5553
elif isinstance(typ, pt.MapType):
5654
return dt.Map(
57-
cls.to_ibis(typ.keyType),
58-
cls.to_ibis(typ.valueType),
59-
nullable=nullable,
55+
cls.to_ibis(typ.keyType), cls.to_ibis(typ.valueType), nullable=nullable
6056
)
6157
elif isinstance(typ, pt.StructType):
6258
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}
6359

6460
return dt.Struct(fields, nullable=nullable)
65-
elif PYSPARK_33 and isinstance(typ, pt.DayTimeIntervalType):
61+
elif isinstance(typ, pt.DayTimeIntervalType):
62+
pyspark_interval_units = {
63+
pt.DayTimeIntervalType.SECOND: "s",
64+
pt.DayTimeIntervalType.MINUTE: "m",
65+
pt.DayTimeIntervalType.HOUR: "h",
66+
pt.DayTimeIntervalType.DAY: "D",
67+
}
68+
6669
if (
6770
typ.startField == typ.endField
68-
and typ.startField in _pyspark_interval_units
71+
and typ.startField in pyspark_interval_units
6972
):
70-
unit = _pyspark_interval_units[typ.startField]
73+
unit = pyspark_interval_units[typ.startField]
7174
return dt.Interval(unit, nullable=nullable)
7275
else:
7376
raise com.IbisTypeError(f"{typ!r} couldn't be converted to Interval")
74-
elif PYSPARK_35 and isinstance(typ, pt.TimestampNTZType):
75-
return dt.Timestamp(nullable=nullable)
77+
elif isinstance(typ, pt.TimestampNTZType):
78+
if SUPPORTS_TIMESTAMP_NTZ:
79+
return dt.Timestamp(nullable=nullable)
80+
raise com.UnsupportedBackendType(
81+
"PySpark<3.4 doesn't properly support timestamps without a timezone"
82+
)
7683
elif isinstance(typ, pt.UserDefinedType):
7784
return cls.to_ibis(typ.sqlType(), nullable=nullable)
7885
else:
@@ -85,6 +92,8 @@ def to_ibis(cls, typ, nullable=True):
8592

8693
@classmethod
8794
def from_ibis(cls, dtype):
95+
from ibis.backends.pyspark import SUPPORTS_TIMESTAMP_NTZ
96+
8897
if dtype.is_decimal():
8998
return pt.DecimalType(dtype.precision, dtype.scale)
9099
elif dtype.is_array():
@@ -97,11 +106,21 @@ def from_ibis(cls, dtype):
97106
value_contains_null = dtype.value_type.nullable
98107
return pt.MapType(key_type, value_type, value_contains_null)
99108
elif dtype.is_struct():
100-
fields = [
101-
pt.StructField(n, cls.from_ibis(t), t.nullable)
102-
for n, t in dtype.fields.items()
103-
]
104-
return pt.StructType(fields)
109+
return pt.StructType(
110+
[
111+
pt.StructField(field, cls.from_ibis(dtype), dtype.nullable)
112+
for field, dtype in dtype.fields.items()
113+
]
114+
)
115+
elif dtype.is_timestamp():
116+
if dtype.timezone is not None:
117+
return pt.TimestampType()
118+
else:
119+
if not SUPPORTS_TIMESTAMP_NTZ:
120+
raise com.UnsupportedBackendType(
121+
"PySpark<3.4 doesn't properly support timestamps without a timezone"
122+
)
123+
return pt.TimestampNTZType()
105124
else:
106125
try:
107126
return _to_pyspark_dtypes[type(dtype)]()
@@ -114,11 +133,7 @@ def from_ibis(cls, dtype):
114133
class PySparkSchema(SchemaMapper):
115134
@classmethod
116135
def from_ibis(cls, schema):
117-
fields = [
118-
pt.StructField(name, PySparkType.from_ibis(dtype), dtype.nullable)
119-
for name, dtype in schema.items()
120-
]
121-
return pt.StructType(fields)
136+
return PySparkType.from_ibis(schema.as_struct())
122137

123138
@classmethod
124139
def to_ibis(cls, schema):

ibis/backends/pyspark/tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,16 @@ def _load_data(self, **_: Any) -> None:
337337

338338
for name, schema in TEST_TABLES.items():
339339
path = str(self.data_dir / "directory" / "parquet" / name)
340+
sch = ibis.schema(
341+
{
342+
col: dtype.copy(timezone="UTC")
343+
if dtype.is_timestamp()
344+
else dtype
345+
for col, dtype in schema.items()
346+
}
347+
)
340348
t = (
341-
s.readStream.schema(PySparkSchema.from_ibis(schema))
349+
s.readStream.schema(PySparkSchema.from_ibis(sch))
342350
.parquet(path)
343351
.repartition(num_partitions)
344352
)

ibis/backends/pyspark/tests/test_basic.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytest import param
1010

1111
import ibis
12-
from ibis.common.exceptions import IbisTypeError
12+
import ibis.common.exceptions as com
1313

1414
pyspark = pytest.importorskip("pyspark")
1515

@@ -119,30 +119,31 @@ def test_alias_after_select(t):
119119

120120

121121
def test_interval_columns_invalid(con):
122-
df_interval_invalid = con._session.createDataFrame(
123-
[[timedelta(days=10, hours=10, minutes=10, seconds=10)]],
124-
pt.StructType(
125-
[
126-
pt.StructField(
127-
"interval_day_hour",
128-
pt.DayTimeIntervalType(
129-
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
130-
),
131-
)
132-
]
133-
),
122+
data = [[timedelta(days=10, hours=10, minutes=10, seconds=10)]]
123+
schema = pt.StructType(
124+
[
125+
pt.StructField(
126+
"interval_day_hour",
127+
pt.DayTimeIntervalType(
128+
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
129+
),
130+
)
131+
]
134132
)
135133

136-
df_interval_invalid.createTempView("invalid_interval_table")
137-
msg = r"DayTimeIntervalType.+ couldn't be converted to Interval"
138-
with pytest.raises(IbisTypeError, match=msg):
139-
con.table("invalid_interval_table")
134+
name = "invalid_interval_table"
135+
136+
con._session.createDataFrame(data, schema).createTempView(name)
137+
138+
with pytest.raises(
139+
com.IbisTypeError, match="DayTimeIntervalType.+ couldn't be converted"
140+
):
141+
con.table(name)
140142

141143

142144
def test_string_literal_backslash_escaping(con):
143-
expr = ibis.literal("\\d\\e")
144-
result = con.execute(expr)
145-
assert result == "\\d\\e"
145+
input = r"\d\e"
146+
assert con.execute(ibis.literal(input)) == input
146147

147148

148149
def test_connect_without_explicit_session():

ibis/backends/pyspark/tests/test_ddl.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import ibis
1111
from ibis import util
12+
from ibis.backends.pyspark import PYSPARK_33
1213
from ibis.backends.tests.errors import PySparkAnalysisException
1314
from ibis.tests.util import assert_equal
1415

@@ -92,12 +93,13 @@ def test_ctas_from_table_expr(con, alltypes, temp_table_db):
9293

9394
def test_create_empty_table(con, temp_table):
9495
schema = ibis.schema(
95-
[
96-
("a", "string"),
97-
("b", "timestamp"),
98-
("c", "decimal(12, 8)"),
99-
("d", "double"),
100-
]
96+
{
97+
"a": "string",
98+
"b": "timestamp('UTC')",
99+
"c": "decimal(12, 8)",
100+
"d": "double",
101+
}
102+
| ({"e": "timestamp"} if not PYSPARK_33 else {})
101103
)
102104

103105
con.create_table(temp_table, schema=schema)
@@ -181,9 +183,9 @@ def test_create_table_reserved_identifier(con, alltypes, keyword_t):
181183

182184

183185
@pytest.mark.xfail_version(
184-
pyspark=["pyspark<3.5"],
186+
pyspark=["pyspark<3.4"],
185187
raises=ValueError,
186-
reason="PySparkAnalysisException is not available in PySpark <3.5",
188+
reason="PySparkAnalysisException is not available in PySpark <3.4",
187189
)
188190
def test_create_database_exists(con):
189191
con.create_database(dbname := util.gen_name("dbname"))
@@ -197,9 +199,9 @@ def test_create_database_exists(con):
197199

198200

199201
@pytest.mark.xfail_version(
200-
pyspark=["pyspark<3.5"],
202+
pyspark=["pyspark<3.4"],
201203
raises=ValueError,
202-
reason="PySparkAnalysisException is not available in PySpark <3.5",
204+
reason="PySparkAnalysisException is not available in PySpark <3.4",
203205
)
204206
def test_drop_database_exists(con):
205207
con.create_database(dbname := util.gen_name("dbname"))

0 commit comments

Comments
 (0)