Skip to content

Commit 18191da

Browse files
committed
fix(client): raise helpful error message for response_format misuse
1 parent 631a2a7 commit 18191da

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/openai/resources/chat/completions.py

+11
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import inspect
56
from typing import Dict, List, Union, Iterable, Optional, overload
67
from typing_extensions import Literal
78

89
import httpx
10+
import pydantic
911

1012
from ... import _legacy_response
1113
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
@@ -647,6 +649,7 @@ def create(
647649
extra_body: Body | None = None,
648650
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
649651
) -> ChatCompletion | Stream[ChatCompletionChunk]:
652+
validate_response_format(response_format)
650653
return self._post(
651654
"/chat/completions",
652655
body=maybe_transform(
@@ -1302,6 +1305,7 @@ async def create(
13021305
extra_body: Body | None = None,
13031306
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
13041307
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
1308+
validate_response_format(response_format)
13051309
return await self._post(
13061310
"/chat/completions",
13071311
body=await async_maybe_transform(
@@ -1375,3 +1379,10 @@ def __init__(self, completions: AsyncCompletions) -> None:
13751379
self.create = async_to_streamed_response_wrapper(
13761380
completions.create,
13771381
)
1382+
1383+
1384+
def validate_response_format(response_format: object) -> None:
1385+
if inspect.isclass(response_format) and issubclass(response_format, pydantic.BaseModel):
1386+
raise TypeError(
1387+
"You tried to pass a `BaseModel` class to `chat.completions.create()`; You must use `beta.chat.completions.parse()` instead"
1388+
)

tests/api_resources/chat/test_completions.py

+35
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, cast
77

88
import pytest
9+
import pydantic
910

1011
from openai import OpenAI, AsyncOpenAI
1112
from tests.utils import assert_matches_type
@@ -257,6 +258,23 @@ def test_streaming_response_create_overload_2(self, client: OpenAI) -> None:
257258

258259
assert cast(Any, response.is_closed) is True
259260

261+
@parametrize
262+
def test_method_create_disallows_pydantic(self, client: OpenAI) -> None:
263+
class MyModel(pydantic.BaseModel):
264+
a: str
265+
266+
with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
267+
client.chat.completions.create(
268+
messages=[
269+
{
270+
"content": "string",
271+
"role": "system",
272+
}
273+
],
274+
model="gpt-4o",
275+
response_format=cast(Any, MyModel),
276+
)
277+
260278

261279
class TestAsyncCompletions:
262280
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -498,3 +516,20 @@ async def test_streaming_response_create_overload_2(self, async_client: AsyncOpe
498516
await stream.close()
499517

500518
assert cast(Any, response.is_closed) is True
519+
520+
@parametrize
521+
async def test_method_create_disallows_pydantic(self, async_client: AsyncOpenAI) -> None:
522+
class MyModel(pydantic.BaseModel):
523+
a: str
524+
525+
with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
526+
await async_client.chat.completions.create(
527+
messages=[
528+
{
529+
"content": "string",
530+
"role": "system",
531+
}
532+
],
533+
model="gpt-4o",
534+
response_format=cast(Any, MyModel),
535+
)

0 commit comments

Comments
 (0)