5
5
import inspect
6
6
from types import TracebackType
7
7
from typing import TYPE_CHECKING , Any , Generic , TypeVar , Iterator , AsyncIterator , cast
8
- from typing_extensions import Self , TypeGuard , override , get_origin
8
+ from typing_extensions import Self , Protocol , TypeGuard , override , get_origin , runtime_checkable
9
9
10
10
import httpx
11
11
@@ -23,6 +23,8 @@ class Stream(Generic[_T]):
23
23
24
24
response : httpx .Response
25
25
26
+ _decoder : SSEDecoder | SSEBytesDecoder
27
+
26
28
def __init__ (
27
29
self ,
28
30
* ,
@@ -33,7 +35,7 @@ def __init__(
33
35
self .response = response
34
36
self ._cast_to = cast_to
35
37
self ._client = client
36
- self ._decoder = SSEDecoder ()
38
+ self ._decoder = client . _make_sse_decoder ()
37
39
self ._iterator = self .__stream__ ()
38
40
39
41
def __next__ (self ) -> _T :
@@ -44,7 +46,10 @@ def __iter__(self) -> Iterator[_T]:
44
46
yield item
45
47
46
48
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
47
- yield from self ._decoder .iter (self .response .iter_lines ())
49
+ if isinstance (self ._decoder , SSEBytesDecoder ):
50
+ yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
51
+ else :
52
+ yield from self ._decoder .iter (self .response .iter_lines ())
48
53
49
54
def __stream__ (self ) -> Iterator [_T ]:
50
55
cast_to = cast (Any , self ._cast_to )
@@ -84,6 +89,8 @@ class AsyncStream(Generic[_T]):
84
89
85
90
response : httpx .Response
86
91
92
+ _decoder : SSEDecoder | SSEBytesDecoder
93
+
87
94
def __init__ (
88
95
self ,
89
96
* ,
@@ -94,7 +101,7 @@ def __init__(
94
101
self .response = response
95
102
self ._cast_to = cast_to
96
103
self ._client = client
97
- self ._decoder = SSEDecoder ()
104
+ self ._decoder = client . _make_sse_decoder ()
98
105
self ._iterator = self .__stream__ ()
99
106
100
107
async def __anext__ (self ) -> _T :
@@ -105,8 +112,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
105
112
yield item
106
113
107
114
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
108
- async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
109
- yield sse
115
+ if isinstance (self ._decoder , SSEBytesDecoder ):
116
+ async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
117
+ yield sse
118
+ else :
119
+ async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
120
+ yield sse
110
121
111
122
async def __stream__ (self ) -> AsyncIterator [_T ]:
112
123
cast_to = cast (Any , self ._cast_to )
@@ -259,6 +270,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
259
270
return None
260
271
261
272
273
+ @runtime_checkable
274
+ class SSEBytesDecoder (Protocol ):
275
+ def iter_bytes (self , iterator : Iterator [bytes ]) -> Iterator [ServerSentEvent ]:
276
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
277
+ ...
278
+
279
+ def aiter_bytes (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [ServerSentEvent ]:
280
+ """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
281
+ ...
282
+
283
+
262
284
def is_stream_class_type (typ : type ) -> TypeGuard [type [Stream [object ]] | type [AsyncStream [object ]]]:
263
285
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
264
286
origin = get_origin (typ ) or typ
0 commit comments