5
5
import inspect
6
6
from types import TracebackType
7
7
from typing import TYPE_CHECKING , Any , Generic , TypeVar , Iterator , AsyncIterator , cast
8
- from typing_extensions import Self , TypeGuard , override , get_origin
8
+ from typing_extensions import Self , Protocol , TypeGuard , override , get_origin , runtime_checkable
9
9
10
10
import httpx
11
11
@@ -24,6 +24,8 @@ class Stream(Generic[_T]):
24
24
25
25
response : httpx .Response
26
26
27
+ _decoder : SSEDecoder | SSEBytesDecoder
28
+
27
29
def __init__ (
28
30
self ,
29
31
* ,
@@ -34,7 +36,7 @@ def __init__(
34
36
self .response = response
35
37
self ._cast_to = cast_to
36
38
self ._client = client
37
- self ._decoder = SSEDecoder ()
39
+ self ._decoder = client . _make_sse_decoder ()
38
40
self ._iterator = self .__stream__ ()
39
41
40
42
def __next__ (self ) -> _T :
@@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]:
45
47
yield item
46
48
47
49
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
48
- yield from self ._decoder .iter (self .response .iter_lines ())
50
+ if isinstance (self ._decoder , SSEBytesDecoder ):
51
+ yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
52
+ else :
53
+ yield from self ._decoder .iter (self .response .iter_lines ())
49
54
50
55
def __stream__ (self ) -> Iterator [_T ]:
51
56
cast_to = cast (Any , self ._cast_to )
@@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):
97
102
98
103
response : httpx .Response
99
104
105
+ _decoder : SSEDecoder | SSEBytesDecoder
106
+
100
107
def __init__ (
101
108
self ,
102
109
* ,
@@ -107,7 +114,7 @@ def __init__(
107
114
self .response = response
108
115
self ._cast_to = cast_to
109
116
self ._client = client
110
- self ._decoder = SSEDecoder ()
117
+ self ._decoder = client . _make_sse_decoder ()
111
118
self ._iterator = self .__stream__ ()
112
119
113
120
async def __anext__ (self ) -> _T :
@@ -118,8 +125,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
118
125
yield item
119
126
120
127
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
121
- async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
122
- yield sse
128
+ if isinstance (self ._decoder , SSEBytesDecoder ):
129
+ async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
130
+ yield sse
131
+ else :
132
+ async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
133
+ yield sse
123
134
124
135
async def __stream__ (self ) -> AsyncIterator [_T ]:
125
136
cast_to = cast (Any , self ._cast_to )
@@ -284,6 +295,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
284
295
return None
285
296
286
297
298
+ @runtime_checkable
299
+ class SSEBytesDecoder (Protocol ):
300
+ def iter_bytes (self , iterator : Iterator [bytes ]) -> Iterator [ServerSentEvent ]:
301
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
302
+ ...
303
+
304
+ def aiter_bytes (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [ServerSentEvent ]:
305
+ """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
306
+ ...
307
+
308
+
287
309
def is_stream_class_type (typ : type ) -> TypeGuard [type [Stream [object ]] | type [AsyncStream [object ]]]:
288
310
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
289
311
origin = get_origin (typ ) or typ
0 commit comments