Skip to content

Commit db162bf

Browse files
enable lazy caching for chatinterface (#10015)
* lazy chat * add changeset * lazy caching * lazy caching * revert * fix this * changes * changes * format * changes * add env variable * revert * add changeset * lazy * fix * chat interface * fix test --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 369a44e commit db162bf

File tree

9 files changed

+87
-26
lines changed

9 files changed

+87
-26
lines changed

.changeset/many-horses-judge.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:enable lazy caching for chatinterface

.config/playwright-setup.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ function spawn_gradio_app(app, port, verbose) {
8484
...process.env,
8585
PYTHONUNBUFFERED: "true",
8686
GRADIO_ANALYTICS_ENABLED: "False",
87-
GRADIO_IS_E2E_TEST: "1"
87+
GRADIO_IS_E2E_TEST: "1",
88+
GRADIO_RESET_EXAMPLES_CACHE: "True"
8889
}
8990
});
9091
_process.stdout.setEncoding("utf8");
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import gradio as gr
2+
3+
def generate(
4+
message: str,
5+
chat_history: list[dict],
6+
):
7+
8+
output = ""
9+
for character in message:
10+
output += character
11+
yield output
12+
13+
14+
demo = gr.ChatInterface(
15+
fn=generate,
16+
examples=[
17+
["Hey"],
18+
["Can you explain briefly to me what is the Python programming language?"],
19+
],
20+
cache_examples=True,
21+
cache_mode="lazy",
22+
type="messages",
23+
)
24+
25+
26+
if __name__ == "__main__":
27+
demo.launch()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/eager_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/eager_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/lazy_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

gradio/chat_interface.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import functools
99
import inspect
1010
import warnings
11-
from collections.abc import AsyncGenerator, Callable, Sequence
11+
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
1212
from pathlib import Path
1313
from typing import Literal, Union, cast
1414

@@ -111,7 +111,7 @@ def __init__(
111111
example_labels: labels for the examples, to be displayed instead of the examples themselves. If provided, should be a list of strings with the same length as the examples list. Only applies when examples are displayed within the chatbot (i.e. when `additional_inputs` is not provided).
112112
example_icons: icons for the examples, to be displayed above the examples. If provided, should be a list of string URLs or local paths with the same length as the examples list. Only applies when examples are displayed within the chatbot (i.e. when `additional_inputs` is not provided).
113113
cache_examples: if True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
114-
cache_mode: if "eager", all examples are cached at app launch. The "lazy" option is not yet supported. If None, will use the GRADIO_CACHE_MODE environment variable if defined, or default to "eager".
114+
cache_mode: if "eager", all examples are cached at app launch. If "lazy", examples are cached for all users after the first use by any user of the app. If None, will use the GRADIO_CACHE_MODE environment variable if defined, or default to "eager".
115115
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
116116
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
117117
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
@@ -369,7 +369,7 @@ def _setup_events(self) -> None:
369369
and self.examples
370370
and not self._additional_inputs_in_examples
371371
):
372-
if self.cache_examples and self.cache_mode == "eager":
372+
if self.cache_examples:
373373
self.chatbot.example_select(
374374
self.example_clicked,
375375
None,
@@ -718,15 +718,15 @@ def option_clicked(
718718

719719
def example_clicked(
720720
self, example: SelectData
721-
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:
721+
) -> Generator[
722+
tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess], None, None
723+
]:
722724
"""
723-
When an example is clicked, the chat history is set to the complete example value
724-
(including files). The saved input value is also set to complete example value
725-
if multimodal is True, otherwise it is set to the text of the example.
725+
When an example is clicked, the chat history (and saved input) is initially set only
726+
to the example message. Then, if example caching is enabled, the cached response is loaded
727+
and added to the chat history as well.
726728
"""
727-
if self.cache_examples and self.cache_mode == "eager":
728-
history = self.examples_handler.load_from_cache(example.index)[0].root
729-
elif self.type == "tuples":
729+
if self.type == "tuples":
730730
history = [(example.value["text"], None)]
731731
for file in example.value.get("files", []):
732732
history.append(((file["path"]), None))
@@ -735,7 +735,10 @@ def example_clicked(
735735
for file in example.value.get("files", []):
736736
history.append(MessageDict(role="user", content=file))
737737
message = example.value if self.multimodal else example.value["text"]
738-
return history, message
738+
yield history, message
739+
if self.cache_examples:
740+
history = self.examples_handler.load_from_cache(example.index)[0].root
741+
yield history, message
739742

740743
def _process_example(
741744
self, message: ExampleMessage | str, response: MessageDict | str | None

gradio/helpers.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import csv
1010
import inspect
1111
import os
12+
import shutil
1213
import warnings
1314
from collections.abc import Callable, Iterable, Sequence
1415
from functools import partial
@@ -276,6 +277,11 @@ def __init__(
276277
simplify_file_data=False, verbose=False, dataset_file_name="log.csv"
277278
)
278279
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
280+
if (
281+
os.environ.get("GRADIO_RESET_EXAMPLES_CACHE") == "True"
282+
and self.cached_folder.exists()
283+
):
284+
shutil.rmtree(self.cached_folder)
279285
self.cached_file = Path(self.cached_folder) / "log.csv"
280286
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
281287
self.run_on_click = run_on_click
@@ -495,13 +501,15 @@ def sync_lazy_cache(self, example_value: tuple[int, list[Any]], *input_values):
495501
with open(self.cached_indices_file, "a") as f:
496502
f.write(f"{example_index}\n")
497503

498-
async def cache(self) -> None:
504+
async def cache(self, example_id: int | None = None) -> None:
499505
"""
500506
Caches examples so that their predictions can be shown immediately.
507+
Parameters:
508+
example_id: The id of the example to process (zero-indexed). If None, all examples are cached.
501509
"""
502510
if self.root_block is None:
503511
raise Error("Cannot cache examples if not in a Blocks context.")
504-
if Path(self.cached_file).exists():
512+
if Path(self.cached_file).exists() and example_id is None:
505513
print(
506514
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
507515
)
@@ -548,6 +556,8 @@ async def get_final_item(*args):
548556
if self.outputs is None:
549557
raise ValueError("self.outputs is missing")
550558
for i, example in enumerate(self.examples):
559+
if example_id is not None and i != example_id:
560+
continue
551561
print(f"Caching example {i + 1}/{len(self.examples)}")
552562
processed_input = self._get_processed_example(example)
553563
if self.batch:
@@ -574,6 +584,16 @@ def load_from_cache(self, example_id: int) -> list[Any]:
574584
Parameters:
575585
example_id: The id of the example to process (zero-indexed).
576586
"""
587+
if self.cache_examples == "lazy":
588+
if cached_index := self._get_cached_index_if_cached(example_id) is None:
589+
client_utils.synchronize_async(self.cache, example_id)
590+
with open(self.cached_indices_file, "a") as f:
591+
f.write(f"{example_id}\n")
592+
with open(self.cached_indices_file) as f:
593+
example_id = len(f.readlines()) - 1
594+
else:
595+
example_id = cached_index
596+
577597
with open(self.cached_file, encoding="utf-8") as cache:
578598
examples = list(csv.reader(cache))
579599
example = examples[example_id + 1] # +1 to adjust for header

guides/04_additional-features/09_environment-variables.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ Environment variables in Gradio provide a way to customize your applications and
163163
export GRADIO_NODE_NUM_PORTS=200
164164
```
165165

166+
### 18. `GRADIO_RESET_EXAMPLES_CACHE`
167+
168+
- **Description**: If set to "True", Gradio will delete and recreate the examples cache directory when the app starts instead of reusing the cached example if they already exist.
169+
- **Default**: `"False"`
170+
- **Options**: `"True"`, `"False"`
171+
- **Example**:
172+
```sh
173+
export GRADIO_RESET_EXAMPLES_CACHE="True"
174+
```
175+
166176
## How to Set Environment Variables
167177

168178
To set environment variables in your terminal, use the `export` command followed by the variable name and its value. For example:

js/spa/test/test_chatinterface_examples.spec.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ const cases = [
55
"tuples_examples",
66
"multimodal_tuples_examples",
77
"multimodal_messages_examples",
8-
"eager_caching_examples"
8+
"eager_caching_examples",
9+
"lazy_caching_examples"
910
];
1011

1112
for (const test_case of cases) {

test/test_chat_interface.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_example_caching(self, connect):
9595
assert prediction_hi[0].root[0] == ("hi", "hi hi")
9696

9797
@pytest.mark.asyncio
98-
async def test_example_caching_lazy(self, connect):
98+
async def test_example_caching_lazy(self):
9999
with patch(
100100
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
101101
):
@@ -105,16 +105,10 @@ async def test_example_caching_lazy(self, connect):
105105
cache_examples=True,
106106
cache_mode="lazy",
107107
)
108-
async for _ in chatbot.examples_handler.async_lazy_cache(
109-
(0, ["hello"]), "hello"
110-
):
111-
pass
112-
with connect(chatbot):
113-
prediction_hello = chatbot.examples_handler.load_from_cache(0)
108+
prediction_hello = chatbot.examples_handler.load_from_cache(0)
114109
assert prediction_hello[0].root[0] == ("hello", "hello hello")
115-
with pytest.raises(IndexError):
116-
prediction_hi = chatbot.examples_handler.load_from_cache(1)
117-
assert prediction_hi[0].root[0] == ("hi", "hi hi")
110+
prediction_hi = chatbot.examples_handler.load_from_cache(1)
111+
assert prediction_hi[0].root[0] == ("hi", "hi hi")
118112

119113
def test_example_caching_async(self, connect):
120114
with patch(

0 commit comments

Comments
 (0)