Skip to content

Fix api event drops #6556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e91912d
changes
aliabid94 Nov 22, 2023
0092230
changes
aliabid94 Nov 23, 2023
85ae8f9
add changeset
gradio-pr-bot Nov 23, 2023
1885bb9
changes
aliabid94 Nov 24, 2023
b7d35c8
Merge branch 'fix_api_event_dops' of https://github.com/gradio-app/gr…
aliabid94 Nov 24, 2023
57c6e5b
changes
aliabid94 Nov 29, 2023
ca73780
changes
aliabid94 Nov 29, 2023
ba37ca8
changs
aliabid94 Nov 29, 2023
41eb16a
chagnes
aliabid94 Nov 29, 2023
f1d8baf
changes
Dec 5, 2023
31cf67f
changes
Dec 5, 2023
ca1d7fa
changes
Dec 5, 2023
6a66067
changes
Dec 5, 2023
4ceb71a
changes
Dec 5, 2023
edc95e4
changes
Dec 5, 2023
a7db781
changes
Dec 5, 2023
c6d2bc9
changes
Dec 5, 2023
e9db998
Merge remote-tracking branch 'origin' into fix_api_event_dops
Dec 5, 2023
8456e68
changes
Dec 5, 2023
2fa939c
changes
Dec 5, 2023
fdaa96b
changes~git push
Dec 5, 2023
f27f336
changes
Dec 6, 2023
27ad025
changes
Dec 6, 2023
8ea90a4
chagmes
Dec 6, 2023
acd1f1d
changes
Dec 7, 2023
692f2d9
changes
Dec 7, 2023
64d30bd
Merge remote-tracking branch 'origin' into fix_api_event_dops
Dec 7, 2023
1f53bf3
changes
Dec 7, 2023
9fa6d3d
changes
Dec 7, 2023
1ff6e7d
Merge branch 'main' into fix_api_event_dops
abidlabs Dec 7, 2023
448f9c4
changes
Dec 11, 2023
7dbe37c
Merge remote-tracking branch 'origin' into fix_api_event_dops
Dec 11, 2023
7f6ebe4
changes
Dec 11, 2023
5c6349b
changes
Dec 11, 2023
7cde497
Merge remote-tracking branch 'origin' into fix_api_event_dops
Dec 11, 2023
b722e46
changes
Dec 12, 2023
6e5a093
changes
Dec 12, 2023
5c598a0
changes
Dec 12, 2023
b7ef2c5
change
Dec 12, 2023
778f15b
changes
Dec 12, 2023
bba88da
changes
Dec 12, 2023
caf6786
Merge remote-tracking branch 'origin' into fix_api_event_dops
Dec 12, 2023
a23d771
changes
Dec 12, 2023
aaa63fb
changes
Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/ripe-spiders-love.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/client": patch
"gradio": patch
"gradio_client": patch
---

fix:Fix api event drops
146 changes: 144 additions & 2 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ export function api_factory(

const session_hash = Math.random().toString(36).substring(2);
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
let config: Config;
let api_map: Record<string, number> = {};

Expand Down Expand Up @@ -437,7 +440,7 @@ export function api_factory(

let websocket: WebSocket;
let eventSource: EventSource;
let protocol = config.protocol ?? "sse";
let protocol = config.protocol ?? "ws";

const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
let payload: Payload;
Expand Down Expand Up @@ -646,7 +649,7 @@ export function api_factory(
websocket.send(JSON.stringify({ hash: session_hash }))
);
}
} else {
} else if (protocol == "sse") {
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -766,6 +769,121 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
fire_event({
type: "status",
stage: "pending",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});

post_data(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/data?${url_params}`,
{
...payload,
session_hash
},
hf_token
).then(([response, status]) => {
if (status !== 200) {
fire_event({
type: "status",
stage: "error",
message: BROKEN_CONNECTION_MSG,
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else {
event_id = response.event_id as string;
if (!stream_open) {
open_stream();
}

let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);

if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
} else if (type === "complete") {
complete = status;
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});

if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
}

if (status.stage === "complete" || status.stage === "error") {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
if (Object.keys(event_callbacks).length === 0) {
close_stream();
}
}
}
};
event_callbacks[event_id] = callback;
}
});
}
});

Expand Down Expand Up @@ -864,6 +982,30 @@ export function api_factory(
};
}

function open_stream(): void {
stream_open = true;
let params = new URLSearchParams({
session_hash: session_hash
}).toString();
let url = new URL(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/join?${params}`
);
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
await event_callbacks[_data.event_id](_data);
};
}

function close_stream(): void {
stream_open = false;
event_stream?.close();
}

async function component_server(
component_id: number,
fn_name: string,
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse" | "ws";
protocol?: "sse_v1" | "sse" | "ws";
}

export interface Payload {
Expand Down
126 changes: 112 additions & 14 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from gradio_client.utils import (
Communicator,
JobStatus,
Message,
Status,
StatusUpdate,
)
Expand Down Expand Up @@ -139,10 +140,12 @@ def __init__(
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())

protocol = self.config.get("protocol")
endpoint_class = Endpoint if protocol == "sse" else EndpointV3Compatibility
self.protocol: str = self.config.get("protocol", "ws")
endpoint_class = (
Endpoint if self.protocol.startswith("sse") else EndpointV3Compatibility
)
self.endpoints = [
endpoint_class(self, fn_index, dependency)
endpoint_class(self, fn_index, dependency, self.protocol)
for fn_index, dependency in enumerate(self.config["dependencies"])
]

Expand All @@ -152,6 +155,78 @@ def __init__(
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
threading.Thread(target=self._telemetry_thread).start()

self.stream_open = False
self.streaming_future: Future | None = None
self.pending_messages_per_event: dict[str, list[Message]] = {}
self.pending_event_ids: set[str] = set()

async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
async with client.stream(
"GET",
self.sse_url,
params={"session_hash": self.session_hash},
headers=self.headers,
cookies=self.cookies,
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
resp = json.loads(line[5:])
if resp["msg"] == "heartbeat":
continue
elif resp["msg"] == "server_stopped":
print("Server stopped!!!", self.src)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print statement

for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
break
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == "process_completed":
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
elif line == "":
continue
else:
raise ValueError(f"Unexpected SSE line: '{line}'")
except BaseException as e:
import traceback

traceback.print_exc()
raise e

async def send_data(self, data, hash_data):
async with httpx.AsyncClient() as client:
req = await client.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
req.raise_for_status()
resp = req.json()
event_id = resp["event_id"]

if not self.stream_open:
self.stream_open = True

def open_stream():
return utils.synchronize_async(self.stream_messages)

if self.streaming_future is None or self.streaming_future.done():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think right now we're trying to only open one stream for all pending predictions like we do in the js client. I understand we're doing that in the front-end for the browser limitations but do we have to do that in the python client? I don't think we do because we can just connect to queue/join and listen for events with the right id? I think it would be simpler if every prediction opened it's own stream. Would make this easier to maintain and also add new clients in the future for other languages.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend currently only supports a stream listening to a session_id. This makes it impossible to have a separate stream per event from the client as you suggest, unless every client call had a separate session_id. We probably don't want that behaviour because of session state.
I like the idea of a stream per event, but that would mean creating a separate endpoint just for that in the backend. It would be complex to support both endpoints at the same time, something we could consider in another PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! We talked about exploring "api-only" usage next year. Would be good to revisit this then because the api is getting complex (for valid reasons) but it limits the ability to use gradio from another language

self.streaming_future = self.executor.submit(open_stream)
self.streaming_future.add_done_callback(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this print statement

lambda f: print("res:", f.result())
)

return event_id

@classmethod
def duplicate(
cls,
Expand Down Expand Up @@ -340,7 +415,7 @@ def submit(
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"):
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
Expand Down Expand Up @@ -806,15 +881,17 @@ class ReplaceMe:
class Endpoint:
"""Helper class for storing all the information about a single API endpoint."""

def __init__(self, client: Client, fn_index: int, dependency: dict):
def __init__(
self, client: Client, fn_index: int, dependency: dict, protocol: str = "sse_v1"
):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
api_name = dependency.get("api_name")
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.protocol = "sse"
self.protocol = protocol
self.input_component_types = [
self._get_component_type(id_) for id_ in dependency["inputs"]
]
Expand Down Expand Up @@ -891,7 +968,20 @@ def _predict(*data) -> tuple:
"session_hash": self.client.session_hash,
}

result = utils.synchronize_async(self._sse_fn, data, hash_data, helper)
if self.protocol == "sse":
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

if "error" in result:
raise ValueError(result["error"])

Expand Down Expand Up @@ -1068,24 +1158,32 @@ def process_predictions(self, *predictions):
predictions = self.reduce_singleton_output(*predictions)
return predictions

async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator):
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
return await utils.get_pred_from_sse(
return await utils.get_pred_from_sse_v0(
client,
data,
hash_data,
helper,
sse_url=self.client.sse_url,
sse_data_url=self.client.sse_data_url,
headers=self.client.headers,
cookies=self.client.cookies,
self.client.sse_url,
self.client.sse_data_url,
self.client.headers,
self.client.cookies,
Comment on lines +1180 to +1183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jw why did we remove the kwargs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no particular reason, seemed arbitrary some were kwargs and some args

)

async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
helper,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'll need to send headers here, otherwise it won't work with private spaces

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function doesn't make any backend calls, other than check_for_cancel, which doesn't use headers previously either. Will add to check_for_cancel though

)


class EndpointV3Compatibility:
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""

def __init__(self, client: Client, fn_index: int, dependency: dict):
def __init__(self, client: Client, fn_index: int, dependency: dict, *args):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
Expand Down
Loading