Skip to content

Commit b438e15

Browse files
tfwillemsThomas WillemscosmicBboy
authored
Feature: Add support for Generic to SchemaModel (#810)
* Adapt SchemaModel so that it can inherit from typing.Generic * Extend SchemaModel to enable generic types in fields * fix linter Co-authored-by: Thomas Willems <[email protected]> Co-authored-by: cosmicBboy <[email protected]>
1 parent 9a463e1 commit b438e15

File tree

3 files changed

+262
-4
lines changed

3 files changed

+262
-4
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ good-names=
1313
k,
1414
v,
1515
fp,
16+
bar,
1617
_IS_INFERRED,
1718

1819
[MESSAGES CONTROL]

pandera/model.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Class-based api"""
22

3+
import copy
34
import inspect
45
import os
56
import re
@@ -59,6 +60,9 @@
5960

6061

6162
MODEL_CACHE: Dict[Type["SchemaModel"], DataFrameSchema] = {}
63+
GENERIC_SCHEMA_CACHE: Dict[
64+
Tuple[Type["SchemaModel"], Tuple[Type[Any], ...]], Type["SchemaModel"]
65+
] = {}
6266
F = TypeVar("F", bound=Callable)
6367
TSchemaModel = TypeVar("TSchemaModel", bound="SchemaModel")
6468

@@ -190,6 +194,10 @@ def to_schema(cls) -> DataFrameSchema:
190194
}
191195

192196
cls.__fields__ = cls._collect_fields()
197+
for field, (annot_info, _) in cls.__fields__.items():
198+
if isinstance(annot_info.arg, TypeVar):
199+
raise SchemaInitError(f"Field {field} has a generic data type")
200+
193201
check_infos = typing.cast(
194202
List[FieldCheckInfo], cls._collect_check_infos(CHECK_KEY)
195203
)
@@ -372,7 +380,8 @@ def _get_model_attrs(cls) -> Dict[str, Any]:
372380
bases = inspect.getmro(cls)[:-1] # bases -> SchemaModel -> object
373381
attrs = {}
374382
for base in reversed(bases):
375-
attrs.update(base.__dict__)
383+
if issubclass(base, SchemaModel):
384+
attrs.update(base.__dict__)
376385
return attrs
377386

378387
@classmethod
@@ -412,7 +421,7 @@ def _collect_config_and_extras(
412421
) -> Tuple[Type[BaseConfig], Dict[str, Any]]:
413422
"""Collect config options from bases, splitting off unknown options."""
414423
bases = inspect.getmro(cls)[:-1]
415-
bases = typing.cast(Tuple[Type[SchemaModel]], bases)
424+
bases = tuple(base for base in bases if issubclass(base, SchemaModel))
416425
root_model, *models = reversed(bases)
417426

418427
options, extras = _extract_config_options_and_extras(root_model.Config)
@@ -434,7 +443,7 @@ def _collect_check_infos(cls, key: str) -> List[CheckInfo]:
434443
walk the inheritance tree.
435444
"""
436445
bases = inspect.getmro(cls)[:-2] # bases -> SchemaModel -> object
437-
bases = typing.cast(Tuple[Type[SchemaModel]], bases)
446+
bases = tuple(base for base in bases if issubclass(base, SchemaModel))
438447

439448
method_names = set()
440449
check_infos = []
@@ -512,6 +521,49 @@ def __modify_schema__(cls, field_schema):
512521
"""Update pydantic field schema."""
513522
field_schema.update(to_json_schema(cls.to_schema()))
514523

524+
def __class_getitem__(
525+
cls: Type[TSchemaModel],
526+
params: Union[Type[Any], Tuple[Type[Any], ...]],
527+
) -> Type[TSchemaModel]:
528+
"""Parameterize the class's generic arguments with the specified types"""
529+
if not hasattr(cls, "__parameters__"):
530+
raise TypeError(
531+
f"{cls.__name__} must inherit from typing.Generic before being parameterized"
532+
)
533+
# pylint: disable=no-member
534+
__parameters__: Tuple[TypeVar, ...] = cls.__parameters__ # type: ignore
535+
536+
if not isinstance(params, tuple):
537+
params = (params,)
538+
if len(params) != len(__parameters__):
539+
raise ValueError(
540+
f"Expected {len(__parameters__)} generic arguments but found {len(params)}"
541+
)
542+
if (cls, params) in GENERIC_SCHEMA_CACHE:
543+
return typing.cast(
544+
Type[TSchemaModel], GENERIC_SCHEMA_CACHE[(cls, params)]
545+
)
546+
547+
param_dict: Dict[TypeVar, Type[Any]] = dict(
548+
zip(__parameters__, params)
549+
)
550+
extra: Dict[str, Any] = {"__annotations__": {}}
551+
for field, (annot_info, field_info) in cls._collect_fields().items():
552+
if isinstance(annot_info.arg, TypeVar):
553+
if annot_info.arg in param_dict:
554+
raw_annot = annot_info.origin[param_dict[annot_info.arg]] # type: ignore
555+
if annot_info.optional:
556+
raw_annot = Optional[raw_annot]
557+
extra["__annotations__"][field] = raw_annot
558+
extra[field] = copy.deepcopy(field_info)
559+
560+
parameterized_name = (
561+
f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]"
562+
)
563+
parameterized_cls = type(parameterized_name, (cls,), extra)
564+
GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls
565+
return parameterized_cls
566+
515567

516568
def _build_schema_index(
517569
indices: List[schema_components.Index], **multiindex_kwargs: Any

tests/core/test_model.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import re
44
from copy import deepcopy
55
from decimal import Decimal # pylint:disable=C0415
6-
from typing import Any, Iterable, Optional
6+
from typing import Any, Generic, Iterable, Optional, TypeVar
77

88
import pandas as pd
99
import pytest
1010

1111
import pandera as pa
1212
import pandera.extensions as pax
13+
from pandera.errors import SchemaError, SchemaInitError
1314
from pandera.typing import DataFrame, Index, Series, String
1415

1516

@@ -1002,3 +1003,207 @@ class Config:
10021003
match="^expected series 'price' to have type float64, got int64$",
10031004
):
10041005
DataFrame[SchemaNoCoerce](raw_data)
1006+
1007+
1008+
def test_schema_model_generic_inheritance() -> None:
1009+
"""Test that a schema model subclass can also inherit from typing.Generic"""
1010+
1011+
T = TypeVar("T")
1012+
1013+
class Foo(pa.SchemaModel, Generic[T]):
1014+
@classmethod
1015+
def bar(cls) -> T:
1016+
raise NotImplementedError
1017+
1018+
class Bar1(Foo[int]):
1019+
@classmethod
1020+
def bar(cls) -> int:
1021+
return 1
1022+
1023+
class Bar2(Foo[str]):
1024+
@classmethod
1025+
def bar(cls) -> str:
1026+
return "1"
1027+
1028+
with pytest.raises(NotImplementedError):
1029+
Foo.bar()
1030+
assert Bar1.bar() == 1
1031+
assert Bar2.bar() == "1"
1032+
1033+
1034+
def test_generic_no_generic_fields() -> None:
1035+
T = TypeVar("T", int, float, str)
1036+
1037+
class GenericModel(pa.SchemaModel, Generic[T]):
1038+
x: Series[int]
1039+
1040+
GenericModel.to_schema()
1041+
1042+
1043+
def test_generic_model_single_generic_field() -> None:
1044+
T = TypeVar("T", int, float, str)
1045+
1046+
class GenericModel(pa.SchemaModel, Generic[T]):
1047+
x: Series[int]
1048+
y: Series[T]
1049+
1050+
with pytest.raises(SchemaInitError):
1051+
GenericModel.to_schema()
1052+
1053+
class IntModel(GenericModel[int]):
1054+
...
1055+
1056+
IntModel.to_schema()
1057+
1058+
class FloatModel(GenericModel[float]):
1059+
...
1060+
1061+
FloatModel.to_schema()
1062+
1063+
IntModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
1064+
with pytest.raises(SchemaError):
1065+
FloatModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
1066+
1067+
with pytest.raises(SchemaError):
1068+
IntModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5, 6]}))
1069+
FloatModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5, 6]}))
1070+
1071+
1072+
def test_generic_optional_field() -> None:
1073+
T = TypeVar("T", int, float, str)
1074+
1075+
class GenericModel(pa.SchemaModel, Generic[T]):
1076+
x: Series[int]
1077+
y: Optional[Series[T]]
1078+
1079+
class IntYModel(GenericModel[int]):
1080+
...
1081+
1082+
IntYModel.validate(pd.DataFrame({"x": [1, 2, 3]}))
1083+
IntYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
1084+
with pytest.raises(SchemaError):
1085+
IntYModel.validate(
1086+
pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]})
1087+
)
1088+
1089+
class FloatYModel(GenericModel[float]):
1090+
...
1091+
1092+
FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3]}))
1093+
with pytest.raises(SchemaError):
1094+
FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
1095+
FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]}))
1096+
1097+
1098+
def test_generic_model_multiple_inheritance() -> None:
1099+
T = TypeVar("T", int, float, str)
1100+
1101+
class GenericYModel(pa.SchemaModel, Generic[T]):
1102+
x: Series[int]
1103+
y: Series[T]
1104+
1105+
class GenericZModel(pa.SchemaModel, Generic[T]):
1106+
z: Series[T]
1107+
1108+
class IntYFloatZModel(GenericYModel[int], GenericZModel[float]):
1109+
...
1110+
1111+
IntYFloatZModel.to_schema()
1112+
IntYFloatZModel.validate(
1113+
pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]})
1114+
)
1115+
with pytest.raises(SchemaError):
1116+
IntYFloatZModel.validate(
1117+
pd.DataFrame(
1118+
{"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}
1119+
)
1120+
)
1121+
with pytest.raises(SchemaError):
1122+
IntYFloatZModel.validate(
1123+
pd.DataFrame(
1124+
{"x": ["a", "b", "c"], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}
1125+
)
1126+
)
1127+
1128+
class FloatYIntZModel(GenericYModel[float], GenericZModel[int]):
1129+
...
1130+
1131+
FloatYIntZModel.to_schema()
1132+
with pytest.raises(SchemaError):
1133+
FloatYIntZModel.validate(
1134+
pd.DataFrame(
1135+
{"x": [1, 2, 3], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}
1136+
)
1137+
)
1138+
FloatYIntZModel.validate(
1139+
pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]})
1140+
)
1141+
with pytest.raises(SchemaError):
1142+
FloatYIntZModel.validate(
1143+
pd.DataFrame(
1144+
{"x": ["a", "b", "c"], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}
1145+
)
1146+
)
1147+
1148+
1149+
def test_multiple_generic() -> None:
1150+
"""Test that a generic schema with multiple types is handled correctly"""
1151+
T1 = TypeVar("T1", int, float, str)
1152+
T2 = TypeVar("T2", int, float, str)
1153+
1154+
class GenericModel(pa.SchemaModel, Generic[T1, T2]):
1155+
y: Series[T1]
1156+
z: Series[T2]
1157+
1158+
class IntYFloatZModel(GenericModel[int, float]):
1159+
...
1160+
1161+
IntYFloatZModel.to_schema()
1162+
IntYFloatZModel.to_schema()
1163+
IntYFloatZModel.validate(
1164+
pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]})
1165+
)
1166+
with pytest.raises(SchemaError):
1167+
IntYFloatZModel.validate(
1168+
pd.DataFrame({"y": [4.0, 5.0, 6.0], "z": [1, 2, 3]})
1169+
)
1170+
1171+
class FloatYIntZModel(GenericModel[float, int]):
1172+
...
1173+
1174+
FloatYIntZModel.to_schema()
1175+
with pytest.raises(SchemaError):
1176+
FloatYIntZModel.validate(
1177+
pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]})
1178+
)
1179+
FloatYIntZModel.validate(
1180+
pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]})
1181+
)
1182+
1183+
1184+
def test_repeated_generic() -> None:
1185+
"""Test that repeated use of Generic in a class hierachy results in the correct types"""
1186+
T1 = TypeVar("T1", int, float, str)
1187+
T2 = TypeVar("T2", int, float, str)
1188+
T3 = TypeVar("T3", int, float, str)
1189+
1190+
class GenericYZModel(pa.SchemaModel, Generic[T1, T2]):
1191+
y: Series[T1]
1192+
z: Series[T2]
1193+
1194+
class IntYGenericZModel(GenericYZModel[int, T3], Generic[T3]):
1195+
...
1196+
1197+
with pytest.raises(SchemaInitError):
1198+
IntYGenericZModel.to_schema()
1199+
1200+
class IntYFloatZModel(IntYGenericZModel[float]):
1201+
...
1202+
1203+
IntYFloatZModel.validate(
1204+
pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]})
1205+
)
1206+
with pytest.raises(SchemaError):
1207+
IntYFloatZModel.validate(
1208+
pd.DataFrame({"y": [4.0, 5.0, 6.0], "z": [1, 2, 3]})
1209+
)

0 commit comments

Comments
 (0)