Skip to content

Commit c7a3d29

Browse files
committed
fix(json schema): unravel $refs alongside additional keys
1 parent 53d964d commit c7a3d29

File tree

3 files changed

+119
-17
lines changed

3 files changed

+119
-17
lines changed

src/openai/lib/_pydantic.py

+59-14
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,32 @@
1010

1111

1212
def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
13-
return _ensure_strict_json_schema(model_json_schema(model), path=())
13+
schema = model_json_schema(model)
14+
return _ensure_strict_json_schema(schema, path=(), root=schema)
1415

1516

1617
def _ensure_strict_json_schema(
1718
json_schema: object,
19+
*,
1820
path: tuple[str, ...],
21+
root: dict[str, object],
1922
) -> dict[str, Any]:
2023
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
2124
that the API expects.
2225
"""
2326
if not is_dict(json_schema):
2427
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
2528

29+
defs = json_schema.get("$defs")
30+
if is_dict(defs):
31+
for def_name, def_schema in defs.items():
32+
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
33+
34+
definitions = json_schema.get("definitions")
35+
if is_dict(definitions):
36+
for definition_name, definition_schema in definitions.items():
37+
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
38+
2639
typ = json_schema.get("type")
2740
if typ == "object" and "additionalProperties" not in json_schema:
2841
json_schema["additionalProperties"] = False
@@ -33,48 +46,80 @@ def _ensure_strict_json_schema(
3346
if is_dict(properties):
3447
json_schema["required"] = [prop for prop in properties.keys()]
3548
json_schema["properties"] = {
36-
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key))
49+
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
3750
for key, prop_schema in properties.items()
3851
}
3952

4053
# arrays
4154
# { 'type': 'array', 'items': {...} }
4255
items = json_schema.get("items")
4356
if is_dict(items):
44-
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"))
57+
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
4558

4659
# unions
4760
any_of = json_schema.get("anyOf")
4861
if is_list(any_of):
4962
json_schema["anyOf"] = [
50-
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i))) for i, variant in enumerate(any_of)
63+
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
64+
for i, variant in enumerate(any_of)
5165
]
5266

5367
# intersections
5468
all_of = json_schema.get("allOf")
5569
if is_list(all_of):
5670
if len(all_of) == 1:
57-
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
71+
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
5872
json_schema.pop("allOf")
5973
else:
6074
json_schema["allOf"] = [
61-
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
75+
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
76+
for i, entry in enumerate(all_of)
6277
]
6378

64-
defs = json_schema.get("$defs")
65-
if is_dict(defs):
66-
for def_name, def_schema in defs.items():
67-
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name))
79+
# we can't use `$ref`s if there are also other properties defined, e.g.
80+
# `{"$ref": "...", "description": "my description"}`
81+
#
82+
# so we unravel the ref
83+
# `{"type": "string", "description": "my description"}`
84+
ref = json_schema.get("$ref")
85+
if ref and has_more_than_n_keys(json_schema, 1):
86+
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
6887

69-
definitions = json_schema.get("definitions")
70-
if is_dict(definitions):
71-
for definition_name, definition_schema in definitions.items():
72-
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name))
88+
resolved = resolve_ref(root=root, ref=ref)
89+
if not is_dict(resolved):
90+
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
91+
92+
# properties from the json schema take priority over the ones on the `$ref`
93+
json_schema.update({**resolved, **json_schema})
94+
json_schema.pop("$ref")
7395

7496
return json_schema
7597

7698

99+
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
100+
if not ref.startswith("#/"):
101+
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
102+
103+
path = ref[2:].split("/")
104+
resolved = root
105+
for key in path:
106+
value = resolved[key]
107+
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
108+
resolved = value
109+
110+
return resolved
111+
112+
77113
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
78114
# just pretend that we know there are only `str` keys
79115
# as that check is not worth the performance cost
80116
return _is_dict(obj)
117+
118+
119+
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
120+
i = 0
121+
for _ in obj.keys():
122+
i += 1
123+
if i > n:
124+
return True
125+
return False

tests/lib/chat/test_completions.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import os
44
import json
5+
from enum import Enum
56
from typing import Any, Callable
67
from typing_extensions import Literal, TypeVar
78

89
import httpx
910
import pytest
1011
from respx import MockRouter
11-
from pydantic import BaseModel
12+
from pydantic import Field, BaseModel
1213
from inline_snapshot import snapshot
1314

1415
import openai
@@ -133,6 +134,53 @@ class Location(BaseModel):
133134
)
134135

135136

137+
@pytest.mark.respx(base_url=base_url)
138+
def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
139+
class Color(Enum):
140+
"""The detected color"""
141+
142+
RED = "red"
143+
BLUE = "blue"
144+
GREEN = "green"
145+
146+
class ColorDetection(BaseModel):
147+
color: Color
148+
hex_color_code: str = Field(description="The hex color code of the detected color")
149+
150+
completion = _make_snapshot_request(
151+
lambda c: c.beta.chat.completions.parse(
152+
model="gpt-4o-2024-08-06",
153+
messages=[
154+
{"role": "user", "content": "What color is a Coke can?"},
155+
],
156+
response_format=ColorDetection,
157+
),
158+
content_snapshot=snapshot(
159+
'{"id": "chatcmpl-9vK4UZVr385F2UgZlP1ShwPn2nFxG", "object": "chat.completion", "created": 1723448878, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"color\\":\\"red\\",\\"hex_color_code\\":\\"#FF0000\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 18, "completion_tokens": 14, "total_tokens": 32}, "system_fingerprint": "fp_845eaabc1f"}'
160+
),
161+
mock_client=client,
162+
respx_mock=respx_mock,
163+
)
164+
165+
assert print_obj(completion.choices[0], monkeypatch) == snapshot(
166+
"""\
167+
ParsedChoice[ColorDetection](
168+
finish_reason='stop',
169+
index=0,
170+
logprobs=None,
171+
message=ParsedChatCompletionMessage[ColorDetection](
172+
content='{"color":"red","hex_color_code":"#FF0000"}',
173+
function_call=None,
174+
parsed=ColorDetection(color=<Color.RED: 'red'>, hex_color_code='#FF0000'),
175+
refusal=None,
176+
role='assistant',
177+
tool_calls=[]
178+
)
179+
)
180+
"""
181+
)
182+
183+
136184
@pytest.mark.respx(base_url=base_url)
137185
def test_parse_pydantic_model_multiple_choices(
138186
client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch

tests/lib/test_pydantic.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,12 @@ def test_enums() -> None:
186186
"parameters": {
187187
"$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
188188
"properties": {
189-
"color": {"description": "The detected color", "$ref": "#/$defs/Color"},
189+
"color": {
190+
"description": "The detected color",
191+
"enum": ["red", "blue", "green"],
192+
"title": "Color",
193+
"type": "string",
194+
},
190195
"hex_color_code": {
191196
"description": "The hex color code of the detected color",
192197
"title": "Hex Color Code",
@@ -207,7 +212,11 @@ def test_enums() -> None:
207212
"strict": True,
208213
"parameters": {
209214
"properties": {
210-
"color": {"description": "The detected color", "$ref": "#/definitions/Color"},
215+
"color": {
216+
"description": "The detected color",
217+
"title": "Color",
218+
"enum": ["red", "blue", "green"],
219+
},
211220
"hex_color_code": {
212221
"description": "The hex color code of the detected color",
213222
"title": "Hex Color Code",

0 commit comments

Comments
 (0)