-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Fix api event drops #6556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix api event drops #6556
Changes from 30 commits
e91912d
0092230
85ae8f9
1885bb9
b7d35c8
57c6e5b
ca73780
ba37ca8
41eb16a
f1d8baf
31cf67f
ca1d7fa
6a66067
4ceb71a
edc95e4
a7db781
c6d2bc9
e9db998
8456e68
2fa939c
fdaa96b
f27f336
27ad025
8ea90a4
acd1f1d
692f2d9
64d30bd
1f53bf3
9fa6d3d
1ff6e7d
448f9c4
7dbe37c
7f6ebe4
5c6349b
7cde497
b722e46
6e5a093
5c598a0
b7ef2c5
778f15b
bba88da
caf6786
a23d771
aaa63fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
"@gradio/client": patch | ||
"gradio": patch | ||
"gradio_client": patch | ||
--- | ||
|
||
fix:Fix api event drops |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from gradio_client.utils import ( | ||
Communicator, | ||
JobStatus, | ||
Message, | ||
Status, | ||
StatusUpdate, | ||
) | ||
|
@@ -139,10 +140,12 @@ def __init__( | |
self._info = self._get_api_info() | ||
self.session_hash = str(uuid.uuid4()) | ||
|
||
protocol = self.config.get("protocol") | ||
endpoint_class = Endpoint if protocol == "sse" else EndpointV3Compatibility | ||
self.protocol: str = self.config.get("protocol", "ws") | ||
endpoint_class = ( | ||
Endpoint if self.protocol.startswith("sse") else EndpointV3Compatibility | ||
) | ||
self.endpoints = [ | ||
endpoint_class(self, fn_index, dependency) | ||
endpoint_class(self, fn_index, dependency, self.protocol) | ||
for fn_index, dependency in enumerate(self.config["dependencies"]) | ||
] | ||
|
||
|
@@ -152,6 +155,78 @@ def __init__( | |
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 | ||
threading.Thread(target=self._telemetry_thread).start() | ||
|
||
self.stream_open = False | ||
self.streaming_future: Future | None = None | ||
self.pending_messages_per_event: dict[str, list[Message]] = {} | ||
self.pending_event_ids: set[str] = set() | ||
|
||
async def stream_messages(self) -> None: | ||
try: | ||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: | ||
async with client.stream( | ||
"GET", | ||
self.sse_url, | ||
params={"session_hash": self.session_hash}, | ||
headers=self.headers, | ||
cookies=self.cookies, | ||
) as response: | ||
async for line in response.aiter_text(): | ||
if line.startswith("data:"): | ||
resp = json.loads(line[5:]) | ||
if resp["msg"] == "heartbeat": | ||
continue | ||
elif resp["msg"] == "server_stopped": | ||
print("Server stopped!!!", self.src) | ||
for ( | ||
pending_messages | ||
) in self.pending_messages_per_event.values(): | ||
pending_messages.append(resp) | ||
break | ||
event_id = resp["event_id"] | ||
if event_id not in self.pending_messages_per_event: | ||
self.pending_messages_per_event[event_id] = [] | ||
self.pending_messages_per_event[event_id].append(resp) | ||
if resp["msg"] == "process_completed": | ||
self.pending_event_ids.remove(event_id) | ||
if len(self.pending_event_ids) == 0: | ||
self.stream_open = False | ||
return | ||
elif line == "": | ||
continue | ||
else: | ||
raise ValueError(f"Unexpected SSE line: '{line}'") | ||
except BaseException as e: | ||
import traceback | ||
|
||
traceback.print_exc() | ||
raise e | ||
|
||
async def send_data(self, data, hash_data): | ||
async with httpx.AsyncClient() as client: | ||
req = await client.post( | ||
self.sse_data_url, | ||
json={**data, **hash_data}, | ||
headers=self.headers, | ||
cookies=self.cookies, | ||
) | ||
req.raise_for_status() | ||
resp = req.json() | ||
event_id = resp["event_id"] | ||
|
||
if not self.stream_open: | ||
self.stream_open = True | ||
|
||
def open_stream(): | ||
return utils.synchronize_async(self.stream_messages) | ||
|
||
if self.streaming_future is None or self.streaming_future.done(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think right now we're trying to only open one stream for all pending predictions like we do in the js client. I understand we're doing that in the front-end for the browser limitations but do we have to do that in the python client? I don't think we do because we can just connect to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backend currently only supports a stream listening to a session_id. This makes it impossible to have a separate stream per event from the client as you suggest, unless every client call had a separate session_id. We probably don't want that behaviour because of session state. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it! We talked about exploring "api-only" usage next year. Would be good to revisit this then because the api is getting complex (for valid reasons) but it limits the ability to use gradio from another language |
||
self.streaming_future = self.executor.submit(open_stream) | ||
self.streaming_future.add_done_callback( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this print statement |
||
lambda f: print("res:", f.result()) | ||
) | ||
|
||
return event_id | ||
|
||
@classmethod | ||
def duplicate( | ||
cls, | ||
|
@@ -340,7 +415,7 @@ def submit( | |
inferred_fn_index = self._infer_fn_index(api_name, fn_index) | ||
|
||
helper = None | ||
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"): | ||
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"): | ||
helper = self.new_helper(inferred_fn_index) | ||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) | ||
future = self.executor.submit(end_to_end_fn, *args) | ||
|
@@ -806,15 +881,17 @@ class ReplaceMe: | |
class Endpoint: | ||
"""Helper class for storing all the information about a single API endpoint.""" | ||
|
||
def __init__(self, client: Client, fn_index: int, dependency: dict): | ||
def __init__( | ||
self, client: Client, fn_index: int, dependency: dict, protocol: str = "sse_v1" | ||
): | ||
self.client: Client = client | ||
self.fn_index = fn_index | ||
self.dependency = dependency | ||
api_name = dependency.get("api_name") | ||
self.api_name: str | Literal[False] | None = ( | ||
"/" + api_name if isinstance(api_name, str) else api_name | ||
) | ||
self.protocol = "sse" | ||
self.protocol = protocol | ||
self.input_component_types = [ | ||
self._get_component_type(id_) for id_ in dependency["inputs"] | ||
] | ||
|
@@ -891,7 +968,20 @@ def _predict(*data) -> tuple: | |
"session_hash": self.client.session_hash, | ||
} | ||
|
||
result = utils.synchronize_async(self._sse_fn, data, hash_data, helper) | ||
if self.protocol == "sse": | ||
result = utils.synchronize_async( | ||
self._sse_fn_v0, data, hash_data, helper | ||
) | ||
elif self.protocol == "sse_v1": | ||
event_id = utils.synchronize_async( | ||
self.client.send_data, data, hash_data | ||
) | ||
self.client.pending_event_ids.add(event_id) | ||
self.client.pending_messages_per_event[event_id] = [] | ||
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id) | ||
else: | ||
raise ValueError(f"Unsupported protocol: {self.protocol}") | ||
|
||
if "error" in result: | ||
raise ValueError(result["error"]) | ||
|
||
|
@@ -1068,24 +1158,32 @@ def process_predictions(self, *predictions): | |
predictions = self.reduce_singleton_output(*predictions) | ||
return predictions | ||
|
||
async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator): | ||
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): | ||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: | ||
return await utils.get_pred_from_sse( | ||
return await utils.get_pred_from_sse_v0( | ||
client, | ||
data, | ||
hash_data, | ||
helper, | ||
sse_url=self.client.sse_url, | ||
sse_data_url=self.client.sse_data_url, | ||
headers=self.client.headers, | ||
cookies=self.client.cookies, | ||
self.client.sse_url, | ||
self.client.sse_data_url, | ||
self.client.headers, | ||
self.client.cookies, | ||
Comment on lines
+1180
to
+1183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. jw why did we remove the kwargs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no particular reason, seemed arbitrary some were kwargs and some args |
||
) | ||
|
||
async def _sse_fn_v1(self, helper: Communicator, event_id: str): | ||
return await utils.get_pred_from_sse_v1( | ||
helper, | ||
self.client.cookies, | ||
self.client.pending_messages_per_event, | ||
event_id, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you'll need to send headers here, otherwise it won't work with private spaces There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function doesn't make any backend calls, other than check_for_cancel, which doesn't use headers previously either. Will add to check_for_cancel though |
||
) | ||
|
||
|
||
class EndpointV3Compatibility: | ||
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility.""" | ||
|
||
def __init__(self, client: Client, fn_index: int, dependency: dict): | ||
def __init__(self, client: Client, fn_index: int, dependency: dict, *args): | ||
self.client: Client = client | ||
self.fn_index = fn_index | ||
self.dependency = dependency | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print statement