Skip to content

Commit 47ccd9f

Browse files
authored
make finding coerce failure cases faster (#792)
* make finding coerce failure cases faster * fix tests * remove unneeded import * fix tests, coverage
1 parent 9a43c14 commit 47ccd9f

File tree

5 files changed

+114
-43
lines changed

5 files changed

+114
-43
lines changed

pandera/dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def coerce(self, data_container: Any):
3131
"""Coerce data container to the data type."""
3232
raise NotImplementedError()
3333

34+
def coerce_value(self, value: Any):
35+
"""Coerce an value to a particular type."""
36+
raise NotImplementedError()
37+
3438
def try_coerce(self, data_container: Any):
3539
"""Coerce data container to the data type,
3640
raises a `~pandera.errors.ParserError` if the coercion fails

pandera/engines/numpy_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def coerce(self, data_container: PandasObject) -> PandasObject:
5151
coerced.__str__()
5252
return coerced
5353

54+
def coerce_value(self, value: Any) -> Any:
55+
"""Coerce an value to a particular type."""
56+
return self.type.type(value)
57+
5458
def try_coerce(
5559
self, data_container: Union[PandasObject, np.ndarray]
5660
) -> Union[PandasObject, np.ndarray]:

pandera/engines/pandas_engine.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def coerce(self, data_container: PandasObject) -> PandasObject:
8181
coerced.__str__()
8282
return coerced
8383

84+
def coerce_value(self, value: Any) -> Any:
85+
"""Coerce an value to a particular type."""
86+
return self.type.type(value)
87+
8488
def try_coerce(self, data_container: PandasObject) -> PandasObject:
8589
try:
8690
return self.coerce(data_container)
@@ -197,6 +201,15 @@ class BOOL(DataType, dtypes.Bool):
197201
"""Semantic representation of a :class:`pandas.BooleanDtype`."""
198202

199203
type = pd.BooleanDtype()
204+
_bool_like = frozenset({True, False})
205+
206+
def coerce_value(self, value: Any) -> Any:
207+
"""Coerce an value to specified datatime type."""
208+
if value not in self._bool_like:
209+
raise TypeError(
210+
f"value {value} cannot be coerced to type {self.type}"
211+
)
212+
return super().coerce_value(value)
200213

201214

202215
###############################################################################
@@ -416,6 +429,26 @@ def __init__( # pylint:disable=super-init-not-called
416429
pd.CategoricalDtype(self.categories, self.ordered),
417430
)
418431

432+
def coerce(self, data_container: PandasObject) -> PandasObject:
433+
"""Pure coerce without catching exceptions."""
434+
coerced = data_container.astype(self.type)
435+
if (coerced.isna() & data_container.notna()).any(axis=None):
436+
raise TypeError(
437+
f"Data container cannot be coerced to type {self.type}"
438+
)
439+
if type(data_container).__module__.startswith("modin.pandas"):
440+
# NOTE: this is a hack to enable catching of errors in modin
441+
coerced.__str__()
442+
return coerced
443+
444+
def coerce_value(self, value: Any) -> Any:
445+
"""Coerce an value to a particular type."""
446+
if value not in self.categories: # type: ignore
447+
raise TypeError(
448+
f"value {value} cannot be coerced to type {self.type}"
449+
)
450+
return value
451+
419452
@classmethod
420453
def from_parametrized_dtype(
421454
cls, cat: Union[dtypes.Category, pd.CategoricalDtype]
@@ -589,6 +622,12 @@ def _to_datetime(col: pd.Series) -> pd.Series:
589622
return data_container.transform(_to_datetime)
590623
return _to_datetime(data_container)
591624

625+
def coerce_value(self, value: Any) -> Any:
626+
"""Coerce an value to specified datatime type."""
627+
if value is pd.NaT:
628+
return value
629+
return super().coerce_value(value)
630+
592631
@classmethod
593632
def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype):
594633
"""Convert a :class:`pandas.DatetimeTZDtype` to

pandera/engines/utils.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Engine module utilities."""
22

3-
import itertools
43
from typing import Any, Union
54

65
import numpy as np
@@ -23,53 +22,14 @@ def numpy_pandas_coercible(series: pd.Series, type_: Any) -> pd.Series:
2322

2423
data_type = pandas_engine.Engine.dtype(type_)
2524

26-
def _bisect(series):
27-
assert (
28-
series.shape[0] >= 2
29-
), "cannot bisect a pandas Series of length < 2"
30-
bisect_index = series.shape[0] // 2
31-
return [series.iloc[:bisect_index], series.iloc[bisect_index:]]
32-
33-
def _coercible(series):
25+
def _coercible(x):
3426
try:
35-
data_type.coerce(series)
27+
data_type.coerce_value(x)
3628
return True
3729
except Exception: # pylint:disable=broad-except
3830
return False
3931

40-
search_list = [series] if series.size == 1 else _bisect(series)
41-
failure_index = []
42-
while search_list:
43-
candidates = []
44-
for _series in search_list:
45-
if _series.shape[0] == 1 and not _coercible(_series):
46-
# if series is reduced to a single value and isn't coercible,
47-
# keep track of its index value.
48-
failure_index.append(_series.index.item())
49-
elif not _coercible(_series):
50-
# if the series length > 1, add it to the candidates list
51-
# to be further bisected
52-
candidates.append(_series)
53-
54-
# the new search list is a flat list of bisected series views.
55-
search_list = list(
56-
itertools.chain.from_iterable([_bisect(c) for c in candidates])
57-
)
58-
59-
# NOTE: this is a hack to support koalas. This needs to be thoroughly
60-
# tested, right now koalas returns NA when a dtype value can't be coerced
61-
# into the target dtype.
62-
if type(series).__module__.startswith(
63-
"databricks.koalas"
64-
): # pragma: no cover
65-
out = type(series)(
66-
series.index.isin(failure_index).to_series().to_numpy(), # type: ignore[union-attr]
67-
index=series.index.values.to_numpy(),
68-
name=series.name,
69-
)
70-
out.index.name = series.index.name
71-
return out
72-
return pd.Series(~series.index.isin(failure_index), index=series.index)
32+
return series.map(_coercible)
7333

7434

7535
def numpy_pandas_coerce_failure_cases(

tests/core/test_pandas_engine.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Test numpy engine."""
22

3+
import hypothesis.strategies as st
34
import pandas as pd
45
import pytest
6+
from hypothesis import given
57

68
from pandera.engines import pandas_engine
79
from pandera.errors import ParserError
@@ -42,3 +44,65 @@ def test_pandas_data_type_coerce(data_type):
4244
data_type().try_coerce(pd.Series(["1", "2", "a"]))
4345
except ParserError as exc:
4446
assert exc.failure_cases.shape[0] > 0
47+
48+
49+
CATEGORIES = ["A", "B", "C"]
50+
51+
52+
@given(st.lists(st.sampled_from(CATEGORIES), min_size=5))
53+
def test_pandas_category_dtype(data):
54+
"""Test pandas_engine.Category correctly coerces valid categorical data."""
55+
data = pd.Series(data)
56+
dtype = pandas_engine.Category(CATEGORIES)
57+
coerced_data = dtype.coerce(data)
58+
assert dtype.check(coerced_data.dtype)
59+
60+
for _, value in data.iteritems():
61+
coerced_value = dtype.coerce_value(value)
62+
assert coerced_value in CATEGORIES
63+
64+
65+
@given(st.lists(st.sampled_from(["X", "Y", "Z"]), min_size=5))
66+
def test_pandas_category_dtype_error(data):
67+
"""Test pandas_engine.Category raises TypeErrors on invalid data."""
68+
data = pd.Series(data)
69+
dtype = pandas_engine.Category(CATEGORIES)
70+
71+
with pytest.raises(TypeError):
72+
dtype.coerce(data)
73+
74+
for _, value in data.iteritems():
75+
with pytest.raises(TypeError):
76+
dtype.coerce_value(value)
77+
78+
79+
@given(st.lists(st.sampled_from([1, 0, 1.0, 0.0, True, False]), min_size=5))
80+
def test_pandas_boolean_native_type(data):
81+
"""Test native pandas bool type correctly coerces valid bool-like data."""
82+
data = pd.Series(data)
83+
dtype = pandas_engine.Engine.dtype("boolean")
84+
85+
# the BooleanDtype can't handle Series of non-boolean, mixed dtypes
86+
if data.dtype == "object":
87+
with pytest.raises(TypeError):
88+
dtype.coerce(data)
89+
else:
90+
coerced_data = dtype.coerce(data)
91+
assert dtype.check(coerced_data.dtype)
92+
93+
for _, value in data.iteritems():
94+
dtype.coerce_value(value)
95+
96+
97+
@given(st.lists(st.sampled_from(["A", "True", "False", 5, -1]), min_size=5))
98+
def test_pandas_boolean_native_type_error(data):
99+
"""Test native pandas bool type raises TypeErrors on non-bool-like data."""
100+
data = pd.Series(data)
101+
dtype = pandas_engine.Engine.dtype("boolean")
102+
103+
with pytest.raises(TypeError):
104+
dtype.coerce(data)
105+
106+
for _, value in data.iteritems():
107+
with pytest.raises(TypeError):
108+
dtype.coerce_value(value)

0 commit comments

Comments
 (0)