Skip to content

Commit b7d9084

Browse files
cosmicBboymax-raphael
authored andcommitted
implement timezone agnostic polars_engine.DateTime type (unionai-oss#1589)
Signed-off-by: cosmicBboy <[email protected]>
1 parent 743e1cf commit b7d9084

File tree

3 files changed

+116
-3
lines changed

3 files changed

+116
-3
lines changed

pandera/engines/polars_engine.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import decimal
66
import inspect
77
import warnings
8-
from typing import Any, Union, Optional, Iterable, Literal, Sequence, Tuple
8+
from typing import (
9+
Any,
10+
Union,
11+
Optional,
12+
Iterable,
13+
Literal,
14+
Sequence,
15+
Tuple,
16+
Type,
17+
)
918

1019

1120
import polars as pl
@@ -416,16 +425,26 @@ class Date(DataType, dtypes.Date):
416425
class DateTime(DataType, dtypes.DateTime):
417426
"""Polars datetime data type."""
418427

419-
type = pl.Datetime
428+
type: Type[pl.Datetime] = pl.Datetime
429+
time_zone_agnostic: bool = False
420430

421431
def __init__( # pylint:disable=super-init-not-called
422432
self,
423433
time_zone: Optional[str] = None,
424434
time_unit: Optional[str] = None,
435+
time_zone_agnostic: bool = False,
425436
) -> None:
437+
438+
_kwargs = {}
439+
if time_unit is not None:
440+
# avoid deprecated warning when initializing pl.Datetime:
441+
# passing time_unit=None is deprecated.
442+
_kwargs["time_unit"] = time_unit
443+
426444
object.__setattr__(
427-
self, "type", pl.Datetime(time_zone=time_zone, time_unit=time_unit)
445+
self, "type", pl.Datetime(time_zone=time_zone, **_kwargs)
428446
)
447+
object.__setattr__(self, "time_zone_agnostic", time_zone_agnostic)
429448

430449
@classmethod
431450
def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
@@ -435,6 +454,24 @@ def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
435454
time_zone=polars_dtype.time_zone, time_unit=polars_dtype.time_unit
436455
)
437456

457+
def check(
458+
self,
459+
pandera_dtype: dtypes.DataType,
460+
data_container: Optional[PolarsDataContainer] = None,
461+
) -> Union[bool, Iterable[bool]]:
462+
try:
463+
pandera_dtype = Engine.dtype(pandera_dtype)
464+
except TypeError:
465+
return False
466+
467+
if self.time_zone_agnostic:
468+
return (
469+
isinstance(pandera_dtype.type, pl.Datetime)
470+
and pandera_dtype.type.time_unit == self.type.time_unit
471+
)
472+
473+
return self.type == pandera_dtype.type and super().check(pandera_dtype)
474+
438475

439476
@Engine.register_dtype(
440477
equivalents=[

tests/polars/test_polars_container.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
import polars as pl
1212

1313
import pytest
14+
from hypothesis import given
15+
from hypothesis import strategies as st
16+
from polars.testing.parametric import dataframes, column
17+
1418
import pandera as pa
1519
from pandera import Check as C
1620
from pandera.api.polars.types import PolarsData
21+
from pandera.engines import polars_engine as pe
1722
from pandera.polars import Column, DataFrameSchema, DataFrameModel
1823

1924

@@ -528,3 +533,34 @@ class Config:
528533
lf_with_nested_types, lazy=True
529534
)
530535
assert validated_lf.collect().equals(validated_lf.collect())
536+
537+
538+
@pytest.mark.parametrize(
539+
"time_zone",
540+
[
541+
None,
542+
"UTC",
543+
"GMT",
544+
"EST",
545+
],
546+
)
547+
@given(st.data())
548+
def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data):
549+
strategy = dataframes(
550+
column("datetime_col", dtype=pl.Datetime()),
551+
lazy=True,
552+
size=10,
553+
)
554+
lf = data.draw(strategy)
555+
lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)})
556+
schema_tz_agnostic = DataFrameSchema(
557+
{"datetime_col": Column(pe.DateTime(time_zone_agnostic=True))}
558+
)
559+
schema_tz_agnostic.validate(lf)
560+
561+
schema_tz_sensitive = DataFrameSchema(
562+
{"datetime_col": Column(pe.DateTime(time_zone_agnostic=False))}
563+
)
564+
if time_zone:
565+
with pytest.raises(pa.errors.SchemaError):
566+
schema_tz_sensitive.validate(lf)

tests/polars/test_polars_dtypes.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Polars dtype tests."""
2+
3+
import datetime
24
import decimal
35
from decimal import Decimal
46
from typing import Union, Tuple, Sequence
@@ -403,3 +405,41 @@ def test_polars_nested_dtypes_try_coercion(
403405
pe.Engine.dtype(noncoercible_dtype).try_coerce(PolarsData(data))
404406
except pandera.errors.ParserError as exc:
405407
assert exc.failure_cases.equals(data.collect())
408+
409+
410+
@pytest.mark.parametrize(
411+
"dtype",
412+
[
413+
"datetime",
414+
datetime.datetime,
415+
pl.Datetime,
416+
pl.Datetime(),
417+
pl.Datetime(time_unit="ns"),
418+
pl.Datetime(time_unit="us"),
419+
pl.Datetime(time_unit="ms"),
420+
pl.Datetime(time_zone="UTC"),
421+
],
422+
)
423+
def test_datetime_time_zone_agnostic(dtype):
424+
425+
tz_agnostic = pe.DateTime(time_zone_agnostic=True)
426+
dtype = pe.Engine.dtype(dtype)
427+
428+
if tz_agnostic.type.time_unit == getattr(dtype.type, "time_unit", "us"):
429+
# timezone agnostic pandera dtype should pass regardless of timezone
430+
assert tz_agnostic.check(dtype)
431+
else:
432+
# but fail if the time units don't match
433+
assert not tz_agnostic.check(dtype)
434+
435+
tz_sensitive = pe.DateTime()
436+
if getattr(dtype.type, "time_zone", None) is not None:
437+
assert not tz_sensitive.check(dtype)
438+
439+
tz_sensitive_utc = pe.DateTime(time_zone="UTC")
440+
if getattr(
441+
dtype.type, "time_zone", None
442+
) is None and tz_sensitive_utc.type.time_zone != getattr(
443+
dtype.type, "time_zone", None
444+
):
445+
assert not tz_sensitive_utc.check(dtype)

0 commit comments

Comments
 (0)