Skip to content

Commit f85a9c6

Browse files
fix(types): handle more discriminated union shapes (#2206)
1 parent e2def44 commit f85a9c6

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/openai/_models.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ._constants import RAW_RESPONSE_HEADER
6767

6868
if TYPE_CHECKING:
69-
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
69+
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
7070

7171
__all__ = ["BaseModel", "GenericModel"]
7272

@@ -671,15 +671,18 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
671671

672672
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
673673
schema = model.__pydantic_core_schema__
674+
if schema["type"] == "definitions":
675+
schema = schema["schema"]
676+
674677
if schema["type"] != "model":
675678
return None
676679

680+
schema = cast("ModelSchema", schema)
677681
fields_schema = schema["schema"]
678682
if fields_schema["type"] != "model-fields":
679683
return None
680684

681685
fields_schema = cast("ModelFieldsSchema", fields_schema)
682-
683686
field = fields_schema["fields"].get(field_name)
684687
if not field:
685688
return None

tests/test_models.py

+32
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,35 @@ class Model(BaseModel):
854854
m = construct_type(value={"cls": "foo"}, type_=Model)
855855
assert isinstance(m, Model)
856856
assert isinstance(m.cls, str)
857+
858+
859+
def test_discriminated_union_case() -> None:
860+
class A(BaseModel):
861+
type: Literal["a"]
862+
863+
data: bool
864+
865+
class B(BaseModel):
866+
type: Literal["b"]
867+
868+
data: List[Union[A, object]]
869+
870+
class ModelA(BaseModel):
871+
type: Literal["modelA"]
872+
873+
data: int
874+
875+
class ModelB(BaseModel):
876+
type: Literal["modelB"]
877+
878+
required: str
879+
880+
data: Union[A, B]
881+
882+
# when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
883+
m = construct_type(
884+
value={"type": "modelB", "data": {"type": "a", "data": True}},
885+
type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
886+
)
887+
888+
assert isinstance(m, ModelB)

0 commit comments

Comments
 (0)