Skip to content

Commit d0e4baa

Browse files
chore(internal): support more input types (#1211)
1 parent 7853a83 commit d0e4baa

File tree

5 files changed

+75
-1
lines changed

5 files changed

+75
-1
lines changed

src/openai/_files.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313
FileContent,
1414
RequestFiles,
1515
HttpxFileTypes,
16+
Base64FileInput,
1617
HttpxFileContent,
1718
HttpxRequestFiles,
1819
)
1920
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
2021

2122

23+
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
24+
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
25+
26+
2227
def is_file_content(obj: object) -> TypeGuard[FileContent]:
2328
return (
2429
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)

src/openai/_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
4242
ProxiesTypes = Union[str, Proxy, ProxiesDict]
4343
if TYPE_CHECKING:
44+
Base64FileInput = Union[IO[bytes], PathLike[str]]
4445
FileContent = Union[IO[bytes], bytes, PathLike[str]]
4546
else:
47+
Base64FileInput = Union[IO[bytes], PathLike]
4648
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
4749
FileTypes = Union[
4850
# file (or bytes)

src/openai/_utils/_transform.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from __future__ import annotations
22

3+
import io
4+
import base64
5+
import pathlib
36
from typing import Any, Mapping, TypeVar, cast
47
from datetime import date, datetime
58
from typing_extensions import Literal, get_args, override, get_type_hints
69

10+
import anyio
711
import pydantic
812

913
from ._utils import (
1014
is_list,
1115
is_mapping,
1216
is_iterable,
1317
)
18+
from .._files import is_base64_file_input
1419
from ._typing import (
1520
is_list_type,
1621
is_union_type,
@@ -29,7 +34,7 @@
2934
# TODO: ensure works correctly with forward references in all cases
3035

3136

32-
PropertyFormat = Literal["iso8601", "custom"]
37+
PropertyFormat = Literal["iso8601", "base64", "custom"]
3338

3439

3540
class PropertyInfo:
@@ -201,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
201206
if format_ == "custom" and format_template is not None:
202207
return data.strftime(format_template)
203208

209+
if format_ == "base64" and is_base64_file_input(data):
210+
binary: str | bytes | None = None
211+
212+
if isinstance(data, pathlib.Path):
213+
binary = data.read_bytes()
214+
elif isinstance(data, io.IOBase):
215+
binary = data.read()
216+
217+
if isinstance(binary, str): # type: ignore[unreachable]
218+
binary = binary.encode()
219+
220+
if not isinstance(binary, bytes):
221+
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
222+
223+
return base64.b64encode(binary).decode("ascii")
224+
204225
return data
205226

206227

@@ -323,6 +344,22 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
323344
if format_ == "custom" and format_template is not None:
324345
return data.strftime(format_template)
325346

347+
if format_ == "base64" and is_base64_file_input(data):
348+
binary: str | bytes | None = None
349+
350+
if isinstance(data, pathlib.Path):
351+
binary = await anyio.Path(data).read_bytes()
352+
elif isinstance(data, io.IOBase):
353+
binary = data.read()
354+
355+
if isinstance(binary, str): # type: ignore[unreachable]
356+
binary = binary.encode()
357+
358+
if not isinstance(binary, bytes):
359+
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
360+
361+
return base64.b64encode(binary).decode("ascii")
362+
326363
return data
327364

328365

tests/sample_file.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hello, world!

tests/test_transform.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import io
4+
import pathlib
35
from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
46
from datetime import date, datetime
57
from typing_extensions import Required, Annotated, TypedDict
68

79
import pytest
810

11+
from openai._types import Base64FileInput
912
from openai._utils import (
1013
PropertyInfo,
1114
transform as _transform,
@@ -17,6 +20,8 @@
1720

1821
_T = TypeVar("_T")
1922

23+
SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")
24+
2025

2126
async def transform(
2227
data: _T,
@@ -377,3 +382,27 @@ async def test_iterable_union_str(use_async: bool) -> None:
377382
assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
378383
{"fooBaz": "bar"}
379384
]
385+
386+
387+
class TypedDictBase64Input(TypedDict):
388+
foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]
389+
390+
391+
@parametrize
392+
@pytest.mark.asyncio
393+
async def test_base64_file_input(use_async: bool) -> None:
394+
# strings are left as-is
395+
assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}
396+
397+
# pathlib.Path is automatically converted to base64
398+
assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
399+
"foo": "SGVsbG8sIHdvcmxkIQo="
400+
} # type: ignore[comparison-overlap]
401+
402+
# io instances are automatically converted to base64
403+
assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
404+
"foo": "SGVsbG8sIHdvcmxkIQ=="
405+
} # type: ignore[comparison-overlap]
406+
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
407+
"foo": "SGVsbG8sIHdvcmxkIQ=="
408+
} # type: ignore[comparison-overlap]

0 commit comments

Comments
 (0)