2
2
from __future__ import annotations
3
3
4
4
import json
5
- from typing import TYPE_CHECKING , Any , Generic , Iterator , AsyncIterator
6
- from typing_extensions import override
5
+ from types import TracebackType
6
+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , Iterator , AsyncIterator , cast
7
+ from typing_extensions import Self , override
7
8
8
9
import httpx
9
10
10
- from ._types import ResponseT
11
11
from ._utils import is_mapping
12
12
from ._exceptions import APIError
13
13
14
14
if TYPE_CHECKING :
15
15
from ._client import OpenAI , AsyncOpenAI
16
16
17
17
18
- class Stream (Generic [ResponseT ]):
18
+ _T = TypeVar ("_T" )
19
+
20
+
21
+ class Stream (Generic [_T ]):
19
22
"""Provides the core interface to iterate over a synchronous stream response."""
20
23
21
24
response : httpx .Response
22
25
23
26
def __init__ (
24
27
self ,
25
28
* ,
26
- cast_to : type [ResponseT ],
29
+ cast_to : type [_T ],
27
30
response : httpx .Response ,
28
31
client : OpenAI ,
29
32
) -> None :
@@ -33,18 +36,18 @@ def __init__(
33
36
self ._decoder = SSEDecoder ()
34
37
self ._iterator = self .__stream__ ()
35
38
36
- def __next__ (self ) -> ResponseT :
39
+ def __next__ (self ) -> _T :
37
40
return self ._iterator .__next__ ()
38
41
39
- def __iter__ (self ) -> Iterator [ResponseT ]:
42
+ def __iter__ (self ) -> Iterator [_T ]:
40
43
for item in self ._iterator :
41
44
yield item
42
45
43
46
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
44
47
yield from self ._decoder .iter (self .response .iter_lines ())
45
48
46
- def __stream__ (self ) -> Iterator [ResponseT ]:
47
- cast_to = self ._cast_to
49
+ def __stream__ (self ) -> Iterator [_T ]:
50
+ cast_to = cast ( Any , self ._cast_to )
48
51
response = self .response
49
52
process_data = self ._client ._process_response_data
50
53
iterator = self ._iter_events ()
@@ -68,16 +71,35 @@ def __stream__(self) -> Iterator[ResponseT]:
68
71
for _sse in iterator :
69
72
...
70
73
74
+ def __enter__ (self ) -> Self :
75
+ return self
76
+
77
+ def __exit__ (
78
+ self ,
79
+ exc_type : type [BaseException ] | None ,
80
+ exc : BaseException | None ,
81
+ exc_tb : TracebackType | None ,
82
+ ) -> None :
83
+ self .close ()
84
+
85
+ def close (self ) -> None :
86
+ """
87
+ Close the response and release the connection.
88
+
89
+ Automatically called if the response body is read to completion.
90
+ """
91
+ self .response .close ()
71
92
72
- class AsyncStream (Generic [ResponseT ]):
93
+
94
+ class AsyncStream (Generic [_T ]):
73
95
"""Provides the core interface to iterate over an asynchronous stream response."""
74
96
75
97
response : httpx .Response
76
98
77
99
def __init__ (
78
100
self ,
79
101
* ,
80
- cast_to : type [ResponseT ],
102
+ cast_to : type [_T ],
81
103
response : httpx .Response ,
82
104
client : AsyncOpenAI ,
83
105
) -> None :
@@ -87,19 +109,19 @@ def __init__(
87
109
self ._decoder = SSEDecoder ()
88
110
self ._iterator = self .__stream__ ()
89
111
90
- async def __anext__ (self ) -> ResponseT :
112
+ async def __anext__ (self ) -> _T :
91
113
return await self ._iterator .__anext__ ()
92
114
93
- async def __aiter__ (self ) -> AsyncIterator [ResponseT ]:
115
+ async def __aiter__ (self ) -> AsyncIterator [_T ]:
94
116
async for item in self ._iterator :
95
117
yield item
96
118
97
119
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
98
120
async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
99
121
yield sse
100
122
101
- async def __stream__ (self ) -> AsyncIterator [ResponseT ]:
102
- cast_to = self ._cast_to
123
+ async def __stream__ (self ) -> AsyncIterator [_T ]:
124
+ cast_to = cast ( Any , self ._cast_to )
103
125
response = self .response
104
126
process_data = self ._client ._process_response_data
105
127
iterator = self ._iter_events ()
@@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
123
145
async for _sse in iterator :
124
146
...
125
147
148
+ async def __aenter__ (self ) -> Self :
149
+ return self
150
+
151
+ async def __aexit__ (
152
+ self ,
153
+ exc_type : type [BaseException ] | None ,
154
+ exc : BaseException | None ,
155
+ exc_tb : TracebackType | None ,
156
+ ) -> None :
157
+ await self .close ()
158
+
159
+ async def close (self ) -> None :
160
+ """
161
+ Close the response and release the connection.
162
+
163
+ Automatically called if the response body is read to completion.
164
+ """
165
+ await self .response .aclose ()
166
+
126
167
127
168
class ServerSentEvent :
128
169
def __init__ (
0 commit comments