Skip to content

Commit 287926e

Browse files
fix(azure): refresh auth token during retries (#1533)
1 parent 41f682b commit 287926e

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

tests/lib/test_azure.py

+86-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Union
2-
from typing_extensions import Literal
1+
from typing import Union, cast
2+
from typing_extensions import Literal, Protocol
33

4+
import httpx
45
import pytest
6+
from respx import MockRouter
57

68
from openai._models import FinalRequestOptions
79
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
@@ -22,6 +24,10 @@
2224
)
2325

2426

27+
class MockRequestCall(Protocol):
28+
request: httpx.Request
29+
30+
2531
@pytest.mark.parametrize("client", [sync_client, async_client])
2632
def test_implicit_deployment_path(client: Client) -> None:
2733
req = client._build_request(
@@ -64,3 +70,81 @@ def test_client_copying_override_options(client: Client) -> None:
6470
api_version="2022-05-01",
6571
)
6672
assert copied._custom_query == {"api-version": "2022-05-01"}
73+
74+
75+
@pytest.mark.respx()
76+
def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
77+
respx_mock.post(
78+
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
79+
).mock(
80+
side_effect=[
81+
httpx.Response(500, json={"error": "server error"}),
82+
httpx.Response(200, json={"foo": "bar"}),
83+
]
84+
)
85+
86+
counter = 0
87+
88+
def token_provider() -> str:
89+
nonlocal counter
90+
91+
counter += 1
92+
93+
if counter == 1:
94+
return "first"
95+
96+
return "second"
97+
98+
client = AzureOpenAI(
99+
api_version="2024-02-01",
100+
azure_ad_token_provider=token_provider,
101+
azure_endpoint="https://example-resource.azure.openai.com",
102+
)
103+
client.chat.completions.create(messages=[], model="gpt-4")
104+
105+
calls = cast("list[MockRequestCall]", respx_mock.calls)
106+
107+
assert len(calls) == 2
108+
109+
assert calls[0].request.headers.get("Authorization") == "Bearer first"
110+
assert calls[1].request.headers.get("Authorization") == "Bearer second"
111+
112+
113+
@pytest.mark.asyncio
114+
@pytest.mark.respx()
115+
async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
116+
respx_mock.post(
117+
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
118+
).mock(
119+
side_effect=[
120+
httpx.Response(500, json={"error": "server error"}),
121+
httpx.Response(200, json={"foo": "bar"}),
122+
]
123+
)
124+
125+
counter = 0
126+
127+
def token_provider() -> str:
128+
nonlocal counter
129+
130+
counter += 1
131+
132+
if counter == 1:
133+
return "first"
134+
135+
return "second"
136+
137+
client = AsyncAzureOpenAI(
138+
api_version="2024-02-01",
139+
azure_ad_token_provider=token_provider,
140+
azure_endpoint="https://example-resource.azure.openai.com",
141+
)
142+
143+
await client.chat.completions.create(messages=[], model="gpt-4")
144+
145+
calls = cast("list[MockRequestCall]", respx_mock.calls)
146+
147+
assert len(calls) == 2
148+
149+
assert calls[0].request.headers.get("Authorization") == "Bearer first"
150+
assert calls[1].request.headers.get("Authorization") == "Bearer second"

0 commit comments

Comments
 (0)