Skip to content

Commit a0701b5

Browse files
committed
chore(internal): minor options / compat functions updates (#1549)
1 parent ee1c62e commit a0701b5

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

src/openai/_base_client.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,9 @@ def __exit__(
880880
def _prepare_options(
881881
self,
882882
options: FinalRequestOptions, # noqa: ARG002
883-
) -> None:
883+
) -> FinalRequestOptions:
884884
"""Hook for mutating the given options"""
885-
return None
885+
return options
886886

887887
def _prepare_request(
888888
self,
@@ -962,7 +962,7 @@ def _request(
962962
input_options = model_copy(options)
963963

964964
cast_to = self._maybe_override_cast_to(cast_to, options)
965-
self._prepare_options(options)
965+
options = self._prepare_options(options)
966966

967967
retries = self._remaining_retries(remaining_retries, options)
968968
request = self._build_request(options)
@@ -1457,9 +1457,9 @@ async def __aexit__(
14571457
async def _prepare_options(
14581458
self,
14591459
options: FinalRequestOptions, # noqa: ARG002
1460-
) -> None:
1460+
) -> FinalRequestOptions:
14611461
"""Hook for mutating the given options"""
1462-
return None
1462+
return options
14631463

14641464
async def _prepare_request(
14651465
self,
@@ -1544,7 +1544,7 @@ async def _request(
15441544
input_options = model_copy(options)
15451545

15461546
cast_to = self._maybe_override_cast_to(cast_to, options)
1547-
await self._prepare_options(options)
1547+
options = await self._prepare_options(options)
15481548

15491549
retries = self._remaining_retries(remaining_retries, options)
15501550
request = self._build_request(options)

src/openai/_compat.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
118118
return model.__fields__ # type: ignore
119119

120120

121-
def model_copy(model: _ModelT) -> _ModelT:
121+
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
122122
if PYDANTIC_V2:
123-
return model.model_copy()
124-
return model.copy() # type: ignore
123+
return model.model_copy(deep=deep)
124+
return model.copy(deep=deep) # type: ignore
125125

126126

127127
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:

src/openai/lib/azure.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
1111
from .._utils import is_given, is_mapping
1212
from .._client import OpenAI, AsyncOpenAI
13+
from .._compat import model_copy
1314
from .._models import FinalRequestOptions
1415
from .._streaming import Stream, AsyncStream
1516
from .._exceptions import OpenAIError
@@ -281,8 +282,10 @@ def _get_azure_ad_token(self) -> str | None:
281282
return None
282283

283284
@override
284-
def _prepare_options(self, options: FinalRequestOptions) -> None:
285+
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
285286
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
287+
288+
options = model_copy(options)
286289
options.headers = headers
287290

288291
azure_ad_token = self._get_azure_ad_token()
@@ -296,7 +299,7 @@ def _prepare_options(self, options: FinalRequestOptions) -> None:
296299
# should never be hit
297300
raise ValueError("Unable to handle auth")
298301

299-
return super()._prepare_options(options)
302+
return options
300303

301304

302305
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@@ -524,8 +527,10 @@ async def _get_azure_ad_token(self) -> str | None:
524527
return None
525528

526529
@override
527-
async def _prepare_options(self, options: FinalRequestOptions) -> None:
530+
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
528531
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
532+
533+
options = model_copy(options)
529534
options.headers = headers
530535

531536
azure_ad_token = await self._get_azure_ad_token()
@@ -539,4 +544,4 @@ async def _prepare_options(self, options: FinalRequestOptions) -> None:
539544
# should never be hit
540545
raise ValueError("Unable to handle auth")
541546

542-
return await super()._prepare_options(options)
547+
return options

0 commit comments

Comments
 (0)