Skip to content

Commit cd998c8

Browse files
committed
Enhancement: Add support for timezone-flexible DateTime (#1352)
Signed-off-by: Max Raphael <[email protected]> Enhancement: Add support for timezone-flexible DateTime (#1352) Signed-off-by: Max Raphael <[email protected]>
1 parent 33fe68b commit cd998c8

File tree

2 files changed

+258
-9
lines changed

2 files changed

+258
-9
lines changed

pandera/engines/pandas_engine.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,19 @@ class DateTime(_BaseDateTime, dtypes.Timestamp):
860860
tz: Optional[datetime.tzinfo] = None
861861
"""The timezone."""
862862

863+
timezone_flexible: bool = False
864+
"""
865+
A flag indicating whether the datetime data should be handled flexibly with respect to timezones.
866+
867+
- If set to `True` and `coerce` is `False`, the function will accept datetimes with any timezone(s)
868+
but not timezone-naive datetimes. If passed, the `tz` argument will be ignored, as this use
869+
case is handled by setting `timezone_flexible=False`.
870+
871+
- If set to `True` and `coerce` is `True`, a `tz` must also be specified. The function will then
872+
accept datetimes with any timezone(s) and convert them to the specified tz, as well as
873+
timezone-naive datetimes, and localize them to the specified tz.
874+
"""
875+
863876
to_datetime_kwargs: Dict[str, Any] = dataclasses.field(
864877
default_factory=dict, compare=False, repr=False
865878
)
@@ -936,14 +949,112 @@ def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype):
936949
return cls(unit=pd_dtype.unit, tz=pd_dtype.tz) # type: ignore
937950

938951
def coerce(self, data_container: PandasObject) -> PandasObject:
952+
if self.timezone_flexible:
953+
data_container = self._prepare_coerce_timezone_flexible(
954+
data_container=data_container
955+
)
939956
return self._coerce(data_container, pandas_dtype=self.type)
940957

958+
def _prepare_coerce_timezone_flexible(
959+
self, data_container: PandasObject
960+
) -> PandasObject:
961+
if not self.tz:
962+
raise errors.ParserError(
963+
"Cannot coerce timezone-naive datetimes when 'tz' is not specified. "
964+
"Either specify a timezone using 'tz' parameter or set 'timezone_flexible=True' "
965+
"to allow flexible timezone handling.",
966+
failure_cases=utils.numpy_pandas_coerce_failure_cases(
967+
data_container, self
968+
),
969+
)
970+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
971+
if isinstance(data_container.dtype, pd.DatetimeTZDtype):
972+
tz = self.tz
973+
unit = self.unit if self.unit else data_container.dtype.unit
974+
type_ = pd.DatetimeTZDtype(unit, tz)
975+
object.__setattr__(self, "tz", tz)
976+
object.__setattr__(self, "type", type_)
977+
# If there are multiple timezones, convert them to the specified tz and set the type accordingly
978+
elif all(isinstance(x, datetime.datetime) for x in data_container):
979+
container_type = type(data_container)
980+
tz = self.tz
981+
unit = self.unit if self.unit else data_container.dtype.unit
982+
data_container = container_type(
983+
[
984+
(
985+
pd.Timestamp(ts).tz_convert(tz)
986+
if pd.Timestamp(ts).tzinfo
987+
else pd.Timestamp(ts).tz_localize(tz)
988+
)
989+
for ts in data_container
990+
]
991+
)
992+
type_ = pd.DatetimeTZDtype(unit, tz)
993+
object.__setattr__(self, "tz", tz)
994+
object.__setattr__(self, "type", type_)
995+
else:
996+
raise errors.ParserError(
997+
"When timezone_flexible=True, data must either be:\n"
998+
"1. A Series with DatetimeTZDtype (timezone-aware datetime series), or\n"
999+
"2. A Series of datetime objects\n"
1000+
f"Got data with dtype: {data_container.dtype}",
1001+
failure_cases=utils.numpy_pandas_coerce_failure_cases(
1002+
data_container, self
1003+
),
1004+
)
1005+
return data_container
1006+
9411007
def coerce_value(self, value: Any) -> Any:
9421008
"""Coerce an value to specified datatime type."""
9431009
return self._get_to_datetime_fn(value)(
9441010
value, **self.to_datetime_kwargs
9451011
)
9461012

1013+
def check(
1014+
self,
1015+
pandera_dtype: dtypes.DataType,
1016+
data_container: Optional[PandasObject] = None,
1017+
) -> Union[bool, Iterable[bool]]:
1018+
if self.timezone_flexible:
1019+
self._prepare_check_timezone_flexible(
1020+
pandera_dtype=pandera_dtype, data_container=data_container
1021+
)
1022+
return super().check(pandera_dtype, data_container)
1023+
1024+
def _prepare_check_timezone_flexible(
1025+
self,
1026+
pandera_dtype: dtypes.DataType,
1027+
data_container: Optional[PandasObject],
1028+
) -> None:
1029+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
1030+
if (
1031+
isinstance(pandera_dtype, DateTime)
1032+
and pandera_dtype.tz is not None
1033+
):
1034+
type_ = pd.DatetimeTZDtype(self.unit, pandera_dtype.tz)
1035+
object.__setattr__(self, "tz", pandera_dtype.tz)
1036+
object.__setattr__(self, "type", type_)
1037+
# If the data has a mix of timezones, pandas defines the dtype as 'object`
1038+
elif all(
1039+
isinstance(x, datetime.datetime) and x.tzinfo is not None
1040+
for x in data_container # type: ignore
1041+
):
1042+
object.__setattr__(self, "type", np.dtype("O"))
1043+
else:
1044+
raise errors.ParserError(
1045+
"When timezone_flexible=True, data must either be:\n"
1046+
"1. A Series with DatetimeTZDtype (timezone-aware datetime series), or\n"
1047+
"2. A Series of timezone-aware datetime objects\n"
1048+
f"Got data with dtype: {data_container.dtype if data_container is not None else 'None'}",
1049+
failure_cases=(
1050+
utils.numpy_pandas_coerce_failure_cases(
1051+
data_container, self
1052+
)
1053+
if data_container is not None
1054+
else None
1055+
),
1056+
)
1057+
9471058
def __str__(self) -> str:
9481059
if self.type == np.dtype("datetime64[ns]"):
9491060
return "datetime64[ns]"

tests/core/test_pandas_engine.py

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Test pandas engine."""
22

3-
from datetime import date
4-
from typing import Any, Set
3+
import datetime as dt
4+
from typing import Tuple, List, Optional, Any, Set
5+
from zoneinfo import ZoneInfo
56

67
import hypothesis
78
import hypothesis.extra.pandas as pd_st
@@ -13,8 +14,9 @@
1314
import pytz
1415
from hypothesis import given
1516

17+
from pandera import Field, DataFrameModel
1618
from pandera.engines import pandas_engine
17-
from pandera.errors import ParserError
19+
from pandera.errors import ParserError, SchemaError
1820

1921
UNSUPPORTED_DTYPE_CLS: Set[Any] = set()
2022

@@ -202,6 +204,141 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
202204
assert coerced_data.dt.tz == timezone
203205

204206

207+
def generate_test_cases_timezone_flexible() -> List[
208+
Tuple[
209+
List[dt.datetime],
210+
Optional[dt.tzinfo],
211+
bool,
212+
List[dt.datetime],
213+
bool,
214+
]
215+
]:
216+
"""
217+
Generate test parameter combinations for a given list of datetime lists.
218+
219+
Returns:
220+
List of tuples:
221+
- List of input datetimes
222+
- tz for DateTime constructor
223+
- coerce flag for Field constructor
224+
- expected output datetimes
225+
- raises flag (True if an exception is expected, False otherwise)
226+
"""
227+
datetimes = [
228+
# multi tz and tz naive
229+
[
230+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
231+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
232+
dt.datetime(2023, 3, 1, 5),
233+
],
234+
# multiz tz
235+
[
236+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
237+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
238+
],
239+
# tz naive
240+
[dt.datetime(2023, 3, 1, 4), dt.datetime(2023, 3, 1, 5)],
241+
# single tz
242+
[
243+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
244+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/New_York")),
245+
],
246+
]
247+
248+
test_cases = []
249+
250+
for datetime_list in datetimes:
251+
for coerce in [True, False]:
252+
for tz in [
253+
None,
254+
ZoneInfo("America/Chicago"),
255+
dt.timezone(dt.timedelta(hours=2)),
256+
]:
257+
# Determine if the test should raise an exception
258+
# Should raise error when:
259+
# * coerce is False but there is a timezone-naive datetime
260+
# * coerce is True but tz is not set
261+
has_naive_datetime = any(
262+
dt.tzinfo is None for dt in datetime_list
263+
)
264+
raises = (not coerce and has_naive_datetime) or (
265+
coerce and tz is None
266+
)
267+
268+
# Generate expected output
269+
if raises:
270+
expected_output = None # No expected output since an exception will be raised
271+
else:
272+
if coerce:
273+
expected_output_naive = [
274+
dt.replace(tzinfo=tz)
275+
for dt in datetime_list
276+
if dt.tzinfo is None
277+
]
278+
expected_output_aware = [
279+
dt.astimezone(tz)
280+
for dt in datetime_list
281+
if dt.tzinfo is not None
282+
]
283+
expected_output = (
284+
expected_output_naive + expected_output_aware
285+
)
286+
else:
287+
# ignore tz
288+
expected_output = datetime_list
289+
290+
test_case = (
291+
datetime_list,
292+
tz,
293+
coerce,
294+
expected_output,
295+
raises,
296+
)
297+
test_cases.append(test_case)
298+
299+
# define final test cases with improper type
300+
datetime_list = [
301+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
302+
"hello world",
303+
]
304+
tz = None
305+
expected_output = None
306+
raises = True
307+
308+
bad_type_coerce = (datetime_list, tz, True, expected_output, raises)
309+
bad_type_no_coerce = (datetime_list, tz, False, expected_output, raises)
310+
test_cases.extend([bad_type_coerce, bad_type_no_coerce]) # type: ignore
311+
312+
return test_cases # type: ignore
313+
314+
315+
@pytest.mark.parametrize(
316+
"examples, tz, coerce, expected_output, raises",
317+
generate_test_cases_timezone_flexible(),
318+
)
319+
def test_dt_timezone_flexible(examples, tz, coerce, expected_output, raises):
320+
"""Test that timezone_flexible works as expected"""
321+
322+
# Testing using a pandera DataFrameModel rather than directly calling dtype coerce or validate because with
323+
# timezone_flexible, dtype is set dynamically based on the input data
324+
class SimpleSchema(DataFrameModel):
325+
# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
326+
datetime_column: pandas_engine.DateTime(
327+
timezone_flexible=True, tz=tz
328+
) = Field(coerce=coerce)
329+
330+
data = pd.DataFrame({"datetime_column": examples})
331+
332+
if raises:
333+
with pytest.raises(SchemaError):
334+
SimpleSchema.validate(data)
335+
else:
336+
validated_df = SimpleSchema.validate(data)
337+
assert sorted(validated_df["datetime_column"].tolist()) == sorted(
338+
expected_output
339+
)
340+
341+
205342
@hypothesis.settings(max_examples=1000)
206343
@pytest.mark.parametrize("to_df", [True, False])
207344
@given(
@@ -225,7 +362,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
225362
)
226363

227364
assert (
228-
coerced_data.applymap(lambda x: isinstance(x, date))
365+
coerced_data.applymap(lambda x: isinstance(x, dt.date))
229366
| coerced_data.isna()
230367
).all(axis=None)
231368
return
@@ -234,7 +371,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
234371
coerced_data.isna().all() and coerced_data.dtype == "datetime64[ns]"
235372
)
236373
assert (
237-
coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
374+
coerced_data.map(lambda x: isinstance(x, dt.date))
375+
| coerced_data.isna()
238376
).all()
239377

240378

@@ -246,8 +384,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
246384
pyarrow.struct([("foo", pyarrow.int64()), ("bar", pyarrow.string())]),
247385
),
248386
(pd.Series([None, pd.NA, np.nan]), pyarrow.null),
249-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date32),
250-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date64),
387+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date32),
388+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date64),
251389
(pd.Series([1, 2]), pyarrow.duration("ns")),
252390
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time32("ms")),
253391
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time64("ns")),
@@ -292,8 +430,8 @@ def test_pandas_arrow_dtype(data, dtype):
292430
pyarrow.struct([("foo", pyarrow.string()), ("bar", pyarrow.int64())]),
293431
),
294432
(pd.Series(["a", "1"]), pyarrow.null),
295-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
296-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
433+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
434+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
297435
(pd.Series(["a"]), pyarrow.duration("ns")),
298436
(pd.Series(["a", "b"]), pyarrow.time32("ms")),
299437
(pd.Series(["a", "b"]), pyarrow.time64("ns")),

0 commit comments

Comments
 (0)