Skip to content

Commit 8598f81

Browse files
chore(internal): support parsing Annotated types (#1222)
1 parent 3c2e815 commit 8598f81

File tree

7 files changed

+108
-6
lines changed

7 files changed

+108
-6
lines changed

src/openai/_legacy_response.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pydantic
1414

1515
from ._types import NoneType
16-
from ._utils import is_given
16+
from ._utils import is_given, extract_type_arg, is_annotated_type
1717
from ._models import BaseModel, is_basemodel
1818
from ._constants import RAW_RESPONSE_HEADER
1919
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -174,6 +174,10 @@ def elapsed(self) -> datetime.timedelta:
174174
return self.http_response.elapsed
175175

176176
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
177+
# unwrap `Annotated[T, ...]` -> `T`
178+
if to and is_annotated_type(to):
179+
to = extract_type_arg(to, 0)
180+
177181
if self._stream:
178182
if to:
179183
if not is_stream_class_type(to):
@@ -215,6 +219,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
215219
)
216220

217221
cast_to = to if to is not None else self._cast_to
222+
223+
# unwrap `Annotated[T, ...]` -> `T`
224+
if is_annotated_type(cast_to):
225+
cast_to = extract_type_arg(cast_to, 0)
226+
218227
if cast_to is NoneType:
219228
return cast(R, None)
220229

src/openai/_models.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,16 @@
3030
AnyMapping,
3131
HttpxRequestFiles,
3232
)
33-
from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given
33+
from ._utils import (
34+
is_list,
35+
is_given,
36+
is_mapping,
37+
parse_date,
38+
parse_datetime,
39+
strip_not_given,
40+
extract_type_arg,
41+
is_annotated_type,
42+
)
3443
from ._compat import (
3544
PYDANTIC_V2,
3645
ConfigDict,
@@ -275,6 +284,9 @@ def construct_type(*, value: object, type_: type) -> object:
275284
276285
If the given value does not match the expected type then it is returned as-is.
277286
"""
287+
# unwrap `Annotated[T, ...]` -> `T`
288+
if is_annotated_type(type_):
289+
type_ = extract_type_arg(type_, 0)
278290

279291
# we need to use the origin class for any types that are subscripted generics
280292
# e.g. Dict[str, object]

src/openai/_response.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pydantic
2626

2727
from ._types import NoneType
28-
from ._utils import is_given, extract_type_var_from_base
28+
from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
2929
from ._models import BaseModel, is_basemodel
3030
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
3131
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -121,6 +121,10 @@ def __repr__(self) -> str:
121121
)
122122

123123
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
124+
# unwrap `Annotated[T, ...]` -> `T`
125+
if to and is_annotated_type(to):
126+
to = extract_type_arg(to, 0)
127+
124128
if self._is_sse_stream:
125129
if to:
126130
if not is_stream_class_type(to):
@@ -162,6 +166,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
162166
)
163167

164168
cast_to = to if to is not None else self._cast_to
169+
170+
# unwrap `Annotated[T, ...]` -> `T`
171+
if is_annotated_type(cast_to):
172+
cast_to = extract_type_arg(cast_to, 0)
173+
165174
if cast_to is NoneType:
166175
return cast(R, None)
167176

tests/test_legacy_response.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
from typing import cast
3+
from typing_extensions import Annotated
24

35
import httpx
46
import pytest
@@ -63,3 +65,20 @@ def test_response_parse_custom_model(client: OpenAI) -> None:
6365
obj = response.parse(to=CustomModel)
6466
assert obj.foo == "hello!"
6567
assert obj.bar == 2
68+
69+
70+
def test_response_parse_annotated_type(client: OpenAI) -> None:
71+
response = LegacyAPIResponse(
72+
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
73+
client=client,
74+
stream=False,
75+
stream_cls=None,
76+
cast_to=str,
77+
options=FinalRequestOptions.construct(method="get", url="/foo"),
78+
)
79+
80+
obj = response.parse(
81+
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
82+
)
83+
assert obj.foo == "hello!"
84+
assert obj.bar == 2

tests/test_models.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import json
22
from typing import Any, Dict, List, Union, Optional, cast
33
from datetime import datetime, timezone
4-
from typing_extensions import Literal
4+
from typing_extensions import Literal, Annotated
55

66
import pytest
77
import pydantic
88
from pydantic import Field
99

1010
from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
11-
from openai._models import BaseModel
11+
from openai._models import BaseModel, construct_type
1212

1313

1414
class BasicModel(BaseModel):
@@ -571,3 +571,15 @@ class OurModel(BaseModel):
571571
foo: Optional[str] = None
572572

573573
takes_pydantic(OurModel())
574+
575+
576+
def test_annotated_types() -> None:
577+
class Model(BaseModel):
578+
value: str
579+
580+
m = construct_type(
581+
value={"value": "foo"},
582+
type_=cast(Any, Annotated[Model, "random metadata"]),
583+
)
584+
assert isinstance(m, Model)
585+
assert m.value == "foo"

tests/test_response.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
2-
from typing import List
2+
from typing import List, cast
3+
from typing_extensions import Annotated
34

45
import httpx
56
import pytest
@@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> N
157158
obj = await response.parse(to=CustomModel)
158159
assert obj.foo == "hello!"
159160
assert obj.bar == 2
161+
162+
163+
def test_response_parse_annotated_type(client: OpenAI) -> None:
164+
response = APIResponse(
165+
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
166+
client=client,
167+
stream=False,
168+
stream_cls=None,
169+
cast_to=str,
170+
options=FinalRequestOptions.construct(method="get", url="/foo"),
171+
)
172+
173+
obj = response.parse(
174+
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
175+
)
176+
assert obj.foo == "hello!"
177+
assert obj.bar == 2
178+
179+
180+
async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None:
181+
response = AsyncAPIResponse(
182+
raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
183+
client=async_client,
184+
stream=False,
185+
stream_cls=None,
186+
cast_to=str,
187+
options=FinalRequestOptions.construct(method="get", url="/foo"),
188+
)
189+
190+
obj = await response.parse(
191+
to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
192+
)
193+
assert obj.foo == "hello!"
194+
assert obj.bar == 2

tests/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
is_list,
1515
is_list_type,
1616
is_union_type,
17+
extract_type_arg,
18+
is_annotated_type,
1719
)
1820
from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
1921
from openai._models import BaseModel
@@ -49,6 +51,10 @@ def assert_matches_type(
4951
path: list[str],
5052
allow_none: bool = False,
5153
) -> None:
54+
# unwrap `Annotated[T, ...]` -> `T`
55+
if is_annotated_type(type_):
56+
type_ = extract_type_arg(type_, 0)
57+
5258
if allow_none and value is None:
5359
return
5460

0 commit comments

Comments
 (0)