Skip to content

Commit 5ba576a

Browse files
chore(internal): minor utils restructuring (#992)
1 parent 6c3427d commit 5ba576a

File tree

8 files changed

+183
-66
lines changed

8 files changed

+183
-66
lines changed

src/openai/_response.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import datetime
66
import functools
77
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
8-
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
8+
from typing_extensions import Awaitable, ParamSpec, override, get_origin
99

1010
import httpx
1111

1212
from ._types import NoneType, UnknownResponse, BinaryResponseContent
13-
from ._utils import is_given
13+
from ._utils import is_given, extract_type_var_from_base
1414
from ._models import BaseModel, is_basemodel
1515
from ._constants import RAW_RESPONSE_HEADER
1616
from ._exceptions import APIResponseValidationError
@@ -221,12 +221,13 @@ def __init__(self) -> None:
221221

222222

223223
def _extract_stream_chunk_type(stream_cls: type) -> type:
224-
args = get_args(stream_cls)
225-
if not args:
226-
raise TypeError(
227-
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
228-
)
229-
return cast(type, args[0])
224+
from ._base_client import Stream, AsyncStream
225+
226+
return extract_type_var_from_base(
227+
stream_cls,
228+
index=0,
229+
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
230+
)
230231

231232

232233
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:

src/openai/_streaming.py

+56-15
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,31 @@
22
from __future__ import annotations
33

44
import json
5-
from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
6-
from typing_extensions import override
5+
from types import TracebackType
6+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
7+
from typing_extensions import Self, override
78

89
import httpx
910

10-
from ._types import ResponseT
1111
from ._utils import is_mapping
1212
from ._exceptions import APIError
1313

1414
if TYPE_CHECKING:
1515
from ._client import OpenAI, AsyncOpenAI
1616

1717

18-
class Stream(Generic[ResponseT]):
18+
_T = TypeVar("_T")
19+
20+
21+
class Stream(Generic[_T]):
1922
"""Provides the core interface to iterate over a synchronous stream response."""
2023

2124
response: httpx.Response
2225

2326
def __init__(
2427
self,
2528
*,
26-
cast_to: type[ResponseT],
29+
cast_to: type[_T],
2730
response: httpx.Response,
2831
client: OpenAI,
2932
) -> None:
@@ -33,18 +36,18 @@ def __init__(
3336
self._decoder = SSEDecoder()
3437
self._iterator = self.__stream__()
3538

36-
def __next__(self) -> ResponseT:
39+
def __next__(self) -> _T:
3740
return self._iterator.__next__()
3841

39-
def __iter__(self) -> Iterator[ResponseT]:
42+
def __iter__(self) -> Iterator[_T]:
4043
for item in self._iterator:
4144
yield item
4245

4346
def _iter_events(self) -> Iterator[ServerSentEvent]:
4447
yield from self._decoder.iter(self.response.iter_lines())
4548

46-
def __stream__(self) -> Iterator[ResponseT]:
47-
cast_to = self._cast_to
49+
def __stream__(self) -> Iterator[_T]:
50+
cast_to = cast(Any, self._cast_to)
4851
response = self.response
4952
process_data = self._client._process_response_data
5053
iterator = self._iter_events()
@@ -68,16 +71,35 @@ def __stream__(self) -> Iterator[ResponseT]:
6871
for _sse in iterator:
6972
...
7073

74+
def __enter__(self) -> Self:
75+
return self
76+
77+
def __exit__(
78+
self,
79+
exc_type: type[BaseException] | None,
80+
exc: BaseException | None,
81+
exc_tb: TracebackType | None,
82+
) -> None:
83+
self.close()
84+
85+
def close(self) -> None:
86+
"""
87+
Close the response and release the connection.
88+
89+
Automatically called if the response body is read to completion.
90+
"""
91+
self.response.close()
7192

72-
class AsyncStream(Generic[ResponseT]):
93+
94+
class AsyncStream(Generic[_T]):
7395
"""Provides the core interface to iterate over an asynchronous stream response."""
7496

7597
response: httpx.Response
7698

7799
def __init__(
78100
self,
79101
*,
80-
cast_to: type[ResponseT],
102+
cast_to: type[_T],
81103
response: httpx.Response,
82104
client: AsyncOpenAI,
83105
) -> None:
@@ -87,19 +109,19 @@ def __init__(
87109
self._decoder = SSEDecoder()
88110
self._iterator = self.__stream__()
89111

90-
async def __anext__(self) -> ResponseT:
112+
async def __anext__(self) -> _T:
91113
return await self._iterator.__anext__()
92114

93-
async def __aiter__(self) -> AsyncIterator[ResponseT]:
115+
async def __aiter__(self) -> AsyncIterator[_T]:
94116
async for item in self._iterator:
95117
yield item
96118

97119
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
98120
async for sse in self._decoder.aiter(self.response.aiter_lines()):
99121
yield sse
100122

101-
async def __stream__(self) -> AsyncIterator[ResponseT]:
102-
cast_to = self._cast_to
123+
async def __stream__(self) -> AsyncIterator[_T]:
124+
cast_to = cast(Any, self._cast_to)
103125
response = self.response
104126
process_data = self._client._process_response_data
105127
iterator = self._iter_events()
@@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
123145
async for _sse in iterator:
124146
...
125147

148+
async def __aenter__(self) -> Self:
149+
return self
150+
151+
async def __aexit__(
152+
self,
153+
exc_type: type[BaseException] | None,
154+
exc: BaseException | None,
155+
exc_tb: TracebackType | None,
156+
) -> None:
157+
await self.close()
158+
159+
async def close(self) -> None:
160+
"""
161+
Close the response and release the connection.
162+
163+
Automatically called if the response body is read to completion.
164+
"""
165+
await self.response.aclose()
166+
126167

127168
class ServerSentEvent:
128169
def __init__(

src/openai/_types.py

+14
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,17 @@ def get(self, __key: str) -> str | None:
353353
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
354354

355355
PostParser = Callable[[Any], Any]
356+
357+
358+
@runtime_checkable
359+
class InheritsGeneric(Protocol):
360+
"""Represents a type that has inherited from `Generic`
361+
The `__orig_bases__` property can be used to determine the resolved
362+
type variable for a given base class.
363+
"""
364+
365+
__orig_bases__: tuple[_GenericAlias]
366+
367+
368+
class _GenericAlias(Protocol):
369+
__origin__: type[object]

src/openai/_utils/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,32 @@
99
from ._utils import parse_date as parse_date
1010
from ._utils import is_sequence as is_sequence
1111
from ._utils import coerce_float as coerce_float
12-
from ._utils import is_list_type as is_list_type
1312
from ._utils import is_mapping_t as is_mapping_t
1413
from ._utils import removeprefix as removeprefix
1514
from ._utils import removesuffix as removesuffix
1615
from ._utils import extract_files as extract_files
1716
from ._utils import is_sequence_t as is_sequence_t
18-
from ._utils import is_union_type as is_union_type
1917
from ._utils import required_args as required_args
2018
from ._utils import coerce_boolean as coerce_boolean
2119
from ._utils import coerce_integer as coerce_integer
2220
from ._utils import file_from_path as file_from_path
2321
from ._utils import parse_datetime as parse_datetime
2422
from ._utils import strip_not_given as strip_not_given
2523
from ._utils import deepcopy_minimal as deepcopy_minimal
26-
from ._utils import extract_type_arg as extract_type_arg
27-
from ._utils import is_required_type as is_required_type
2824
from ._utils import get_async_library as get_async_library
29-
from ._utils import is_annotated_type as is_annotated_type
3025
from ._utils import maybe_coerce_float as maybe_coerce_float
3126
from ._utils import get_required_header as get_required_header
3227
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
3328
from ._utils import maybe_coerce_integer as maybe_coerce_integer
34-
from ._utils import strip_annotated_type as strip_annotated_type
29+
from ._typing import is_list_type as is_list_type
30+
from ._typing import is_union_type as is_union_type
31+
from ._typing import extract_type_arg as extract_type_arg
32+
from ._typing import is_required_type as is_required_type
33+
from ._typing import is_annotated_type as is_annotated_type
34+
from ._typing import strip_annotated_type as strip_annotated_type
35+
from ._typing import extract_type_var_from_base as extract_type_var_from_base
36+
from ._streams import consume_sync_iterator as consume_sync_iterator
37+
from ._streams import consume_async_iterator as consume_async_iterator
3538
from ._transform import PropertyInfo as PropertyInfo
3639
from ._transform import transform as transform
3740
from ._transform import maybe_transform as maybe_transform

src/openai/_utils/_streams.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
from typing_extensions import Iterator, AsyncIterator
3+
4+
5+
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
6+
for _ in iterator:
7+
...
8+
9+
10+
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
11+
async for _ in iterator:
12+
...

src/openai/_utils/_transform.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
import pydantic
88

9-
from ._utils import (
10-
is_list,
11-
is_mapping,
9+
from ._utils import is_list, is_mapping
10+
from ._typing import (
1211
is_list_type,
1312
is_union_type,
1413
extract_type_arg,

src/openai/_utils/_typing.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
from typing_extensions import Required, Annotated, get_args, get_origin
5+
6+
from .._types import InheritsGeneric
7+
from .._compat import is_union as _is_union
8+
9+
10+
def is_annotated_type(typ: type) -> bool:
11+
return get_origin(typ) == Annotated
12+
13+
14+
def is_list_type(typ: type) -> bool:
15+
return (get_origin(typ) or typ) == list
16+
17+
18+
def is_union_type(typ: type) -> bool:
19+
return _is_union(get_origin(typ))
20+
21+
22+
def is_required_type(typ: type) -> bool:
23+
return get_origin(typ) == Required
24+
25+
26+
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
27+
def strip_annotated_type(typ: type) -> type:
28+
if is_required_type(typ) or is_annotated_type(typ):
29+
return strip_annotated_type(cast(type, get_args(typ)[0]))
30+
31+
return typ
32+
33+
34+
def extract_type_arg(typ: type, index: int) -> type:
35+
args = get_args(typ)
36+
try:
37+
return cast(type, args[index])
38+
except IndexError as err:
39+
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
40+
41+
42+
def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
43+
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
44+
45+
This also handles the case where a concrete subclass is given, e.g.
46+
```py
47+
class MyResponse(Foo[bytes]):
48+
...
49+
50+
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
51+
```
52+
"""
53+
cls = cast(object, get_origin(typ) or typ)
54+
if cls in generic_bases:
55+
# we're given the class directly
56+
return extract_type_arg(typ, index)
57+
58+
# if a subclass is given
59+
# ---
60+
# this is needed as __orig_bases__ is not present in the typeshed stubs
61+
# because it is intended to be for internal use only, however there does
62+
# not seem to be a way to resolve generic TypeVars for inherited subclasses
63+
# without using it.
64+
if isinstance(cls, InheritsGeneric):
65+
target_base_class: Any | None = None
66+
for base in cls.__orig_bases__:
67+
if base.__origin__ in generic_bases:
68+
target_base_class = base
69+
break
70+
71+
if target_base_class is None:
72+
raise RuntimeError(
73+
"Could not find the generic base class;\n"
74+
"This should never happen;\n"
75+
f"Does {cls} inherit from one of {generic_bases} ?"
76+
)
77+
78+
return extract_type_arg(target_base_class, index)
79+
80+
raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")

src/openai/_utils/_utils.py

+1-34
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
overload,
1717
)
1818
from pathlib import Path
19-
from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
19+
from typing_extensions import TypeGuard
2020

2121
import sniffio
2222

2323
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
24-
from .._compat import is_union as _is_union
2524
from .._compat import parse_date as parse_date
2625
from .._compat import parse_datetime as parse_datetime
2726

@@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
166165
return isinstance(obj, list)
167166

168167

169-
def is_annotated_type(typ: type) -> bool:
170-
return get_origin(typ) == Annotated
171-
172-
173-
def is_list_type(typ: type) -> bool:
174-
return (get_origin(typ) or typ) == list
175-
176-
177-
def is_union_type(typ: type) -> bool:
178-
return _is_union(get_origin(typ))
179-
180-
181-
def is_required_type(typ: type) -> bool:
182-
return get_origin(typ) == Required
183-
184-
185-
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
186-
def strip_annotated_type(typ: type) -> type:
187-
if is_required_type(typ) or is_annotated_type(typ):
188-
return strip_annotated_type(cast(type, get_args(typ)[0]))
189-
190-
return typ
191-
192-
193-
def extract_type_arg(typ: type, index: int) -> type:
194-
args = get_args(typ)
195-
try:
196-
return cast(type, args[index])
197-
except IndexError as err:
198-
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
199-
200-
201168
def deepcopy_minimal(item: _T) -> _T:
202169
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
203170

0 commit comments

Comments
 (0)