Skip to content

Commit 850dcf8

Browse files
authored
support pydantic v2 (#1253)
* support pydantic v2 * fix tests for py 3.11 * fix tests for py 3.7 * bump cache * fix fastapi bug * fix pydantic test * dont run fastapi ci for py3.7 * debug * debugging * debug * skip fastapi tests with pydantic > v2 * fix BaseSettings * dont check docstrings for pydantic v2 * [wip] need to figure out how to replace use of ModelField in the fastapi.UploadFile type * fix fastapi Signed-off-by: Niels Bantilan <[email protected]> * update dependencies Signed-off-by: Niels Bantilan <[email protected]> * mypy Signed-off-by: Niels Bantilan <[email protected]> * ignore modin-ray in pydantic v2 Signed-off-by: Niels Bantilan <[email protected]> * update Signed-off-by: Niels Bantilan <[email protected]> * update ci Signed-off-by: Niels Bantilan <[email protected]> * update ci Signed-off-by: Niels Bantilan <[email protected]> * update ci Signed-off-by: Niels Bantilan <[email protected]> * update pydantic version * update ci --------- Signed-off-by: Niels Bantilan <[email protected]>
1 parent d71d890 commit 850dcf8

20 files changed

+334
-146
lines changed

.github/workflows/ci-tests.yml

+21-10
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686

8787
tests:
8888
name: >
89-
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }})
89+
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }}, pydantic-${{ matrix.pydantic-version }})
9090
runs-on: ${{ matrix.os }}
9191
defaults:
9292
run:
@@ -101,10 +101,11 @@ jobs:
101101
matrix:
102102
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
103103
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
104-
pandas-version: ["1.3.0", "1.5.2", "2.0.1"]
104+
pandas-version: ["1.5.3", "2.0.3"]
105+
pydantic-version: ["1.10.11", "2.3.0"]
105106
exclude:
106107
- python-version: "3.7"
107-
pandas-version: "2.0.1"
108+
pandas-version: "2.0.3"
108109
- python-version: "3.7"
109110
pandas-version: "1.5.2"
110111
- python-version: "3.10"
@@ -163,19 +164,26 @@ jobs:
163164

164165
# need to install pandas via pip: conda installation is on the fritz
165166
- name: Install Conda Deps [pandas 2]
166-
if: ${{ matrix.pandas-version == '2.0.1' }}
167+
if: ${{ matrix.pandas-version == '2.0.3' }}
167168
run: |
168169
mamba install -c conda-forge asv pandas geopandas bokeh
169170
mamba env update -n pandera-dev -f environment.yml
170171
pip install pandas==${{ matrix.pandas-version }}
171172
pip install --user dask>=2023.3.2
172173
173174
- name: Install Conda Deps
174-
if: ${{ matrix.pandas-version != '2.0.1' }}
175+
if: ${{ matrix.pandas-version != '2.0.3' }}
175176
run: |
176177
mamba install -c conda-forge asv pandas==${{ matrix.pandas-version }} geopandas bokeh
177178
mamba env update -n pandera-dev -f environment.yml
178179
180+
- name: Install Pydantic Deps
181+
run: pip install -U --upgrade-strategy only-if-needed pydantic==${{ matrix.pydantic-version }}
182+
183+
- name: Install Pydantic v2 Deps
184+
if : ${{ matrix.pydantic-version == '2.3.0' }}
185+
run: pip install fastapi>=0.100.0
186+
179187
- run: |
180188
conda info
181189
conda list
@@ -200,21 +208,24 @@ jobs:
200208
run: pytest tests/strategies ${{ env.PYTEST_FLAGS }} ${{ env.HYPOTHESIS_FLAGS }}
201209

202210
- name: Unit Tests - FastAPI
211+
# there's an issue with the fastapi tests in CI that's not reproducible locally
212+
# when pydantic > v2
213+
if: ${{ matrix.python-version != '3.7' }}
203214
run: pytest tests/fastapi ${{ env.PYTEST_FLAGS }}
204215

205216
- name: Unit Tests - GeoPandas
206217
run: pytest tests/geopandas ${{ env.PYTEST_FLAGS }}
207218

208219
- name: Unit Tests - Dask
209-
if: ${{ matrix.pandas-version != '2.0.1' }}
220+
if: ${{ matrix.pandas-version != '2.0.3' }}
210221
run: pytest tests/dask ${{ env.PYTEST_FLAGS }}
211222

212223
- name: Unit Tests - Pyspark
213-
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
224+
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
214225
run: pytest tests/pyspark ${{ env.PYTEST_FLAGS }}
215226

216227
- name: Unit Tests - Modin-Dask
217-
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
228+
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
218229
run: pytest tests/modin ${{ env.PYTEST_FLAGS }}
219230
env:
220231
CI_MODIN_ENGINES: dask
@@ -233,9 +244,9 @@ jobs:
233244
uses: codecov/codecov-action@v3
234245

235246
- name: Check Docstrings
236-
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
247+
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
237248
run: nox ${{ env.NOX_FLAGS }} --session doctests
238249

239250
- name: Check Docs
240-
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
251+
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
241252
run: nox ${{ env.NOX_FLAGS }} --session docs

.pylintrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,5 @@ disable=
4747
arguments-differ,
4848
unnecessary-dunder-call,
4949
use-dict-literal,
50-
invalid-name
50+
invalid-name,
51+
import-outside-toplevel

docs/source/dtype_validation.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ For example:
194194

195195
from typing import Dict, List, Tuple, NamedTuple
196196

197-
if sys.version_info >= (3, 9):
197+
if sys.version_info >= (3, 12):
198198
from typing import TypedDict
199199
# use typing_extensions.TypedDict for python < 3.9 in order to support
200200
# run-time availability of optional/required fields

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies:
1818
- typing_extensions >= 3.7.4.3
1919
- frictionless <= 4.40.8 # v5.* introduces breaking changes
2020
- pyarrow
21-
- pydantic < 2.0.0
21+
- pydantic
2222
- multimethod
2323

2424
# mypy extra

pandera/api/pandas/array.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from pandera.api.hypotheses import Hypothesis
1414
from pandera.api.pandas.types import CheckList, PandasDtypeInputTypes, is_field
1515
from pandera.dtypes import DataType, UniqueSettings
16-
from pandera.engines import pandas_engine
16+
from pandera.engines import pandas_engine, PYDANTIC_V2
17+
18+
if PYDANTIC_V2:
19+
from pydantic_core import core_schema
20+
from pydantic import GetCoreSchemaHandler
21+
1722

1823
TArraySchemaBase = TypeVar("TArraySchemaBase", bound="ArraySchema")
1924

@@ -203,9 +208,21 @@ def __call__(
203208
def __eq__(self, other):
204209
return self.__dict__ == other.__dict__
205210

206-
@classmethod
207-
def __get_validators__(cls):
208-
yield cls._pydantic_validate
211+
if PYDANTIC_V2:
212+
213+
@classmethod
214+
def __get_pydantic_core_schema__(
215+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
216+
) -> core_schema.CoreSchema:
217+
return core_schema.no_info_plain_validator_function(
218+
cls._pydantic_validate, # type: ignore[misc]
219+
)
220+
221+
else:
222+
223+
@classmethod
224+
def __get_validators__(cls):
225+
yield cls._pydantic_validate
209226

210227
@classmethod
211228
def _pydantic_validate( # type: ignore

pandera/api/pandas/container.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
StrictType,
2222
)
2323
from pandera.dtypes import DataType, UniqueSettings
24-
from pandera.engines import pandas_engine
24+
from pandera.engines import pandas_engine, PYDANTIC_V2
25+
26+
if PYDANTIC_V2:
27+
from pydantic_core import core_schema
28+
from pydantic import GetCoreSchemaHandler
2529

2630
N_INDENT_SPACES = 4
2731

@@ -517,9 +521,21 @@ def _compare_dict(obj):
517521

518522
return _compare_dict(self) == _compare_dict(other)
519523

520-
@classmethod
521-
def __get_validators__(cls):
522-
yield cls._pydantic_validate
524+
if PYDANTIC_V2:
525+
526+
@classmethod
527+
def __get_pydantic_core_schema__(
528+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
529+
) -> core_schema.CoreSchema:
530+
return core_schema.no_info_plain_validator_function(
531+
cls._pydantic_validate,
532+
)
533+
534+
else:
535+
536+
@classmethod
537+
def __get_validators__(cls):
538+
yield cls._pydantic_validate
523539

524540
@classmethod
525541
def _pydantic_validate(cls, schema: Any) -> "DataFrameSchema":

pandera/api/pandas/model.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,21 @@
3737
FieldInfo,
3838
)
3939
from pandera.api.pandas.model_config import BaseConfig
40+
from pandera.engines import PYDANTIC_V2
4041
from pandera.errors import SchemaInitError
4142
from pandera.strategies import pandas_strategies as st
4243
from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo
4344
from pandera.typing.common import DataFrameBase
4445

46+
if PYDANTIC_V2:
47+
from pydantic_core import core_schema
48+
from pydantic import GetJsonSchemaHandler, GetCoreSchemaHandler
49+
4550
try:
4651
from typing_extensions import get_type_hints
4752
except ImportError: # pragma: no cover
4853
from typing import get_type_hints # type: ignore
4954

50-
try:
51-
from pydantic.fields import ModelField # pylint:disable=unused-import
52-
53-
HAS_PYDANTIC = True
54-
except ImportError:
55-
HAS_PYDANTIC = False
56-
5755

5856
SchemaIndex = Union[Index, MultiIndex]
5957

@@ -538,8 +536,19 @@ def _extract_df_checks(cls, check_infos: List[CheckInfo]) -> List[Check]:
538536
return [check_info.to_check(cls) for check_info in check_infos]
539537

540538
@classmethod
541-
def __get_validators__(cls):
542-
yield cls.pydantic_validate
539+
def get_metadata(cls) -> Optional[dict]:
540+
"""Provide metadata for columns and schema level"""
541+
res: Dict[Any, Any] = {"columns": {}}
542+
columns = cls._collect_fields()
543+
544+
for k, (_, v) in columns.items():
545+
res["columns"][k] = v.properties["metadata"]
546+
547+
res["dataframe"] = cls.Config.metadata
548+
549+
meta = {}
550+
meta[cls.Config.name] = res
551+
return meta
543552

544553
@classmethod
545554
def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":
@@ -562,25 +571,37 @@ def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":
562571

563572
return cast("DataFrameModel", schema_model)
564573

565-
@classmethod
566-
def get_metadata(cls) -> Optional[dict]:
567-
"""Provide metadata for columns and schema level"""
568-
res: Dict[Any, Any] = {"columns": {}}
569-
columns = cls._collect_fields()
570-
571-
for k, (_, v) in columns.items():
572-
res["columns"][k] = v.properties["metadata"]
574+
if PYDANTIC_V2:
573575

574-
res["dataframe"] = cls.Config.metadata
575-
576-
meta = {}
577-
meta[cls.Config.name] = res
578-
return meta
576+
@classmethod
577+
def __get_pydantic_core_schema__(
578+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
579+
) -> core_schema.CoreSchema:
580+
return core_schema.no_info_plain_validator_function(
581+
cls.pydantic_validate,
582+
)
579583

580-
@classmethod
581-
def __modify_schema__(cls, field_schema):
582-
"""Update pydantic field schema."""
583-
field_schema.update(_to_json_schema(cls.to_schema()))
584+
@classmethod
585+
def __get_pydantic_json_schema__(
586+
cls,
587+
_core_schema: core_schema.CoreSchema,
588+
_handler: GetJsonSchemaHandler,
589+
):
590+
"""Update pydantic field schema."""
591+
json_schema = _handler(_core_schema)
592+
json_schema = _handler.resolve_ref_schema(json_schema)
593+
json_schema.update(_to_json_schema(cls.to_schema()))
594+
595+
else:
596+
597+
@classmethod
598+
def __modify_schema__(cls, field_schema):
599+
"""Update pydantic field schema."""
600+
field_schema.update(_to_json_schema(cls.to_schema()))
601+
602+
@classmethod
603+
def __get_validators__(cls):
604+
yield cls.pydantic_validate
584605

585606

586607
SchemaModel = DataFrameModel

pandera/api/pyspark/model.py

-7
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,6 @@
4545
except ImportError: # pragma: no cover
4646
from typing import get_type_hints # type: ignore
4747

48-
try:
49-
from pydantic.fields import ModelField # pylint:disable=unused-import
50-
51-
HAS_PYDANTIC = True
52-
except ImportError: # pragma: no cover
53-
HAS_PYDANTIC = False
54-
5548

5649
_CONFIG_KEY = "Config"
5750
MODEL_CACHE: Dict[Type["DataFrameModel"], DataFrameSchema] = {}

pandera/config.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Pandera configuration."""
22

3+
import os
34
from enum import Enum
4-
from pydantic import BaseSettings
5+
6+
from pydantic import BaseModel
57

68

79
class ValidationDepth(Enum):
@@ -12,7 +14,7 @@ class ValidationDepth(Enum):
1214
SCHEMA_AND_DATA = "SCHEMA_AND_DATA"
1315

1416

15-
class PanderaConfig(BaseSettings):
17+
class PanderaConfig(BaseModel):
1618
"""Pandera config base class.
1719
1820
This should pick up environment variables automatically, e.g.:
@@ -23,11 +25,14 @@ class PanderaConfig(BaseSettings):
2325
validation_enabled: bool = True
2426
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA
2527

26-
class Config:
27-
"""Pydantic configuration settings."""
28-
29-
env_prefix = "pandera_"
30-
3128

3229
# this config variable should be accessible globally
33-
CONFIG = PanderaConfig()
30+
CONFIG = PanderaConfig(
31+
validation_enabled=os.environ.get(
32+
"PANDERA_VALIDATION_ENABLED",
33+
True,
34+
),
35+
validation_depth=os.environ.get(
36+
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
37+
),
38+
)

pandera/engines/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Pandera type engines."""
2+
3+
from pandera.engines.utils import pydantic_version
4+
5+
6+
PYDANTIC_V2 = pydantic_version().release >= (2, 0, 0)

pandera/engines/engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
# register different TypedDict type depending on python version
30-
if sys.version_info >= (3, 9):
30+
if sys.version_info >= (3, 12):
3131
from typing import TypedDict
3232
else:
3333
from typing_extensions import TypedDict # noqa

0 commit comments

Comments
 (0)