Skip to content

Pydantic 2 support #847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _submit_request_and_process(
data_model(
**{
field: value
for field, value in raw_doc.dict().items()
for field, value in raw_doc.model_dump().items()
if field in set_fields
}
)
Expand Down Expand Up @@ -877,29 +877,29 @@ def _submit_request_and_process(

def _generate_returned_model(self, doc):
set_fields = [
field for field, _ in doc if field in doc.dict(exclude_unset=True)
field for field, _ in doc if field in doc.model_dump(exclude_unset=True)
]
unset_fields = [field for field in doc.__fields__ if field not in set_fields]
unset_fields = [field for field in doc.model_fields if field not in set_fields]

data_model = create_model(
"MPDataDoc",
fields_not_requested=unset_fields,
fields_not_requested=(list[str], unset_fields),
__base__=self.document_model,
)

data_model.__fields__ = {
data_model.model_fields = {
**{
name: description
for name, description in data_model.__fields__.items()
for name, description in data_model.model_fields.items()
if name in set_fields
},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
"fields_not_requested": data_model.model_fields["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
for n in data_model.model_fields
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
Expand All @@ -908,7 +908,7 @@ def new_repr(self) -> str:
def new_str(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
for n in data_model.model_fields
if n != "fields_not_requested"
)

Expand All @@ -927,7 +927,7 @@ def new_getattr(self, attr) -> str:
)

def new_dict(self, *args, **kwargs):
d = super(data_model, self).dict(*args, **kwargs)
d = super(data_model, self).model_dump(*args, **kwargs)
return jsanitize(d)

data_model.__repr__ = new_repr
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def count(self, criteria: dict | None = None) -> int | str:
def available_fields(self) -> list[str]:
if self.document_model is None:
return ["Unknown fields."]
return list(self.document_model.schema()["properties"].keys()) # type: ignore
return list(self.document_model.model_json_schema()["properties"].keys()) # type: ignore

def __repr__(self): # pragma: no cover
return f"<{self.__class__.__name__} {self.endpoint}>"
Expand Down
3 changes: 2 additions & 1 deletion mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from multiprocessing import cpu_count
from typing import List

from pydantic import BaseSettings, Field
from pydantic import Field
from pydantic_settings import BaseSettings
from pymatgen.core import _load_pmg_settings

from mp_api.client import __file__ as root_dir
Expand Down
35 changes: 18 additions & 17 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import re
from functools import cache
from typing import get_args
from typing import Optional, get_args

from maggma.utils import get_flat_models_from_model
from monty.json import MSONable
from pydantic import BaseModel
from pydantic.schema import get_flat_models_from_model
from pydantic.utils import lenient_issubclass
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo


def validate_ids(id_list: list[str]):
Expand Down Expand Up @@ -62,33 +63,33 @@ def api_sanitize(

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.__fields__.items():
field_type = field.type_

if name not in model_fields_to_leave:
field.required = False
field.default = None
field.default_factory = None
field.allow_none = True
field.field_info.default = None
field.field_info.default_factory = None
for name in model.model_fields:
field = model.model_fields[name]
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field.type_ = allow_msonable_dict(field_type)
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)
field.populate_validators()

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
model.model_fields[name] = new_field

model.model_rebuild(force=True)

return pydantic_model


def allow_msonable_dict(monty_cls: type[MSONable]):
"""Patch Monty to allow for dict values for MSONable."""

def validate_monty(cls, v):
def validate_monty(cls, v, _):
"""Stub validator for MSONable as a dictionary only."""
if isinstance(v, cls):
return v
Expand All @@ -110,6 +111,6 @@ def validate_monty(cls, v):
else:
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")

monty_cls.validate_monty = classmethod(validate_monty)
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls
2 changes: 1 addition & 1 deletion mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def get_entries(
if property_data:
for property in property_data:
entry_dict["data"][property] = (
doc.dict()[property]
doc.model_dump()[property]
if self.use_document_model
else doc[property]
)
Expand Down
6 changes: 3 additions & 3 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_bandstructure_from_material_id(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()
bs_data = bs_data.model_dump()

if bs_data.get(path_type.value, None):
bs_task_id = bs_data[path_type.value]["task_id"]
Expand All @@ -303,7 +303,7 @@ def get_bandstructure_from_material_id(
f"No uniform band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()
bs_data = bs_data.model_dump()

if bs_data.get("total", None):
bs_task_id = bs_data["total"]["1"]["task_id"]
Expand Down Expand Up @@ -444,7 +444,7 @@ def get_dos_from_material_id(self, material_id: str):

dos_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dict()
).model_dump()

if dos_data["dos"]:
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ classifiers = [
dependencies = [
"setuptools",
"msgpack",
"maggma",
"pymatgen>=2022.3.7",
"typing-extensions>=3.7.4.1",
"requests>=2.23.0",
"monty>=2021.3.12",
"emmet-core>=0.54.0",
"monty>=2023.9.25",
"emmet-core>=0.69.2",
]
dynamic = ["version"]

[project.optional-dependencies]
all = ["emmet-core[all]>=0.54.0", "custodian", "mpcontribs-client", "boto3"]
all = ["emmet-core[all]>=0.69.1", "custodian", "mpcontribs-client", "boto3"]
test = [
"pre-commit",
"pytest",
Expand Down
2 changes: 1 addition & 1 deletion tests/materials/core_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def client_search_testing(
"num_chunks": 1,
}

doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()

for sub_field in sub_doc_fields:
if sub_field in doc:
Expand Down
4 changes: 2 additions & 2 deletions tests/materials/test_electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_bs_client(bs_rester):
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()

for sub_field in bs_sub_doc_fields:
if sub_field in doc:
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_dos_client(dos_rester):
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()
for sub_field in dos_sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]
Expand Down
2 changes: 1 addition & 1 deletion tests/molecules/core_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def client_search_testing(
docs = search_method(**q)

if len(docs) > 0:
doc = docs[0].dict()
doc = docs[0].model_dump()
else:
raise ValueError("No documents returned")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_generic_get_methods(rester):

if name not in search_only_resters:
doc = rester.get_data_by_id(
doc.dict()[rester.primary_key], fields=[rester.primary_key]
doc.model_dump()[rester.primary_key], fields=[rester.primary_key]
)
assert isinstance(doc, rester.document_model)

Expand Down