Skip to content

Commit 4314cdc

Browse files
chore(internal): minor core client restructuring (#1199)
1 parent e41abf7 commit 4314cdc

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

src/openai/_base_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
RAW_RESPONSE_HEADER,
8080
OVERRIDE_CAST_TO_HEADER,
8181
)
82-
from ._streaming import Stream, AsyncStream
82+
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
8383
from ._exceptions import (
8484
APIStatusError,
8585
APITimeoutError,
@@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL:
431431

432432
return merge_url
433433

434+
def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
435+
return SSEDecoder()
436+
434437
def _build_request(
435438
self,
436439
options: FinalRequestOptions,

src/openai/_streaming.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
from types import TracebackType
77
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
8-
from typing_extensions import Self, TypeGuard, override, get_origin
8+
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
99

1010
import httpx
1111

@@ -24,6 +24,8 @@ class Stream(Generic[_T]):
2424

2525
response: httpx.Response
2626

27+
_decoder: SSEDecoder | SSEBytesDecoder
28+
2729
def __init__(
2830
self,
2931
*,
@@ -34,7 +36,7 @@ def __init__(
3436
self.response = response
3537
self._cast_to = cast_to
3638
self._client = client
37-
self._decoder = SSEDecoder()
39+
self._decoder = client._make_sse_decoder()
3840
self._iterator = self.__stream__()
3941

4042
def __next__(self) -> _T:
@@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]:
4547
yield item
4648

4749
def _iter_events(self) -> Iterator[ServerSentEvent]:
48-
yield from self._decoder.iter(self.response.iter_lines())
50+
if isinstance(self._decoder, SSEBytesDecoder):
51+
yield from self._decoder.iter_bytes(self.response.iter_bytes())
52+
else:
53+
yield from self._decoder.iter(self.response.iter_lines())
4954

5055
def __stream__(self) -> Iterator[_T]:
5156
cast_to = cast(Any, self._cast_to)
@@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):
97102

98103
response: httpx.Response
99104

105+
_decoder: SSEDecoder | SSEBytesDecoder
106+
100107
def __init__(
101108
self,
102109
*,
@@ -107,7 +114,7 @@ def __init__(
107114
self.response = response
108115
self._cast_to = cast_to
109116
self._client = client
110-
self._decoder = SSEDecoder()
117+
self._decoder = client._make_sse_decoder()
111118
self._iterator = self.__stream__()
112119

113120
async def __anext__(self) -> _T:
@@ -118,8 +125,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
118125
yield item
119126

120127
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
121-
async for sse in self._decoder.aiter(self.response.aiter_lines()):
122-
yield sse
128+
if isinstance(self._decoder, SSEBytesDecoder):
129+
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
130+
yield sse
131+
else:
132+
async for sse in self._decoder.aiter(self.response.aiter_lines()):
133+
yield sse
123134

124135
async def __stream__(self) -> AsyncIterator[_T]:
125136
cast_to = cast(Any, self._cast_to)
@@ -284,6 +295,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
284295
return None
285296

286297

298+
@runtime_checkable
299+
class SSEBytesDecoder(Protocol):
300+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
301+
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
302+
...
303+
304+
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
305+
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
306+
...
307+
308+
287309
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
288310
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
289311
origin = get_origin(typ) or typ

0 commit comments

Comments
 (0)