Skip to content

Commit ac132e3

Browse files
valgaiabidlabsgradio-pr-bot
authored
Support the use of custom authentication mechanism, timeouts, and other httpx parameters in Python Client (#8862)
* gradio Client now supports the use of custom authentication mechanism with httpx * Fix formatting issues * Replace specific parameter `httpx_auth` by a more general parameter `httpx_kwargs`. * add changeset * typing * future --------- Co-authored-by: Abubakar Abid <[email protected]> Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 7f1a78c commit ac132e3

File tree

5 files changed

+49
-6
lines changed

5 files changed

+49
-6
lines changed

.changeset/clean-eagles-taste.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"gradio": minor
3+
"gradio_client": minor
4+
---
5+
6+
feat:Support the use of custom authentication mechanism, timeouts, and other `httpx` parameters in Python Client

client/python/gradio_client/client.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
max_workers: int = 40,
8080
verbose: bool = True,
8181
auth: tuple[str, str] | None = None,
82+
httpx_kwargs: dict[str, Any] | None = None,
8283
*,
8384
headers: dict[str, str] | None = None,
8485
download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR,
@@ -94,6 +95,7 @@ def __init__(
9495
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys.
9596
download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead.
9697
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
98+
httpx_kwargs: Additional keyword arguments to pass to `httpx.Client`, `httpx.stream`, `httpx.get` and `httpx.post`. This can be used to set timeouts, proxies, http auth, etc.
9799
"""
98100
self.verbose = verbose
99101
self.hf_token = hf_token
@@ -143,6 +145,7 @@ def __init__(
143145
if self.verbose:
144146
print(f"Loaded as API: {self.src} ✔")
145147

148+
self.httpx_kwargs = {} if httpx_kwargs is None else httpx_kwargs
146149
if auth is not None:
147150
self._login(auth)
148151

@@ -202,13 +205,15 @@ def _stream_heartbeat(self):
202205
while True:
203206
url = self.heartbeat_url.format(session_hash=self.session_hash)
204207
try:
208+
httpx_kwargs = self.httpx_kwargs.copy()
209+
httpx_kwargs.setdefault("timeout", 20)
205210
with httpx.stream(
206211
"GET",
207212
url,
208213
headers=self.headers,
209214
cookies=self.cookies,
210215
verify=self.ssl_verify,
211-
timeout=20,
216+
**httpx_kwargs,
212217
) as response:
213218
for _ in response.iter_lines():
214219
if self._refresh_heartbeat.is_set():
@@ -223,8 +228,11 @@ def stream_messages(
223228
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
224229
) -> None:
225230
try:
231+
httpx_kwargs = self.httpx_kwargs.copy()
232+
httpx_kwargs.setdefault("timeout", httpx.Timeout(timeout=None))
226233
with httpx.Client(
227-
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
234+
verify=self.ssl_verify,
235+
**httpx_kwargs,
228236
) as client:
229237
with client.stream(
230238
"GET",
@@ -284,6 +292,7 @@ def send_data(self, data, hash_data, protocol):
284292
headers=self.headers,
285293
cookies=self.cookies,
286294
verify=self.ssl_verify,
295+
**self.httpx_kwargs,
287296
)
288297
if req.status_code == 503:
289298
raise QueueError("Queue is full! Please try again.")
@@ -549,6 +558,7 @@ def _get_api_info(self):
549558
headers=self.headers,
550559
cookies=self.cookies,
551560
verify=self.ssl_verify,
561+
**self.httpx_kwargs,
552562
)
553563
if r.is_success:
554564
info = r.json()
@@ -561,6 +571,7 @@ def _get_api_info(self):
561571
"config": json.dumps(self.config),
562572
"serialize": False,
563573
},
574+
**self.httpx_kwargs,
564575
)
565576
if fetch.is_success:
566577
info = fetch.json()["api"]
@@ -823,6 +834,7 @@ def _login(self, auth: tuple[str, str]):
823834
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
824835
data={"username": auth[0], "password": auth[1]},
825836
verify=self.ssl_verify,
837+
**self.httpx_kwargs,
826838
)
827839
if not resp.is_success:
828840
if resp.status_code == 401:
@@ -841,6 +853,7 @@ def _get_config(self) -> dict:
841853
headers=self.headers,
842854
cookies=self.cookies,
843855
verify=self.ssl_verify,
856+
**self.httpx_kwargs,
844857
)
845858
if r.is_success:
846859
return r.json()
@@ -854,6 +867,7 @@ def _get_config(self) -> dict:
854867
headers=self.headers,
855868
cookies=self.cookies,
856869
verify=self.ssl_verify,
870+
**self.httpx_kwargs,
857871
)
858872
if not r.is_success:
859873
raise ValueError(f"Could not fetch config for {self.src}")
@@ -1185,6 +1199,7 @@ def _cancel():
11851199
headers=self.client.headers,
11861200
cookies=self.client.cookies,
11871201
verify=self.client.ssl_verify,
1202+
**self.client.httpx_kwargs,
11881203
)
11891204

11901205
return _cancel
@@ -1331,6 +1346,7 @@ def _upload_file(self, f: dict, data_index: int) -> dict[str, str]:
13311346
cookies=self.client.cookies,
13321347
verify=self.client.ssl_verify,
13331348
files=files,
1349+
**self.client.httpx_kwargs,
13341350
)
13351351
r.raise_for_status()
13361352
result = r.json()
@@ -1360,6 +1376,7 @@ def _download_file(self, x: dict) -> str:
13601376
cookies=self.client.cookies,
13611377
verify=self.client.ssl_verify,
13621378
follow_redirects=True,
1379+
**self.client.httpx_kwargs,
13631380
) as response:
13641381
response.raise_for_status()
13651382
with open(temp_dir / Path(url_path).name, "wb") as f:
@@ -1375,7 +1392,9 @@ def _download_file(self, x: dict) -> str:
13751392

13761393
def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
13771394
with httpx.Client(
1378-
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
1395+
timeout=httpx.Timeout(timeout=None),
1396+
verify=self.client.ssl_verify,
1397+
**self.client.httpx_kwargs,
13791398
) as client:
13801399
return utils.get_pred_from_sse_v0(
13811400
client,

client/python/gradio_client/compatibility.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _predict(*data) -> tuple:
101101
headers=self.client.headers,
102102
json=data,
103103
verify=self.client.ssl_verify,
104+
auth=self.client.httpx_auth,
104105
)
105106
result = json.loads(response.content.decode("utf-8"))
106107
try:
@@ -154,6 +155,7 @@ def _upload(
154155
headers=self.client.headers,
155156
files=files,
156157
verify=self.client.ssl_verify,
158+
auth=self.client.httpx_auth,
157159
)
158160
if r.status_code != 200:
159161
uploaded = file_paths

client/python/test/test_client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import json
24
import os
35
import pathlib
@@ -38,11 +40,14 @@
3840
def connect(
3941
demo: gr.Blocks,
4042
download_files: str = DEFAULT_TEMP_DIR,
43+
client_kwargs: dict | None = None,
4144
**kwargs,
4245
):
4346
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
47+
if client_kwargs is None:
48+
client_kwargs = {}
4449
try:
45-
yield Client(local_url, download_files=download_files)
50+
yield Client(local_url, download_files=download_files, **client_kwargs)
4651
finally:
4752
# A more verbose version of .close()
4853
# because we should set a timeout
@@ -1406,3 +1411,13 @@ def test_upstream_exceptions(count_generator_demo_exception):
14061411
match="The upstream Gradio app has raised an exception but has not enabled verbose error reporting.",
14071412
):
14081413
client.predict(7, api_name="/count")
1414+
1415+
1416+
def test_httpx_kwargs(increment_demo):
1417+
with connect(
1418+
increment_demo, client_kwargs={"httpx_kwargs": {"timeout": 5}}
1419+
) as client:
1420+
with patch("httpx.post", MagicMock()) as mock_post:
1421+
with pytest.raises(Exception):
1422+
client.predict(1, api_name="/increment_with_queue")
1423+
assert mock_post.call_args.kwargs["timeout"] == 5

gradio/monitoring_dashboard.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def gen_plot(start, end, selected_fn):
6262
df = df[(df["time"] >= start) & (df["time"] <= end)]
6363
df["time"] = pd.to_datetime(df["time"], unit="s")
6464

65-
unique_users = len(df["session_hash"].unique())
65+
unique_users = len(df["session_hash"].unique()) # type: ignore
6666
total_requests = len(df)
6767
process_time = round(df["process_time"].mean(), 2)
6868

@@ -74,7 +74,8 @@ def gen_plot(start, end, selected_fn):
7474
if duration >= 60 * 60 * 3
7575
else "1m"
7676
)
77-
df = df.drop(columns=["session_hash"])
77+
df = df.drop(columns=["session_hash"]) # type: ignore
78+
assert isinstance(df, pd.DataFrame) # noqa: S101
7879
return (
7980
gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]),
8081
unique_users,

0 commit comments

Comments
 (0)