Skip to content

Commit 1ad2023

Browse files
author
Arne Recknagel
committed
resolves: unionai-oss#992
Signed-off-by: Arne Recknagel <[email protected]>
1 parent 5aa7795 commit 1ad2023

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

pandera/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
FieldInfo,
3838
)
3939
from .schemas import DataFrameSchema
40-
from .typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo
40+
from .typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo, DataFrame
4141
from .typing.common import DataFrameBase
4242
from .typing.config import BaseConfig
4343

@@ -521,6 +521,12 @@ def __modify_schema__(cls, field_schema):
521521
"""Update pydantic field schema."""
522522
field_schema.update(to_json_schema(cls.to_schema()))
523523

524+
@classmethod
525+
@docstring_substitution(validate_doc=DataFrameSchema.empty.__doc__)
526+
def empty(cls: Type[TSchemaModel]) -> DataFrame[TSchemaModel]:
527+
"""%(validate_doc)s"""
528+
return cast(DataFrame[TSchemaModel], cls.to_schema().empty())
529+
524530
def __class_getitem__(
525531
cls: Type[TSchemaModel],
526532
params: Union[Type[Any], Tuple[Type[Any], ...]],

pandera/schemas.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,9 +798,9 @@ def _validate(
798798
with ps.option_context(
799799
"compute.ops_on_diff_frames", True
800800
):
801-
failure_cases = df_to_validate.loc[duplicates, lst]
801+
failure_cases = df_to_validate.loc[duplicates, lst] # type: ignore
802802
else:
803-
failure_cases = df_to_validate.loc[duplicates, lst]
803+
failure_cases = df_to_validate.loc[duplicates, lst] # type: ignore
804804

805805
failure_cases = reshape_failure_cases(failure_cases)
806806
error_handler.collect_error(
@@ -1747,6 +1747,15 @@ def _pydantic_validate(cls, schema: Any) -> "DataFrameSchema":
17471747

17481748
return cast("DataFrameSchema", schema)
17491749

1750+
def empty(self) -> pd.DataFrame:
1751+
"""Return an empty dataframe with correctly named and typed columns."""
1752+
coerce_old = self.coerce
1753+
try:
1754+
self.coerce = True
1755+
return self.coerce_dtype(pd.DataFrame(columns=[*self.columns]))
1756+
finally:
1757+
self.coerce = coerce_old
1758+
17501759

17511760
class SeriesSchemaBase:
17521761
"""Base series validator object."""

tests/core/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ class EmptyParentSchema(EmptySchema):
7070
assert empty_parent_schema == EmptyParentSchema.to_schema()
7171

7272

73+
def test_create_empty_dataframe():
74+
"""Ensure that SchemaModel proxies the `empty` method correctly."""
75+
76+
class Schema(pa.SchemaModel):
77+
col_a: pa.typing.Series[int]
78+
col_b: pa.typing.Series[str]
79+
col_c: pa.typing.Series[float]
80+
81+
result = Schema.empty()
82+
83+
assert result.empty
84+
assert Schema.validate(result).empty
85+
86+
7387
def test_invalid_annotations() -> None:
7488
"""Test that SchemaModel.to_schema() fails if annotations or types are not
7589
recognized.

tests/core/test_schemas.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pandera.dtypes import UniqueSettings
2626
from pandera.engines.pandas_engine import Engine
2727
from pandera.schemas import SeriesSchemaBase
28+
from pandera.engines.engine import DataType
2829

2930

3031
def test_dataframe_schema() -> None:
@@ -1283,6 +1284,48 @@ def test_lazy_dataframe_validation_nullable_with_checks() -> None:
12831284
)
12841285

12851286

1287+
def test_schema_empty():
1288+
"""Ensure that an empty dataframe works for all valid pandera dtypes."""
1289+
# make sure all subclasses of pandera.dtypes.DataType are instantiated
1290+
import pandera.dtypes # pylint: disable=C0415:
1291+
import pandera.engines.pandas_engine # pylint: disable=C0415:
1292+
import pandera.engines.numpy_engine # pylint: disable=C0415:
1293+
1294+
# find all subclasses
1295+
def get_subclasses(parent: type):
1296+
yield parent
1297+
for child in parent.__subclasses__():
1298+
yield from get_subclasses(child)
1299+
1300+
# create a valid schema of all possible dtypes pandera supports
1301+
skip = {
1302+
# these are abstract and should be excluded from column creation
1303+
pandera.dtypes.DataType,
1304+
pandera.engines.numpy_engine.DataType,
1305+
pandera.engines.pandas_engine.DataType,
1306+
# these fail when trying to make columns from them, but probably shouldn't
1307+
pandera.dtypes._Number,
1308+
pandera.dtypes._PhysicalNumber,
1309+
pandera.engines.pandas_engine.Period,
1310+
pandera.engines.pandas_engine.PydanticModel,
1311+
pandera.engines.pandas_engine.Interval,
1312+
# these fail during validation, and definitely shouldn't
1313+
pandera.engines.numpy_engine.DateTime64,
1314+
pandera.engines.numpy_engine.Bytes,
1315+
}
1316+
columns = {
1317+
f"{cls.__module__}.{cls.__qualname__}": pandera.Column(cls)
1318+
for cls in get_subclasses(DataType)
1319+
if cls not in skip
1320+
}
1321+
schema = pandera.DataFrameSchema(columns=columns)
1322+
1323+
result = schema.empty()
1324+
1325+
assert result.empty
1326+
assert schema.validate(result).empty
1327+
1328+
12861329
@pytest.mark.parametrize(
12871330
"schema_cls, data",
12881331
[

0 commit comments

Comments
 (0)