Skip to content

Commit 364c57b

Browse files
committed
Enhancement: Add support for timezone-flexible DateTime (unionai-oss#1352)
Signed-off-by: Max Raphael <[email protected]>
1 parent 33fe68b commit 364c57b

File tree

2 files changed

+172
-9
lines changed

2 files changed

+172
-9
lines changed

pandera/engines/pandas_engine.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,13 @@ class DateTime(_BaseDateTime, dtypes.Timestamp):
860860
tz: Optional[datetime.tzinfo] = None
861861
"""The timezone."""
862862

863+
timezone_flexible: bool = False
864+
"""
865+
A flag indicating whether the datetime data should be handled flexibly with respect to timezones.
866+
When set to True, the function will ignore 'tz' and allow datetimes with any timezone(s). If coerce is set to True,
867+
the function can accept timezone-naive datetimes, and will convert all datetimes to the specified tz (or 'UTC').
868+
"""
869+
863870
to_datetime_kwargs: Dict[str, Any] = dataclasses.field(
864871
default_factory=dict, compare=False, repr=False
865872
)
@@ -936,14 +943,65 @@ def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype):
936943
return cls(unit=pd_dtype.unit, tz=pd_dtype.tz) # type: ignore
937944

938945
def coerce(self, data_container: PandasObject) -> PandasObject:
946+
if self.timezone_flexible:
947+
data_container = self._prepare_coerce_timezone_flexible(data_container=data_container)
939948
return self._coerce(data_container, pandas_dtype=self.type)
940949

950+
def _prepare_coerce_timezone_flexible(self, data_container: PandasObject) -> PandasObject:
951+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
952+
if isinstance(data_container.dtype, pd.DatetimeTZDtype):
953+
tz = self.tz if self.tz else data_container.dtype.tz
954+
unit = self.unit if self.unit else data_container.dtype.unit
955+
type_ = pd.DatetimeTZDtype(unit, tz)
956+
object.__setattr__(self, "tz", tz)
957+
object.__setattr__(self, "type", type_)
958+
# If there are multiple timezones, convert them to the specified tz (default 'UTC') and set the type accordingly
959+
elif all(isinstance(x, datetime.datetime) for x in data_container):
960+
container_type = type(data_container)
961+
tz = self.tz if self.tz else 'UTC'
962+
unit = self.unit if self.unit else data_container.dtype.unit
963+
data_container = container_type(
964+
[pd.Timestamp(ts).tz_convert(tz) if pd.Timestamp(ts).tzinfo else pd.Timestamp(ts).tz_localize(tz)
965+
for ts in data_container]
966+
)
967+
type_ = pd.DatetimeTZDtype(unit, tz)
968+
object.__setattr__(self, "tz", tz)
969+
object.__setattr__(self, "type", type_)
970+
else:
971+
# Prepare to raise exception, adding type strictly for the check_dtype error message
972+
object.__setattr__(self, "type", "datetime64[ns, <timezone>]")
973+
return data_container
974+
941975
def coerce_value(self, value: Any) -> Any:
942976
"""Coerce an value to specified datatime type."""
943977
return self._get_to_datetime_fn(value)(
944978
value, **self.to_datetime_kwargs
945979
)
946980

981+
def check(
982+
self,
983+
pandera_dtype: dtypes.DataType,
984+
data_container: Optional[PandasObject] = None,
985+
) -> Union[bool, Iterable[bool]]:
986+
if self.timezone_flexible:
987+
self._prepare_check_timezone_flexible(pandera_dtype=pandera_dtype, data_container=data_container)
988+
return super().check(pandera_dtype, data_container)
989+
990+
def _prepare_check_timezone_flexible(
991+
self, pandera_dtype: dtypes.DataType, data_container: Optional[PandasObject]
992+
) -> None:
993+
# If there is a single timezone, define the type as a timezone-aware DatetimeTZDtype
994+
if isinstance(pandera_dtype, DateTime) and pandera_dtype.tz is not None:
995+
type_ = pd.DatetimeTZDtype(self.unit, pandera_dtype.tz)
996+
object.__setattr__(self, "tz", pandera_dtype.tz)
997+
object.__setattr__(self, "type", type_)
998+
# If the data has a mix of timezones, pandas defines the dtype as 'object
999+
elif all(isinstance(x, datetime.datetime) and x.tzinfo is not None for x in data_container):
1000+
object.__setattr__(self, "type", np.dtype('O'))
1001+
else:
1002+
# Prepare to raise exception, adding type strictly for the check_dtype error message
1003+
object.__setattr__(self, "type", "datetime64[ns, <timezone>]")
1004+
9471005
def __str__(self) -> str:
9481006
if self.type == np.dtype("datetime64[ns]"):
9491007
return "datetime64[ns]"

tests/core/test_pandas_engine.py

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Test pandas engine."""
22

3-
from datetime import date
4-
from typing import Any, Set
3+
import datetime as dt
4+
from typing import Tuple, List, Optional, Any, Set
5+
from zoneinfo import ZoneInfo
56

67
import hypothesis
78
import hypothesis.extra.pandas as pd_st
@@ -13,8 +14,9 @@
1314
import pytz
1415
from hypothesis import given
1516

17+
from pandera import Field, DataFrameModel
1618
from pandera.engines import pandas_engine
17-
from pandera.errors import ParserError
19+
from pandera.errors import ParserError, SchemaError
1820

1921
UNSUPPORTED_DTYPE_CLS: Set[Any] = set()
2022

@@ -202,6 +204,109 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
202204
assert coerced_data.dt.tz == timezone
203205

204206

207+
def generate_test_cases_timezone_flexible() -> List[
208+
Tuple[List[dt.datetime], Optional[dt.tzinfo], bool, List[dt.datetime], bool]
209+
]:
210+
"""
211+
Generate test parameter combinations for a given list of datetime lists.
212+
213+
Returns:
214+
List of tuples:
215+
- List of input datetimes
216+
- tz for DateTime constructor
217+
- coerce flag for Field constructor
218+
- expected output datetimes
219+
- raises flag (True if an exception is expected, False otherwise)
220+
"""
221+
datetimes = [
222+
# multi tz and tz naive
223+
[
224+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo('America/New_York')),
225+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo('America/Los_Angeles')),
226+
dt.datetime(2023, 3, 1, 5)
227+
],
228+
# multiz tz
229+
[
230+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo('America/New_York')),
231+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo('America/Los_Angeles'))
232+
],
233+
# tz naive
234+
[
235+
dt.datetime(2023, 3, 1, 4),
236+
dt.datetime(2023, 3, 1, 5)
237+
],
238+
# single tz
239+
[
240+
dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo('America/New_York')),
241+
dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo('America/New_York'))
242+
]
243+
]
244+
245+
test_cases = []
246+
247+
for datetime_list in datetimes:
248+
for coerce in [True, False]:
249+
for tz in [None, ZoneInfo("America/Chicago"), dt.timezone(dt.timedelta(hours=2))]:
250+
# Determine if the test should raise an exception
251+
has_naive_datetime = any([dt.tzinfo is None for dt in datetime_list])
252+
raises = has_naive_datetime and not coerce
253+
254+
# Generate expected output
255+
if raises:
256+
expected_output = None # No expected output since an exception will be raised
257+
else:
258+
if coerce:
259+
# localize / convert the input datetimes to the specified tz or 'UTC' (default)
260+
use_tz = tz if tz else ZoneInfo("UTC")
261+
expected_output_naive = [
262+
dt.replace(tzinfo=use_tz) for dt in datetime_list if dt.tzinfo is None
263+
]
264+
expected_output_aware = [
265+
dt.astimezone(use_tz) for dt in datetime_list if dt.tzinfo is not None
266+
]
267+
expected_output = expected_output_naive + expected_output_aware
268+
else:
269+
# ignore tz
270+
expected_output = datetime_list
271+
272+
test_case = (datetime_list, tz, coerce, expected_output, raises)
273+
test_cases.append(test_case)
274+
275+
# define final test cases with improper type
276+
datetime_list = [dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo('America/New_York')), "hello world"]
277+
tz = None
278+
expected_output = None
279+
raises = True
280+
281+
bad_type_coerce = (datetime_list, tz, True, expected_output, raises)
282+
bad_type_no_coerce = (datetime_list, tz, False, expected_output, raises)
283+
test_cases.extend([bad_type_coerce, bad_type_no_coerce])
284+
285+
return test_cases
286+
287+
288+
@pytest.mark.parametrize(
289+
"examples, tz, coerce, expected_output, raises",
290+
generate_test_cases_timezone_flexible()
291+
)
292+
def test_dt_timezone_flexible(examples, tz, coerce, expected_output, raises):
293+
"""Test that timezone_flexible works as expected"""
294+
295+
# Testing using a pandera DataFrameModel rather than directly calling dtype coerce or validate because with
296+
# timezone_flexible, dtype is set dynamically based on the input data
297+
class SimpleSchema(DataFrameModel):
298+
datetime_column: pandas_engine.DateTime(timezone_flexible=True, tz=tz) = Field(coerce=coerce)
299+
300+
data = pd.DataFrame({'datetime_column': examples})
301+
302+
if raises:
303+
with pytest.raises(SchemaError):
304+
SimpleSchema.validate(data)
305+
else:
306+
validated_df = SimpleSchema.validate(data)
307+
assert sorted(validated_df['datetime_column'].tolist()) == sorted(expected_output)
308+
309+
205310
@hypothesis.settings(max_examples=1000)
206311
@pytest.mark.parametrize("to_df", [True, False])
207312
@given(
@@ -225,7 +330,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
225330
)
226331

227332
assert (
228-
coerced_data.applymap(lambda x: isinstance(x, date))
333+
coerced_data.applymap(lambda x: isinstance(x, dt.date))
229334
| coerced_data.isna()
230335
).all(axis=None)
231336
return
@@ -234,7 +339,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
234339
coerced_data.isna().all() and coerced_data.dtype == "datetime64[ns]"
235340
)
236341
assert (
237-
coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
342+
coerced_data.map(lambda x: isinstance(x, dt.date)) | coerced_data.isna()
238343
).all()
239344

240345

@@ -246,8 +351,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
246351
pyarrow.struct([("foo", pyarrow.int64()), ("bar", pyarrow.string())]),
247352
),
248353
(pd.Series([None, pd.NA, np.nan]), pyarrow.null),
249-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date32),
250-
(pd.Series([None, date(1970, 1, 1)]), pyarrow.date64),
354+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date32),
355+
(pd.Series([None, dt.date(1970, 1, 1)]), pyarrow.date64),
251356
(pd.Series([1, 2]), pyarrow.duration("ns")),
252357
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time32("ms")),
253358
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time64("ns")),
@@ -292,8 +397,8 @@ def test_pandas_arrow_dtype(data, dtype):
292397
pyarrow.struct([("foo", pyarrow.string()), ("bar", pyarrow.int64())]),
293398
),
294399
(pd.Series(["a", "1"]), pyarrow.null),
295-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
296-
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
400+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
401+
(pd.Series(["a", dt.date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
297402
(pd.Series(["a"]), pyarrow.duration("ns")),
298403
(pd.Series(["a", "b"]), pyarrow.time32("ms")),
299404
(pd.Series(["a", "b"]), pyarrow.time64("ns")),

0 commit comments

Comments
 (0)