Skip to content

Commit 101bee9

Browse files
feat(parsing): add support for pydantic dataclasses (#1655)
1 parent af2a1ca commit 101bee9

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

src/openai/lib/_parsing/_completions.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from .._tools import PydanticFunctionTool
1010
from ..._types import NOT_GIVEN, NotGiven
1111
from ..._utils import is_dict, is_given
12-
from ..._compat import model_parse_json
12+
from ..._compat import PYDANTIC_V2, model_parse_json
1313
from ..._models import construct_type_unchecked
14-
from .._pydantic import to_strict_json_schema
14+
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
1515
from ...types.chat import (
1616
ParsedChoice,
1717
ChatCompletion,
@@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
216216
return cast(FunctionDefinition, input_fn).get("strict") or False
217217

218218

219-
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
220-
return issubclass(typ, pydantic.BaseModel)
221-
222-
223219
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
224220
if is_basemodel_type(response_format):
225221
return cast(ResponseFormatT, model_parse_json(response_format, content))
226222

223+
if is_dataclass_like_type(response_format):
224+
if not PYDANTIC_V2:
225+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
226+
227+
return pydantic.TypeAdapter(response_format).validate_json(content)
228+
227229
raise TypeError(f"Unable to automatically parse response format type {response_format}")
228230

229231

@@ -241,14 +243,22 @@ def type_to_response_format_param(
241243
# can only be a `type`
242244
response_format = cast(type, response_format)
243245

244-
if not is_basemodel_type(response_format):
246+
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
247+
248+
if is_basemodel_type(response_format):
249+
name = response_format.__name__
250+
json_schema_type = response_format
251+
elif is_dataclass_like_type(response_format):
252+
name = response_format.__name__
253+
json_schema_type = pydantic.TypeAdapter(response_format)
254+
else:
245255
raise TypeError(f"Unsupported response_format type - {response_format}")
246256

247257
return {
248258
"type": "json_schema",
249259
"json_schema": {
250-
"schema": to_strict_json_schema(response_format),
251-
"name": response_format.__name__,
260+
"schema": to_strict_json_schema(json_schema_type),
261+
"name": name,
252262
"strict": True,
253263
},
254264
}

src/openai/lib/_pydantic.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
import inspect
4+
from typing import Any, TypeVar
45
from typing_extensions import TypeGuard
56

67
import pydantic
78

89
from .._types import NOT_GIVEN
910
from .._utils import is_dict as _is_dict, is_list
10-
from .._compat import model_json_schema
11+
from .._compat import PYDANTIC_V2, model_json_schema
1112

13+
_T = TypeVar("_T")
14+
15+
16+
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
17+
if inspect.isclass(model) and is_basemodel_type(model):
18+
schema = model_json_schema(model)
19+
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
20+
schema = model.json_schema()
21+
else:
22+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
1223

13-
def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
14-
schema = model_json_schema(model)
1524
return _ensure_strict_json_schema(schema, path=(), root=schema)
1625

1726

@@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
117126
return resolved
118127

119128

129+
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
130+
return issubclass(typ, pydantic.BaseModel)
131+
132+
133+
def is_dataclass_like_type(typ: type) -> bool:
134+
"""Returns True if the given type likely used `@pydantic.dataclass`"""
135+
return hasattr(typ, "__pydantic_config__")
136+
137+
120138
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
121139
# just pretend that we know there are only `str` keys
122140
# as that check is not worth the performance cost

tests/lib/chat/test_completions.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import json
55
from enum import Enum
6-
from typing import Any, Callable, Optional
6+
from typing import Any, List, Callable, Optional
77
from typing_extensions import Literal, TypeVar
88

99
import httpx
@@ -317,6 +317,63 @@ class Location(BaseModel):
317317
)
318318

319319

320+
@pytest.mark.respx(base_url=base_url)
321+
@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
322+
def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
323+
from pydantic.dataclasses import dataclass
324+
325+
@dataclass
326+
class CalendarEvent:
327+
name: str
328+
date: str
329+
participants: List[str]
330+
331+
completion = _make_snapshot_request(
332+
lambda c: c.beta.chat.completions.parse(
333+
model="gpt-4o-2024-08-06",
334+
messages=[
335+
{"role": "system", "content": "Extract the event information."},
336+
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
337+
],
338+
response_format=CalendarEvent,
339+
),
340+
content_snapshot=snapshot(
341+
'{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
342+
),
343+
mock_client=client,
344+
respx_mock=respx_mock,
345+
)
346+
347+
assert print_obj(completion, monkeypatch) == snapshot(
348+
"""\
349+
ParsedChatCompletion[CalendarEvent](
350+
choices=[
351+
ParsedChoice[CalendarEvent](
352+
finish_reason='stop',
353+
index=0,
354+
logprobs=None,
355+
message=ParsedChatCompletionMessage[CalendarEvent](
356+
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
357+
function_call=None,
358+
parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
359+
refusal=None,
360+
role='assistant',
361+
tool_calls=[]
362+
)
363+
)
364+
],
365+
created=1723761008,
366+
id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
367+
model='gpt-4o-2024-08-06',
368+
object='chat.completion',
369+
service_tier=None,
370+
system_fingerprint='fp_2a322c9ffc',
371+
usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
372+
)
373+
"""
374+
)
375+
376+
320377
@pytest.mark.respx(base_url=base_url)
321378
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
322379
completion = _make_snapshot_request(

0 commit comments

Comments
 (0)