1
1
"""Test pandas engine."""
2
2
3
- from datetime import date
4
- from typing import Any, Set
3
+ import datetime as dt
4
+ from typing import Tuple, List, Optional, Any, Set
5
+ from zoneinfo import ZoneInfo
5
6
6
7
import hypothesis
7
8
import hypothesis.extra.pandas as pd_st
13
14
import pytz
14
15
from hypothesis import given
15
16
17
+ from pandera import Field, DataFrameModel, errors
16
18
from pandera.engines import pandas_engine
17
- from pandera.errors import ParserError
19
+ from pandera.errors import ParserError, SchemaError
18
20
19
21
UNSUPPORTED_DTYPE_CLS: Set[Any] = set()
20
22
@@ -202,6 +204,141 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
202
204
assert coerced_data.dt.tz == timezone
203
205
204
206
207
+ def generate_test_cases_timezone_flexible() -> List[
208
+ Tuple[
209
+ List[dt.datetime],
210
+ Optional[dt.tzinfo],
211
+ bool,
212
+ List[dt.datetime],
213
+ bool,
214
+ ]
215
+ ]:
216
+ """
217
+ Generate test parameter combinations for a given list of datetime lists.
218
+
219
+ Returns:
220
+ List of tuples:
221
+ - List of input datetimes
222
+ - tz for DateTime constructor
223
+ - coerce flag for Field constructor
224
+ - expected output datetimes
225
+ - raises flag (True if an exception is expected, False otherwise)
226
+ """
227
+ datetimes = [
228
+ # multi tz and tz naive
229
+ [
230
+ dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
231
+ dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
232
+ dt.datetime(2023, 3, 1, 5),
233
+ ],
234
+ # multiz tz
235
+ [
236
+ dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
237
+ dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/Los_Angeles")),
238
+ ],
239
+ # tz naive
240
+ [dt.datetime(2023, 3, 1, 4), dt.datetime(2023, 3, 1, 5)],
241
+ # single tz
242
+ [
243
+ dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
244
+ dt.datetime(2023, 3, 1, 5, tzinfo=ZoneInfo("America/New_York")),
245
+ ],
246
+ ]
247
+
248
+ test_cases = []
249
+
250
+ for datetime_list in datetimes:
251
+ for coerce in [True, False]:
252
+ for tz in [
253
+ None,
254
+ ZoneInfo("America/Chicago"),
255
+ dt.timezone(dt.timedelta(hours=2)),
256
+ ]:
257
+ # Determine if the test should raise an exception
258
+ # Should raise error when:
259
+ # * coerce is False but there is a timezone-naive datetime
260
+ # * coerce is True but tz is not set
261
+ has_naive_datetime = any(
262
+ dt.tzinfo is None for dt in datetime_list
263
+ )
264
+ raises = (not coerce and has_naive_datetime) or (
265
+ coerce and tz is None
266
+ )
267
+
268
+ # Generate expected output
269
+ if raises:
270
+ expected_output = None # No expected output since an exception will be raised
271
+ else:
272
+ if coerce:
273
+ expected_output_naive = [
274
+ dt.replace(tzinfo=tz)
275
+ for dt in datetime_list
276
+ if dt.tzinfo is None
277
+ ]
278
+ expected_output_aware = [
279
+ dt.astimezone(tz)
280
+ for dt in datetime_list
281
+ if dt.tzinfo is not None
282
+ ]
283
+ expected_output = (
284
+ expected_output_naive + expected_output_aware
285
+ )
286
+ else:
287
+ # ignore tz
288
+ expected_output = datetime_list
289
+
290
+ test_case = (
291
+ datetime_list,
292
+ tz,
293
+ coerce,
294
+ expected_output,
295
+ raises,
296
+ )
297
+ test_cases.append(test_case)
298
+
299
+ # define final test cases with improper type
300
+ datetime_list = [
301
+ dt.datetime(2023, 3, 1, 4, tzinfo=ZoneInfo("America/New_York")),
302
+ "hello world",
303
+ ]
304
+ tz = None
305
+ expected_output = None
306
+ raises = True
307
+
308
+ bad_type_coerce = (datetime_list, tz, True, expected_output, raises)
309
+ bad_type_no_coerce = (datetime_list, tz, False, expected_output, raises)
310
+ test_cases.extend([bad_type_coerce, bad_type_no_coerce]) # type: ignore
311
+
312
+ return test_cases # type: ignore
313
+
314
+
315
+ @pytest.mark.parametrize(
316
+ "examples, tz, coerce, expected_output, raises",
317
+ generate_test_cases_timezone_flexible(),
318
+ )
319
+ def test_dt_timezone_flexible(examples, tz, coerce, expected_output, raises):
320
+ """Test that timezone_flexible works as expected"""
321
+
322
+ # Testing using a pandera DataFrameModel rather than directly calling dtype coerce or validate because with
323
+ # timezone_flexible, dtype is set dynamically based on the input data
324
+ class SimpleSchema(DataFrameModel):
325
+ # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
326
+ datetime_column: pandas_engine.DateTime(
327
+ timezone_flexible=True, tz=tz
328
+ ) = Field(coerce=coerce)
329
+
330
+ data = pd.DataFrame({"datetime_column": examples})
331
+
332
+ if raises:
333
+ with pytest.raises((SchemaError, errors.ParserError)):
334
+ SimpleSchema.validate(data)
335
+ else:
336
+ validated_df = SimpleSchema.validate(data)
337
+ assert sorted(validated_df["datetime_column"].tolist()) == sorted(
338
+ expected_output
339
+ )
340
+
341
+
205
342
@hypothesis.settings(max_examples=1000)
206
343
@pytest.mark.parametrize("to_df", [True, False])
207
344
@given(
@@ -225,7 +362,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
225
362
)
226
363
227
364
assert (
228
- coerced_data.applymap(lambda x: isinstance(x, date))
365
+ coerced_data.applymap(lambda x: isinstance(x, dt. date))
229
366
| coerced_data.isna()
230
367
).all(axis=None)
231
368
return
@@ -234,7 +371,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
234
371
coerced_data.isna().all() and coerced_data.dtype == "datetime64[ns]"
235
372
)
236
373
assert (
237
- coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
374
+ coerced_data.map(lambda x: isinstance(x, dt.date))
375
+ | coerced_data.isna()
238
376
).all()
239
377
240
378
@@ -246,8 +384,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
246
384
pyarrow.struct([("foo", pyarrow.int64()), ("bar", pyarrow.string())]),
247
385
),
248
386
(pd.Series([None, pd.NA, np.nan]), pyarrow.null),
249
- (pd.Series([None, date(1970, 1, 1)]), pyarrow.date32),
250
- (pd.Series([None, date(1970, 1, 1)]), pyarrow.date64),
387
+ (pd.Series([None, dt. date(1970, 1, 1)]), pyarrow.date32),
388
+ (pd.Series([None, dt. date(1970, 1, 1)]), pyarrow.date64),
251
389
(pd.Series([1, 2]), pyarrow.duration("ns")),
252
390
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time32("ms")),
253
391
(pd.Series([1, 1e3, 1e6, 1e9, None]), pyarrow.time64("ns")),
@@ -292,8 +430,8 @@ def test_pandas_arrow_dtype(data, dtype):
292
430
pyarrow.struct([("foo", pyarrow.string()), ("bar", pyarrow.int64())]),
293
431
),
294
432
(pd.Series(["a", "1"]), pyarrow.null),
295
- (pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
296
- (pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
433
+ (pd.Series(["a", dt. date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
434
+ (pd.Series(["a", dt. date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
297
435
(pd.Series(["a"]), pyarrow.duration("ns")),
298
436
(pd.Series(["a", "b"]), pyarrow.time32("ms")),
299
437
(pd.Series(["a", "b"]), pyarrow.time64("ns")),
0 commit comments