Skip to content

support pydantic v2 #1253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:

tests:
name: >
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }})
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }}, pydantic-${{ matrix.pydantic-version }})
runs-on: ${{ matrix.os }}
defaults:
run:
Expand All @@ -101,10 +101,11 @@ jobs:
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
pandas-version: ["1.3.0", "1.5.2", "2.0.1"]
pandas-version: ["1.5.3", "2.0.3"]
pydantic-version: ["1.10.11", "2.3.0"]
exclude:
- python-version: "3.7"
pandas-version: "2.0.1"
pandas-version: "2.0.3"
- python-version: "3.7"
pandas-version: "1.5.2"
- python-version: "3.10"
Expand Down Expand Up @@ -163,19 +164,26 @@ jobs:

# need to install pandas via pip: conda installation is on the fritz
- name: Install Conda Deps [pandas 2]
if: ${{ matrix.pandas-version == '2.0.1' }}
if: ${{ matrix.pandas-version == '2.0.3' }}
run: |
mamba install -c conda-forge asv pandas geopandas bokeh
mamba env update -n pandera-dev -f environment.yml
pip install pandas==${{ matrix.pandas-version }}
pip install --user dask>=2023.3.2

- name: Install Conda Deps
if: ${{ matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.pandas-version != '2.0.3' }}
run: |
mamba install -c conda-forge asv pandas==${{ matrix.pandas-version }} geopandas bokeh
mamba env update -n pandera-dev -f environment.yml

- name: Install Pydantic Deps
run: pip install -U --upgrade-strategy only-if-needed pydantic==${{ matrix.pydantic-version }}

- name: Install Pydantic v2 Deps
if : ${{ matrix.pydantic-version == '2.3.0' }}
run: pip install fastapi>=0.100.0

- run: |
conda info
conda list
Expand All @@ -200,21 +208,24 @@ jobs:
run: pytest tests/strategies ${{ env.PYTEST_FLAGS }} ${{ env.HYPOTHESIS_FLAGS }}

- name: Unit Tests - FastAPI
# there's an issue with the fastapi tests in CI that's not reproducible locally
# when pydantic > v2
if: ${{ matrix.python-version != '3.7' }}
run: pytest tests/fastapi ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - GeoPandas
run: pytest tests/geopandas ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Dask
if: ${{ matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.pandas-version != '2.0.3' }}
run: pytest tests/dask ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Pyspark
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
run: pytest tests/pyspark ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Modin-Dask
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
run: pytest tests/modin ${{ env.PYTEST_FLAGS }}
env:
CI_MODIN_ENGINES: dask
Expand All @@ -233,9 +244,9 @@ jobs:
uses: codecov/codecov-action@v3

- name: Check Docstrings
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
run: nox ${{ env.NOX_FLAGS }} --session doctests

- name: Check Docs
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
run: nox ${{ env.NOX_FLAGS }} --session docs
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ disable=
arguments-differ,
unnecessary-dunder-call,
use-dict-literal,
invalid-name
invalid-name,
import-outside-toplevel
2 changes: 1 addition & 1 deletion docs/source/dtype_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ For example:

from typing import Dict, List, Tuple, NamedTuple

if sys.version_info >= (3, 9):
if sys.version_info >= (3, 12):
from typing import TypedDict
# use typing_extensions.TypedDict for python < 3.9 in order to support
# run-time availability of optional/required fields
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- typing_extensions >= 3.7.4.3
- frictionless <= 4.40.8 # v5.* introduces breaking changes
- pyarrow
- pydantic < 2.0.0
- pydantic
- multimethod

# mypy extra
Expand Down
25 changes: 21 additions & 4 deletions pandera/api/pandas/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from pandera.api.hypotheses import Hypothesis
from pandera.api.pandas.types import CheckList, PandasDtypeInputTypes, is_field
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine
from pandera.engines import pandas_engine, PYDANTIC_V2

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler


TArraySchemaBase = TypeVar("TArraySchemaBase", bound="ArraySchema")

Expand Down Expand Up @@ -203,9 +208,21 @@ def __call__(
def __eq__(self, other):
return self.__dict__ == other.__dict__

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate
if PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls._pydantic_validate, # type: ignore[misc]
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

@classmethod
def _pydantic_validate( # type: ignore
Expand Down
24 changes: 20 additions & 4 deletions pandera/api/pandas/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
StrictType,
)
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine
from pandera.engines import pandas_engine, PYDANTIC_V2

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler

N_INDENT_SPACES = 4

Expand Down Expand Up @@ -516,9 +520,21 @@ def _compare_dict(obj):

return _compare_dict(self) == _compare_dict(other)

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate
if PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls._pydantic_validate,
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

@classmethod
def _pydantic_validate(cls, schema: Any) -> "DataFrameSchema":
Expand Down
73 changes: 47 additions & 26 deletions pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,21 @@
FieldInfo,
)
from pandera.api.pandas.model_config import BaseConfig
from pandera.engines import PYDANTIC_V2
from pandera.errors import SchemaInitError
from pandera.strategies import pandas_strategies as st
from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo
from pandera.typing.common import DataFrameBase

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetJsonSchemaHandler, GetCoreSchemaHandler

try:
from typing_extensions import get_type_hints
except ImportError:
from typing import get_type_hints # type: ignore

try:
from pydantic.fields import ModelField # pylint:disable=unused-import

HAS_PYDANTIC = True
except ImportError:
HAS_PYDANTIC = False


SchemaIndex = Union[Index, MultiIndex]

Expand Down Expand Up @@ -533,8 +531,19 @@ def _extract_df_checks(cls, check_infos: List[CheckInfo]) -> List[Check]:
return [check_info.to_check(cls) for check_info in check_infos]

@classmethod
def __get_validators__(cls):
yield cls.pydantic_validate
def get_metadata(cls) -> Optional[dict]:
"""Provide metadata for columns and schema level"""
res: Dict[Any, Any] = {"columns": {}}
columns = cls._collect_fields()

for k, (_, v) in columns.items():
res["columns"][k] = v.properties["metadata"]

res["dataframe"] = cls.Config.metadata

meta = {}
meta[cls.Config.name] = res
return meta

@classmethod
def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":
Expand All @@ -557,25 +566,37 @@ def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":

return cast("DataFrameModel", schema_model)

@classmethod
def get_metadata(cls) -> Optional[dict]:
"""Provide metadata for columns and schema level"""
res: Dict[Any, Any] = {"columns": {}}
columns = cls._collect_fields()

for k, (_, v) in columns.items():
res["columns"][k] = v.properties["metadata"]

res["dataframe"] = cls.Config.metadata
if PYDANTIC_V2:

meta = {}
meta[cls.Config.name] = res
return meta
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.pydantic_validate,
)

@classmethod
def __modify_schema__(cls, field_schema):
"""Update pydantic field schema."""
field_schema.update(_to_json_schema(cls.to_schema()))
@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema,
_handler: GetJsonSchemaHandler,
):
"""Update pydantic field schema."""
json_schema = _handler(_core_schema)
json_schema = _handler.resolve_ref_schema(json_schema)
json_schema.update(_to_json_schema(cls.to_schema()))

else:

@classmethod
def __modify_schema__(cls, field_schema):
"""Update pydantic field schema."""
field_schema.update(_to_json_schema(cls.to_schema()))

@classmethod
def __get_validators__(cls):
yield cls.pydantic_validate


SchemaModel = DataFrameModel
Expand Down
7 changes: 0 additions & 7 deletions pandera/api/pyspark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@
except ImportError: # pragma: no cover
from typing import get_type_hints # type: ignore

try:
from pydantic.fields import ModelField # pylint:disable=unused-import

HAS_PYDANTIC = True
except ImportError: # pragma: no cover
HAS_PYDANTIC = False


_CONFIG_KEY = "Config"
MODEL_CACHE: Dict[Type["DataFrameModel"], DataFrameSchema] = {}
Expand Down
21 changes: 13 additions & 8 deletions pandera/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Pandera configuration."""

import os
from enum import Enum
from pydantic import BaseSettings

from pydantic import BaseModel


class ValidationDepth(Enum):
Expand All @@ -12,7 +14,7 @@ class ValidationDepth(Enum):
SCHEMA_AND_DATA = "SCHEMA_AND_DATA"


class PanderaConfig(BaseSettings):
class PanderaConfig(BaseModel):
"""Pandera config base class.

This should pick up environment variables automatically, e.g.:
Expand All @@ -23,11 +25,14 @@ class PanderaConfig(BaseSettings):
validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA

class Config:
"""Pydantic configuration settings."""

env_prefix = "pandera_"


# this config variable should be accessible globally
CONFIG = PanderaConfig()
CONFIG = PanderaConfig(
validation_enabled=os.environ.get(
"PANDERA_VALIDATION_ENABLED",
True,
),
validation_depth=os.environ.get(
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
),
)
6 changes: 6 additions & 0 deletions pandera/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Pandera type engines."""

from pandera.engines.utils import pydantic_version


PYDANTIC_V2 = pydantic_version().release >= (2, 0, 0)
2 changes: 1 addition & 1 deletion pandera/engines/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


# register different TypedDict type depending on python version
if sys.version_info >= (3, 9):
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict # noqa
Expand Down
Loading