Skip to content

Commit 758ec25

Browse files
committed
fix(pyspark): custom format converter to handle pyspark timestamps
1 parent bb92e9f commit 758ec25

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

ibis/backends/pyspark/__init__.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from ibis.backends.pyspark.client import PySparkTable
3232
from ibis.backends.pyspark.compiler import PySparkExprTranslator
3333
from ibis.backends.pyspark.datatypes import PySparkType
34+
from ibis.common.temporal import normalize_timezone
35+
from ibis.formats.pandas import PandasData
3436

3537
if TYPE_CHECKING:
3638
from collections.abc import Mapping, Sequence
@@ -104,6 +106,18 @@ class PySparkCompiler(Compiler):
104106
table_set_formatter_class = PySparkTableSetFormatter
105107

106108

109+
class PySparkPandasData(PandasData):
110+
@classmethod
111+
def convert_Timestamp_element(cls, dtype):
112+
def converter(value, dtype=dtype):
113+
if (tz := dtype.timezone) is not None:
114+
return value.astimezone(normalize_timezone(tz))
115+
116+
return value.astimezone(normalize_timezone("UTC")).replace(tzinfo=None)
117+
118+
return converter
119+
120+
107121
class Backend(BaseSQLBackend, CanCreateDatabase):
108122
compiler = PySparkCompiler
109123
name = "pyspark"
@@ -219,7 +233,9 @@ def execute(self, expr: ir.Expr, **kwargs: Any) -> Any:
219233
df = self.compile(table_expr, **kwargs).toPandas()
220234

221235
# TODO: remove the extra conversion
222-
return expr.__pandas_result__(table_expr.__pandas_result__(df))
236+
return expr.__pandas_result__(
237+
PySparkPandasData.convert_table(df, table_expr.schema())
238+
)
223239

224240
def _fully_qualified_name(self, name, database):
225241
if is_fully_qualified(name):
@@ -232,17 +248,15 @@ def close(self):
232248
self._context.stop()
233249

234250
def fetch_from_cursor(self, cursor, schema):
235-
df = cursor.query.toPandas() # blocks until finished
236-
return schema.apply_to(df)
251+
return cursor.query.toPandas() # blocks until finished
237252

238253
def raw_sql(self, query: str) -> _PySparkCursor:
239254
query = self._session.sql(query)
240255
return _PySparkCursor(query)
241256

242257
def _get_schema_using_query(self, query):
243258
cursor = self.raw_sql(f"SELECT * FROM ({query}) t0 LIMIT 0")
244-
struct = PySparkType.to_ibis(cursor.query.schema)
245-
return sch.Schema(struct)
259+
return sch.Schema(PySparkType.to_ibis(cursor.query.schema))
246260

247261
def _get_jtable(self, name, database=None):
248262
get_table = self._catalog._jcatalog.getTable

ibis/backends/tests/test_temporal.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,6 @@ def test_timestamp_extract_milliseconds(backend, alltypes, df):
344344
raises=GoogleBadRequest,
345345
reason="UNIX_SECONDS does not support DATETIME arguments",
346346
)
347-
@pytest.mark.xfail_version(
348-
pyspark=["pandas<2.1"],
349-
reason="test was adjusted to work with pandas 2.1 output; pyspark doesn't support pandas 2",
350-
)
351347
@pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError)
352348
def test_timestamp_extract_epoch_seconds(backend, alltypes, df):
353349
expr = alltypes.timestamp_col.epoch_seconds().name("tmp")

0 commit comments

Comments
 (0)