|
3 | 3 | import re
|
4 | 4 | from copy import deepcopy
|
5 | 5 | from decimal import Decimal # pylint:disable=C0415
|
6 |
| -from typing import Any, Iterable, Optional |
| 6 | +from typing import Any, Generic, Iterable, Optional, TypeVar |
7 | 7 |
|
8 | 8 | import pandas as pd
|
9 | 9 | import pytest
|
10 | 10 |
|
11 | 11 | import pandera as pa
|
12 | 12 | import pandera.extensions as pax
|
| 13 | +from pandera.errors import SchemaError, SchemaInitError |
13 | 14 | from pandera.typing import DataFrame, Index, Series, String
|
14 | 15 |
|
15 | 16 |
|
@@ -1002,3 +1003,207 @@ class Config:
|
1002 | 1003 | match="^expected series 'price' to have type float64, got int64$",
|
1003 | 1004 | ):
|
1004 | 1005 | DataFrame[SchemaNoCoerce](raw_data)
|
| 1006 | + |
| 1007 | + |
| 1008 | +def test_schema_model_generic_inheritance() -> None: |
| 1009 | + """Test that a schema model subclass can also inherit from typing.Generic""" |
| 1010 | + |
| 1011 | + T = TypeVar("T") |
| 1012 | + |
| 1013 | + class Foo(pa.SchemaModel, Generic[T]): |
| 1014 | + @classmethod |
| 1015 | + def bar(cls) -> T: |
| 1016 | + raise NotImplementedError |
| 1017 | + |
| 1018 | + class Bar1(Foo[int]): |
| 1019 | + @classmethod |
| 1020 | + def bar(cls) -> int: |
| 1021 | + return 1 |
| 1022 | + |
| 1023 | + class Bar2(Foo[str]): |
| 1024 | + @classmethod |
| 1025 | + def bar(cls) -> str: |
| 1026 | + return "1" |
| 1027 | + |
| 1028 | + with pytest.raises(NotImplementedError): |
| 1029 | + Foo.bar() |
| 1030 | + assert Bar1.bar() == 1 |
| 1031 | + assert Bar2.bar() == "1" |
| 1032 | + |
| 1033 | + |
| 1034 | +def test_generic_no_generic_fields() -> None: |
| 1035 | + T = TypeVar("T", int, float, str) |
| 1036 | + |
| 1037 | + class GenericModel(pa.SchemaModel, Generic[T]): |
| 1038 | + x: Series[int] |
| 1039 | + |
| 1040 | + GenericModel.to_schema() |
| 1041 | + |
| 1042 | + |
| 1043 | +def test_generic_model_single_generic_field() -> None: |
| 1044 | + T = TypeVar("T", int, float, str) |
| 1045 | + |
| 1046 | + class GenericModel(pa.SchemaModel, Generic[T]): |
| 1047 | + x: Series[int] |
| 1048 | + y: Series[T] |
| 1049 | + |
| 1050 | + with pytest.raises(SchemaInitError): |
| 1051 | + GenericModel.to_schema() |
| 1052 | + |
| 1053 | + class IntModel(GenericModel[int]): |
| 1054 | + ... |
| 1055 | + |
| 1056 | + IntModel.to_schema() |
| 1057 | + |
| 1058 | + class FloatModel(GenericModel[float]): |
| 1059 | + ... |
| 1060 | + |
| 1061 | + FloatModel.to_schema() |
| 1062 | + |
| 1063 | + IntModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) |
| 1064 | + with pytest.raises(SchemaError): |
| 1065 | + FloatModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) |
| 1066 | + |
| 1067 | + with pytest.raises(SchemaError): |
| 1068 | + IntModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5, 6]})) |
| 1069 | + FloatModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5, 6]})) |
| 1070 | + |
| 1071 | + |
| 1072 | +def test_generic_optional_field() -> None: |
| 1073 | + T = TypeVar("T", int, float, str) |
| 1074 | + |
| 1075 | + class GenericModel(pa.SchemaModel, Generic[T]): |
| 1076 | + x: Series[int] |
| 1077 | + y: Optional[Series[T]] |
| 1078 | + |
| 1079 | + class IntYModel(GenericModel[int]): |
| 1080 | + ... |
| 1081 | + |
| 1082 | + IntYModel.validate(pd.DataFrame({"x": [1, 2, 3]})) |
| 1083 | + IntYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) |
| 1084 | + with pytest.raises(SchemaError): |
| 1085 | + IntYModel.validate( |
| 1086 | + pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]}) |
| 1087 | + ) |
| 1088 | + |
| 1089 | + class FloatYModel(GenericModel[float]): |
| 1090 | + ... |
| 1091 | + |
| 1092 | + FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3]})) |
| 1093 | + with pytest.raises(SchemaError): |
| 1094 | + FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) |
| 1095 | + FloatYModel.validate(pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]})) |
| 1096 | + |
| 1097 | + |
| 1098 | +def test_generic_model_multiple_inheritance() -> None: |
| 1099 | + T = TypeVar("T", int, float, str) |
| 1100 | + |
| 1101 | + class GenericYModel(pa.SchemaModel, Generic[T]): |
| 1102 | + x: Series[int] |
| 1103 | + y: Series[T] |
| 1104 | + |
| 1105 | + class GenericZModel(pa.SchemaModel, Generic[T]): |
| 1106 | + z: Series[T] |
| 1107 | + |
| 1108 | + class IntYFloatZModel(GenericYModel[int], GenericZModel[float]): |
| 1109 | + ... |
| 1110 | + |
| 1111 | + IntYFloatZModel.to_schema() |
| 1112 | + IntYFloatZModel.validate( |
| 1113 | + pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}) |
| 1114 | + ) |
| 1115 | + with pytest.raises(SchemaError): |
| 1116 | + IntYFloatZModel.validate( |
| 1117 | + pd.DataFrame( |
| 1118 | + {"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]} |
| 1119 | + ) |
| 1120 | + ) |
| 1121 | + with pytest.raises(SchemaError): |
| 1122 | + IntYFloatZModel.validate( |
| 1123 | + pd.DataFrame( |
| 1124 | + {"x": ["a", "b", "c"], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]} |
| 1125 | + ) |
| 1126 | + ) |
| 1127 | + |
| 1128 | + class FloatYIntZModel(GenericYModel[float], GenericZModel[int]): |
| 1129 | + ... |
| 1130 | + |
| 1131 | + FloatYIntZModel.to_schema() |
| 1132 | + with pytest.raises(SchemaError): |
| 1133 | + FloatYIntZModel.validate( |
| 1134 | + pd.DataFrame( |
| 1135 | + {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1.0, 2.0, 3.0]} |
| 1136 | + ) |
| 1137 | + ) |
| 1138 | + FloatYIntZModel.validate( |
| 1139 | + pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}) |
| 1140 | + ) |
| 1141 | + with pytest.raises(SchemaError): |
| 1142 | + FloatYIntZModel.validate( |
| 1143 | + pd.DataFrame( |
| 1144 | + {"x": ["a", "b", "c"], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]} |
| 1145 | + ) |
| 1146 | + ) |
| 1147 | + |
| 1148 | + |
| 1149 | +def test_multiple_generic() -> None: |
| 1150 | + """Test that a generic schema with multiple types is handled correctly""" |
| 1151 | + T1 = TypeVar("T1", int, float, str) |
| 1152 | + T2 = TypeVar("T2", int, float, str) |
| 1153 | + |
| 1154 | + class GenericModel(pa.SchemaModel, Generic[T1, T2]): |
| 1155 | + y: Series[T1] |
| 1156 | + z: Series[T2] |
| 1157 | + |
| 1158 | + class IntYFloatZModel(GenericModel[int, float]): |
| 1159 | + ... |
| 1160 | + |
| 1161 | + IntYFloatZModel.to_schema() |
| 1162 | + IntYFloatZModel.to_schema() |
| 1163 | + IntYFloatZModel.validate( |
| 1164 | + pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}) |
| 1165 | + ) |
| 1166 | + with pytest.raises(SchemaError): |
| 1167 | + IntYFloatZModel.validate( |
| 1168 | + pd.DataFrame({"y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}) |
| 1169 | + ) |
| 1170 | + |
| 1171 | + class FloatYIntZModel(GenericModel[float, int]): |
| 1172 | + ... |
| 1173 | + |
| 1174 | + FloatYIntZModel.to_schema() |
| 1175 | + with pytest.raises(SchemaError): |
| 1176 | + FloatYIntZModel.validate( |
| 1177 | + pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}) |
| 1178 | + ) |
| 1179 | + FloatYIntZModel.validate( |
| 1180 | + pd.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}) |
| 1181 | + ) |
| 1182 | + |
| 1183 | + |
| 1184 | +def test_repeated_generic() -> None: |
| 1185 | + """Test that repeated use of Generic in a class hierachy results in the correct types""" |
| 1186 | + T1 = TypeVar("T1", int, float, str) |
| 1187 | + T2 = TypeVar("T2", int, float, str) |
| 1188 | + T3 = TypeVar("T3", int, float, str) |
| 1189 | + |
| 1190 | + class GenericYZModel(pa.SchemaModel, Generic[T1, T2]): |
| 1191 | + y: Series[T1] |
| 1192 | + z: Series[T2] |
| 1193 | + |
| 1194 | + class IntYGenericZModel(GenericYZModel[int, T3], Generic[T3]): |
| 1195 | + ... |
| 1196 | + |
| 1197 | + with pytest.raises(SchemaInitError): |
| 1198 | + IntYGenericZModel.to_schema() |
| 1199 | + |
| 1200 | + class IntYFloatZModel(IntYGenericZModel[float]): |
| 1201 | + ... |
| 1202 | + |
| 1203 | + IntYFloatZModel.validate( |
| 1204 | + pd.DataFrame({"y": [4, 5, 6], "z": [1.0, 2.0, 3.0]}) |
| 1205 | + ) |
| 1206 | + with pytest.raises(SchemaError): |
| 1207 | + IntYFloatZModel.validate( |
| 1208 | + pd.DataFrame({"y": [4.0, 5.0, 6.0], "z": [1, 2, 3]}) |
| 1209 | + ) |
0 commit comments