Skip to content

Commit 1acadbe

Browse files
committed
Enhancement: Add support for timezone-flexible DateTime (unionai-oss#1352)
Signed-off-by: Max Raphael <[email protected]> Enhancement: Add support for timezone-flexible DateTime (unionai-oss#1352) Signed-off-by: Max Raphael <[email protected]>
1 parent 33fe68b commit 1acadbe

File tree

2 files changed

+236
-9
lines changed

2 files changed

+236
-9
lines changed

pandera/engines/pandas_engine.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,13 @@ class DateTime(_BaseDateTime, dtypes.Timestamp):
860860
tz: Optional[datetime.tzinfo] = None
861861
"""The timezone."""
862862

863+
timezone_flexible: bool = False
864+
"""
865+
A flag indicating whether the datetime data should be handled flexibly with respect to timezones.
866+
When set to True, the function will ignore 'tz' and allow datetimes with any timezone(s). If coerce is set to True,
867+
the function can accept timezone-naive datetimes, and will convert all datetimes to the specified tz (or 'UTC').
868+
"""
869+
863870
to_datetime_kwargs: Dict[str, Any] = dataclasses.field(
864871
default_factory=dict, compare=False, repr=False
865872
)
@@ -936,14 +943,94 @@ def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype):
936943
return cls(unit=pd_dtype.unit, tz=pd_dtype.tz) # type: ignore
937944

938945
def coerce(self, data_container: PandasObject) -> PandasObject:
946+
if self.timezone_flexible:
947+
data_container = self._prepare_coerce_timezone_flexible(
948+
data_container=data_container
949+
)
939950
return self._coerce(data_container, pandas_dtype=self.type)
940951

952+
def _prepare_coerce_timezone_flexible(
953+
self, data_container: PandasObject
954+
) -> PandasObject:
955+
if not self.tz:
956+
raise errors.ParserError(
957+
"Cannot coerce timezone-naive datetimes when 'tz' is not specified. "
958+
"Either specify a timezone using 'tz' parameter or set 'timezone_flexible=True' "
959+
"to allow flexible timezone handling.",
960+
failure_cases=utils.numpy_pandas_coerce_failure_cases(
961+
data_container, self
962+
),
963+
)
964+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
965+
if isinstance(data_container.dtype, pd.DatetimeTZDtype):
966+
tz = self.tz if self.tz else data_container.dtype.tz
967+
unit = self.unit if self.unit else data_container.dtype.unit
968+
type_ = pd.DatetimeTZDtype(unit, tz)
969+
object.__setattr__(self, "tz", tz)
970+
object.__setattr__(self, "type", type_)
971+
# If there are multiple timezones, convert them to the specified tz (default 'UTC') and set the type accordingly
972+
elif all(isinstance(x, datetime.datetime) for x in data_container):
973+
container_type = type(data_container)
974+
tz = self.tz if self.tz else "UTC"
975+
unit = self.unit if self.unit else data_container.dtype.unit
976+
data_container = container_type(
977+
[
978+
(
979+
pd.Timestamp(ts).tz_convert(tz)
980+
if pd.Timestamp(ts).tzinfo
981+
else pd.Timestamp(ts).tz_localize(tz)
982+
)
983+
for ts in data_container
984+
]
985+
)
986+
type_ = pd.DatetimeTZDtype(unit, tz)
987+
object.__setattr__(self, "tz", tz)
988+
object.__setattr__(self, "type", type_)
989+
else:
990+
# Prepare to raise exception, adding type strictly for the check_dtype error message
991+
object.__setattr__(self, "type", "datetime64[ns, <timezone>]")
992+
return data_container
993+
941994
def coerce_value(self, value: Any) -> Any:
942995
"""Coerce an value to specified datatime type."""
943996
return self._get_to_datetime_fn(value)(
944997
value, **self.to_datetime_kwargs
945998
)
946999

1000+
def check(
1001+
self,
1002+
pandera_dtype: dtypes.DataType,
1003+
data_container: Optional[PandasObject] = None,
1004+
) -> Union[bool, Iterable[bool]]:
1005+
if self.timezone_flexible:
1006+
self._prepare_check_timezone_flexible(
1007+
pandera_dtype=pandera_dtype, data_container=data_container
1008+
)
1009+
return super().check(pandera_dtype, data_container)
1010+
1011+
def _prepare_check_timezone_flexible(
1012+
self,
1013+
pandera_dtype: dtypes.DataType,
1014+
data_container: Optional[PandasObject],
1015+
) -> None:
1016+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
1017+
if (
1018+
isinstance(pandera_dtype, DateTime)
1019+
and pandera_dtype.tz is not None
1020+
):
1021+
type_ = pd.DatetimeTZDtype(self.unit, pandera_dtype.tz)
1022+
object.__setattr__(self, "tz", pandera_dtype.tz)
1023+
object.__setattr__(self, "type", type_)
1024+
# If the data has a mix of timezones, pandas defines the dtype as 'object
1025+
elif all(
1026+
isinstance(x, datetime.datetime) and x.tzinfo is not None
1027+
for x in data_container # type: ignore
1028+
):
1029+
object.__setattr__(self, "type", np.dtype("O"))
1030+
else:
1031+
# Prepare to raise exception, adding type strictly for the check_dtype error message
1032+
object.__setattr__(self, "type", "datetime64[ns, <timezone>]")
1033+
9471034
def __str__(self) -> str:
9481035
if self.type == np.dtype("datetime64[ns]"):
9491036
return "datetime64[ns]"

tests/core/test_pandas_engine.py

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Test pandas engine."""
22

3-
from datetime import date
4-
from typing import Any, Set
3+
import datetime as dt
4+
from typing import Tuple, List, Optional, Any, Set
5+
from zoneinfo import ZoneInfo
56

67
import hypothesis
78
import hypothesis.extra.pandas as pd_st
@@ -13,8 +14,9 @@
1314
import pytz
1415
from hypothesis import given
1516

17+
from pandera import Field, DataFrameModel
1618
from pandera.engines import pandas_engine
17-
from pandera.errors import ParserError
19+
from pandera.errors import ParserError, SchemaError
1820

1921
UNSUPPORTED_DTYPE_CLS: Set[Any] = set()
2022

@@ -202,6 +204,143 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
202204
assert coerced_data.dt.tz == timezone
203205

204206

207+
def generate_test_cases_timezone_flexible() -> List[
208+
Tuple[
209+
List[dt.datetime],
210+
Optional[dt.tzinfo],
211+
bool,
212+
List[dt.datetime],
213+
bool,
214+
]
215+
]:
216+
"""
217+
Generate test parameter combinations for a given list of datetime lists.
218+
219+
Returns:
220+
List of tuples:
221+
- List of input datetimes
222+
- tz for DateTime constructor
223+
- coerce flag for Field constructor
224+
- expected output datetimes
225+
- raises flag (True if an exception is expected, False otherwise)
226+
"""
227+
datetimes = [
228+
# multi tz and tz naive
229+
[
230+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
231+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
232+
dt.datetime(2023, 3, 1, 5),
233+
],
234+
# multiz tz
235+
[
236+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
237+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
238+
],
239+
# tz naive
240+
[dt.datetime(2023, 3, 1, 4), dt.datetime(2023, 3, 1, 5)],
241+
# single tz
242+
[
243+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
244+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/New_York")),
245+
],
246+
]
247+
248+
test_cases = []
249+
250+
for datetime_list in datetimes:
251+
for coerce in [True, False]:
252+
for tz in [
253+
None,
254+
ZoneInfo("America/Chicago"),
255+
dt.timezone(dt.timedelta(hours=2)),
256+
]:
257+
# Determine if the test should raise an exception
258+
# Should raise error when:
259+
# * coerce is False but there is a timezone-naive datetime
260+
# * coerce is True but tz is not set
261+
has_naive_datetime = any(
262+
dt.tzinfo is None for dt in datetime_list
263+
)
264+
raises = (not coerce and has_naive_datetime) or (
265+
coerce and tz is None
266+
)
267+
268+
# Generate expected output
269+
if raises:
270+
expected_output = None # No expected output since an exception will be raised
271+
else:
272+
if coerce:
273+
# localize / convert the input datetimes to the specified tz or 'UTC' (default)
274+
use_tz = tz if tz else ZoneInfo("UTC")
275+
expected_output_naive = [
276+
dt.replace(tzinfo=use_tz)
277+
for dt in datetime_list
278+
if dt.tzinfo is None
279+
]
280+
expected_output_aware = [
281+
dt.astimezone(use_tz)
282+
for dt in datetime_list
283+
if dt.tzinfo is not None
284+
]
285+
expected_output = (
286+
expected_output_naive + expected_output_aware
287+
)
288+
else:
289+
# ignore tz
290+
expected_output = datetime_list
291+
292+
test_case = (
293+
datetime_list,
294+
tz,
295+
coerce,
296+
expected_output,
297+
raises,
298+
)
299+
test_cases.append(test_case)
300+
301+
# define final test cases with improper type
302+
datetime_list = [
303+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
304+
"hello world",
305+
]
306+
tz = None
307+
expected_output = None
308+
raises = True
309+
310+
bad_type_coerce = (datetime_list, tz, True, expected_output, raises)
311+
bad_type_no_coerce = (datetime_list, tz, False, expected_output, raises)
312+
test_cases.extend([bad_type_coerce, bad_type_no_coerce]) # type: ignore
313+
314+
return test_cases # type: ignore
315+
316+
317+
@pytest.mark.parametrize(
318+
"examples, tz, coerce, expected_output, raises",
319+
generate_test_cases_timezone_flexible(),
320+
)
321+
def test_dt_timezone_flexible(examples, tz, coerce, expected_output, raises):
322+
"""Test that timezone_flexible works as expected"""
323+
324+
# Testing using a pandera DataFrameModel rather than directly calling dtype coerce or validate because with
325+
# timezone_flexible, dtype is set dynamically based on the input data
326+
class SimpleSchema(DataFrameModel):
327+
# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
328+
datetime_column: pandas_engine.DateTime(
329+
timezone_flexible=True, tz=tz
330+
) = Field(coerce=coerce)
331+
332+
data = pd.DataFrame({"datetime_column": examples})
333+
334+
if raises:
335+
with pytest.raises(SchemaError):
336+
SimpleSchema.validate(data)
337+
else:
338+
validated_df = SimpleSchema.validate(data)
339+
assert sorted(validated_df["datetime_column"].tolist()) == sorted(
340+
expected_output
341+
)
342+
343+
205344
@hypothesis.settings(max_examples=1000)
206345
@pytest.mark.parametrize("to_df", [True, False])
207346
@given(
@@ -225,7 +364,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
225364
)
226365

227366
assert (
228-
coerced_data.applymap(lambda x: isinstance(x, date))
367+
coerced_data.applymap(lambda x: isinstance(x, dt.date))
229368
| coerced_data.isna()
230369
).all(axis=None)
231370
return
@@ -234,7 +373,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
234373
coerced_data.isna().all() and coerced_data.dtype == "datetime64[ns]"
235374
)
236375
assert (
237-
coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
376+
coerced_data.map(lambda x: isinstance(x, dt.date))
377+
| coerced_data.isna()
238378
).all()
239379

240380

@@ -246,8 +386,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
246386
pyarrow.struct([("foo", pyarrow.int64()), ("bar", pyarrow.string())]),
247387
),
248388
(pd.Series([None, pd.NA, np.nan]), pyarrow.null),
249-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date32),
250-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date64),
389+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date32),
390+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date64),
251391
(pd.Series([1, 2]), pyarrow.duration("ns")),
252392
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time32("ms")),
253393
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time64("ns")),
@@ -292,8 +432,8 @@ def test_pandas_arrow_dtype(data, dtype):
292432
pyarrow.struct([("foo", pyarrow.string()), ("bar", pyarrow.int64())]),
293433
),
294434
(pd.Series(["a", "1"]), pyarrow.null),
295-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
296-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
435+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
436+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
297437
(pd.Series(["a"]), pyarrow.duration("ns")),
298438
(pd.Series(["a", "b"]), pyarrow.time32("ms")),
299439
(pd.Series(["a", "b"]), pyarrow.time64("ns")),

0 commit comments

Comments
 (0)