1
1
"""Test pandas engine."""
2
2
3
- from datetime import date
4
- from typing import Any , Set
3
+ import datetime as dt
4
+ from typing import Tuple , List , Optional , Any , Set
5
+ from zoneinfo import ZoneInfo
5
6
6
7
import hypothesis
7
8
import hypothesis .extra .pandas as pd_st
13
14
import pytz
14
15
from hypothesis import given
15
16
17
+ from pandera import Field , DataFrameModel
16
18
from pandera .engines import pandas_engine
17
- from pandera .errors import ParserError
19
+ from pandera .errors import ParserError , SchemaError
18
20
19
21
UNSUPPORTED_DTYPE_CLS : Set [Any ] = set ()
20
22
@@ -202,6 +204,143 @@ def test_pandas_datetimetz_dtype(timezone_aware, data, timezone):
202
204
assert coerced_data .dt .tz == timezone
203
205
204
206
207
+ def generate_test_cases_timezone_flexible () -> List [
208
+ Tuple [
209
+ List [dt .datetime ],
210
+ Optional [dt .tzinfo ],
211
+ bool ,
212
+ List [dt .datetime ],
213
+ bool ,
214
+ ]
215
+ ]:
216
+ """
217
+ Generate test parameter combinations for a given list of datetime lists.
218
+
219
+ Returns:
220
+ List of tuples:
221
+ - List of input datetimes
222
+ - tz for DateTime constructor
223
+ - coerce flag for Field constructor
224
+ - expected output datetimes
225
+ - raises flag (True if an exception is expected, False otherwise)
226
+ """
227
+ datetimes = [
228
+ # multi tz and tz naive
229
+ [
230
+ dt .datetime (2023 , 3 , 1 , 4 , tzinfo = ZoneInfo ("America/New_York" )),
231
+ dt .datetime (2023 , 3 , 1 , 5 , tzinfo = ZoneInfo ("America/Los_Angeles" )),
232
+ dt .datetime (2023 , 3 , 1 , 5 ),
233
+ ],
234
+ # multiz tz
235
+ [
236
+ dt .datetime (2023 , 3 , 1 , 4 , tzinfo = ZoneInfo ("America/New_York" )),
237
+ dt .datetime (2023 , 3 , 1 , 5 , tzinfo = ZoneInfo ("America/Los_Angeles" )),
238
+ ],
239
+ # tz naive
240
+ [dt .datetime (2023 , 3 , 1 , 4 ), dt .datetime (2023 , 3 , 1 , 5 )],
241
+ # single tz
242
+ [
243
+ dt .datetime (2023 , 3 , 1 , 4 , tzinfo = ZoneInfo ("America/New_York" )),
244
+ dt .datetime (2023 , 3 , 1 , 5 , tzinfo = ZoneInfo ("America/New_York" )),
245
+ ],
246
+ ]
247
+
248
+ test_cases = []
249
+
250
+ for datetime_list in datetimes :
251
+ for coerce in [True , False ]:
252
+ for tz in [
253
+ None ,
254
+ ZoneInfo ("America/Chicago" ),
255
+ dt .timezone (dt .timedelta (hours = 2 )),
256
+ ]:
257
+ # Determine if the test should raise an exception
258
+ # Should raise error when:
259
+ # * coerce is False but there is a timezone-naive datetime
260
+ # * coerce is True but tz is not set
261
+ has_naive_datetime = any (
262
+ dt .tzinfo is None for dt in datetime_list
263
+ )
264
+ raises = (not coerce and has_naive_datetime ) or (
265
+ coerce and tz is None
266
+ )
267
+
268
+ # Generate expected output
269
+ if raises :
270
+ expected_output = None # No expected output since an exception will be raised
271
+ else :
272
+ if coerce :
273
+ # localize / convert the input datetimes to the specified tz or 'UTC' (default)
274
+ use_tz = tz if tz else ZoneInfo ("UTC" )
275
+ expected_output_naive = [
276
+ dt .replace (tzinfo = use_tz )
277
+ for dt in datetime_list
278
+ if dt .tzinfo is None
279
+ ]
280
+ expected_output_aware = [
281
+ dt .astimezone (use_tz )
282
+ for dt in datetime_list
283
+ if dt .tzinfo is not None
284
+ ]
285
+ expected_output = (
286
+ expected_output_naive + expected_output_aware
287
+ )
288
+ else :
289
+ # ignore tz
290
+ expected_output = datetime_list
291
+
292
+ test_case = (
293
+ datetime_list ,
294
+ tz ,
295
+ coerce ,
296
+ expected_output ,
297
+ raises ,
298
+ )
299
+ test_cases .append (test_case )
300
+
301
+ # define final test cases with improper type
302
+ datetime_list = [
303
+ dt .datetime (2023 , 3 , 1 , 4 , tzinfo = ZoneInfo ("America/New_York" )),
304
+ "hello world" ,
305
+ ]
306
+ tz = None
307
+ expected_output = None
308
+ raises = True
309
+
310
+ bad_type_coerce = (datetime_list , tz , True , expected_output , raises )
311
+ bad_type_no_coerce = (datetime_list , tz , False , expected_output , raises )
312
+ test_cases .extend ([bad_type_coerce , bad_type_no_coerce ]) # type: ignore
313
+
314
+ return test_cases # type: ignore
315
+
316
+
317
+ @pytest .mark .parametrize (
318
+ "examples, tz, coerce, expected_output, raises" ,
319
+ generate_test_cases_timezone_flexible (),
320
+ )
321
+ def test_dt_timezone_flexible (examples , tz , coerce , expected_output , raises ):
322
+ """Test that timezone_flexible works as expected"""
323
+
324
+ # Testing using a pandera DataFrameModel rather than directly calling dtype coerce or validate because with
325
+ # timezone_flexible, dtype is set dynamically based on the input data
326
+ class SimpleSchema (DataFrameModel ):
327
+ # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
328
+ datetime_column : pandas_engine .DateTime (
329
+ timezone_flexible = True , tz = tz
330
+ ) = Field (coerce = coerce )
331
+
332
+ data = pd .DataFrame ({"datetime_column" : examples })
333
+
334
+ if raises :
335
+ with pytest .raises (SchemaError ):
336
+ SimpleSchema .validate (data )
337
+ else :
338
+ validated_df = SimpleSchema .validate (data )
339
+ assert sorted (validated_df ["datetime_column" ].tolist ()) == sorted (
340
+ expected_output
341
+ )
342
+
343
+
205
344
@hypothesis .settings (max_examples = 1000 )
206
345
@pytest .mark .parametrize ("to_df" , [True , False ])
207
346
@given (
@@ -225,7 +364,7 @@ def test_pandas_date_coerce_dtype(to_df, data):
225
364
)
226
365
227
366
assert (
228
- coerced_data .applymap (lambda x : isinstance (x , date ))
367
+ coerced_data .applymap (lambda x : isinstance (x , dt . date ))
229
368
| coerced_data .isna ()
230
369
).all (axis = None )
231
370
return
@@ -234,7 +373,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
234
373
coerced_data .isna ().all () and coerced_data .dtype == "datetime64[ns]"
235
374
)
236
375
assert (
237
- coerced_data .map (lambda x : isinstance (x , date )) | coerced_data .isna ()
376
+ coerced_data .map (lambda x : isinstance (x , dt .date ))
377
+ | coerced_data .isna ()
238
378
).all ()
239
379
240
380
@@ -246,8 +386,8 @@ def test_pandas_date_coerce_dtype(to_df, data):
246
386
pyarrow .struct ([("foo" , pyarrow .int64 ()), ("bar" , pyarrow .string ())]),
247
387
),
248
388
(pd .Series ([None , pd .NA , np .nan ]), pyarrow .null ),
249
- (pd .Series ([None , date (1970 , 1 , 1 )]), pyarrow .date32 ),
250
- (pd .Series ([None , date (1970 , 1 , 1 )]), pyarrow .date64 ),
389
+ (pd .Series ([None , dt . date (1970 , 1 , 1 )]), pyarrow .date32 ),
390
+ (pd .Series ([None , dt . date (1970 , 1 , 1 )]), pyarrow .date64 ),
251
391
(pd .Series ([1 , 2 ]), pyarrow .duration ("ns" )),
252
392
(pd .Series ([1 , 1e3 , 1e6 , 1e9 , None ]), pyarrow .time32 ("ms" )),
253
393
(pd .Series ([1 , 1e3 , 1e6 , 1e9 , None ]), pyarrow .time64 ("ns" )),
@@ -292,8 +432,8 @@ def test_pandas_arrow_dtype(data, dtype):
292
432
pyarrow .struct ([("foo" , pyarrow .string ()), ("bar" , pyarrow .int64 ())]),
293
433
),
294
434
(pd .Series (["a" , "1" ]), pyarrow .null ),
295
- (pd .Series (["a" , date (1970 , 1 , 1 ), "1970-01-01" ]), pyarrow .date32 ),
296
- (pd .Series (["a" , date (1970 , 1 , 1 ), "1970-01-01" ]), pyarrow .date64 ),
435
+ (pd .Series (["a" , dt . date (1970 , 1 , 1 ), "1970-01-01" ]), pyarrow .date32 ),
436
+ (pd .Series (["a" , dt . date (1970 , 1 , 1 ), "1970-01-01" ]), pyarrow .date64 ),
297
437
(pd .Series (["a" ]), pyarrow .duration ("ns" )),
298
438
(pd .Series (["a" , "b" ]), pyarrow .time32 ("ms" )),
299
439
(pd .Series (["a" , "b" ]), pyarrow .time64 ("ns" )),
0 commit comments