Skip to content

Commit 1be31c1

Browse files
aliabid94Ali Abidgradio-pr-bot
authored
Allow editable ChatInterface (#10229)
* changes * add changeset * changes * changes * changes --------- Co-authored-by: Ali Abid <[email protected]> Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 506bd28 commit 1be31c1

File tree

6 files changed

+60
-3
lines changed

6 files changed

+60
-3
lines changed

.changeset/slimy-pants-hang.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": minor
3+
---
4+
5+
feat:Allow editable ChatInterface
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_messages_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_non_stream_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/multimodal_tuples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "def reset_runs():\n", " global runs\n", " runs = 0\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "chat = gr.ChatInterface(slow_echo, fill_height=True, editable=True)\n", "\n", "with gr.Blocks() as demo:\n", " chat.render()\n", " # We reset the global variable to minimize flakes\n", " # this works because CI runs only one test at at time\n", " # need to use gr.State if we want to parallelize this test\n", " # currently chatinterface does not support that\n", " demo.unload(reset_runs)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

demo/test_chatinterface_streaming_echo/run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def slow_echo(message, history):
1212
for i in range(len(message)):
1313
yield f"Run {runs} - You typed: " + message[: i + 1]
1414

15-
chat = gr.ChatInterface(slow_echo, fill_height=True)
15+
chat = gr.ChatInterface(slow_echo, fill_height=True, editable=True)
1616

1717
with gr.Blocks() as demo:
1818
chat.render()

gradio/chat_interface.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from gradio.components.multimodal_textbox import MultimodalPostprocess, MultimodalValue
3838
from gradio.context import get_blocks_context
39-
from gradio.events import Dependency, SelectData
39+
from gradio.events import Dependency, EditData, SelectData
4040
from gradio.helpers import create_examples as Examples # noqa: N812
4141
from gradio.helpers import special_args, update
4242
from gradio.layouts import Accordion, Column, Group, Row
@@ -75,6 +75,7 @@ def __init__(
7575
additional_inputs: str | Component | list[str | Component] | None = None,
7676
additional_inputs_accordion: str | Accordion | None = None,
7777
additional_outputs: Component | list[Component] | None = None,
78+
editable: bool = False,
7879
examples: list[str] | list[MultimodalValue] | list[list] | None = None,
7980
example_labels: list[str] | None = None,
8081
example_icons: list[str] | None = None,
@@ -108,6 +109,7 @@ def __init__(
108109
type: The format of the messages passed into the chat history parameter of `fn`. If "messages", passes the history as a list of dictionaries with openai-style "role" and "content" keys. The "content" key's value should be one of the following - (1) strings in valid Markdown (2) a dictionary with a "path" key and value corresponding to the file to display or (3) an instance of a Gradio component: at the moment gr.Image, gr.Plot, gr.Video, gr.Gallery, gr.Audio, and gr.HTML are supported. The "role" key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output. If this parameter is 'tuples' (deprecated), passes the chat history as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message.
109110
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
110111
textbox: an instance of the gr.Textbox or gr.MultimodalTextbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox or gr.MultimodalTextbox component will be created.
112+
editable: if True, users can edit past messages to regenerate responses.
111113
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If the components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion. The values of these components will be passed into `fn` as arguments in order after the chat history.
112114
additional_inputs_accordion: if a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
113115
additional_outputs: an instance or list of instances of gradio components to use as additional outputs from the chat function. These must be components that are already defined in the same Blocks scope. If provided, the chat function should return additional values for these components. See $demo/chatinterface_artifacts.
@@ -173,6 +175,7 @@ def __init__(
173175
self.run_examples_on_click = run_examples_on_click
174176
self.cache_examples = cache_examples
175177
self.cache_mode = cache_mode
178+
self.editable = editable
176179
self.additional_inputs = [
177180
get_component_instance(i)
178181
for i in utils.none_or_singleton_to_list(additional_inputs)
@@ -490,6 +493,14 @@ def _setup_events(self) -> None:
490493

491494
self.chatbot.clear(**synchronize_chat_state_kwargs)
492495

496+
if self.editable:
497+
self.chatbot.edit(
498+
self._edit_message,
499+
[self.chatbot],
500+
[self.chatbot, self.chatbot_state, self.saved_input],
501+
show_api=False,
502+
).success(**submit_fn_kwargs).success(**synchronize_chat_state_kwargs)
503+
493504
def _setup_stop_events(
494505
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
495506
) -> None:
@@ -712,6 +723,19 @@ def example_populated(self, example: SelectData):
712723
else:
713724
return example.value["text"]
714725

726+
def _edit_message(
727+
self, history: list[MessageDict] | TupleFormat, edit_data: EditData
728+
) -> tuple[
729+
list[MessageDict] | TupleFormat,
730+
list[MessageDict] | TupleFormat,
731+
str | MultimodalPostprocess,
732+
]:
733+
if isinstance(edit_data.index, (list, tuple)):
734+
history = history[: edit_data.index[0]]
735+
else:
736+
history = history[: edit_data.index]
737+
return history, history, edit_data.value
738+
715739
def example_clicked(
716740
self, example: SelectData
717741
) -> Generator[

gradio/events.py

+3
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,9 @@ class Events:
906906
edit = EventListener(
907907
"edit",
908908
doc="This listener is triggered when the user edits the {{ component }} (e.g. image) using the built-in editor.",
909+
callback=lambda block: setattr(block, "editable", "user")
910+
if getattr(block, "editable", None) is None
911+
else None,
909912
)
910913
clear = EventListener(
911914
"clear",

js/spa/test/test_chatinterface_streaming_echo.spec.ts

+25
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,28 @@ test("test stopping generation", async ({ page }) => {
115115
await expect(current_content).toBe(new_content);
116116
await expect(new_content!.length).toBeLessThan(3000);
117117
});
118+
119+
test("editing messages", async ({ page }) => {
120+
const submit_button = page.locator(".submit-button");
121+
const textbox = page.locator(".input-container textarea");
122+
const chatbot = page.getByLabel("chatbot conversation");
123+
124+
await textbox.fill("Lets");
125+
await submit_button.click();
126+
await expect(chatbot).toContainText("You typed: Lets");
127+
128+
await textbox.fill("Test");
129+
await submit_button.click();
130+
await expect(chatbot).toContainText("You typed: Test");
131+
132+
await textbox.fill("This");
133+
await submit_button.click();
134+
await expect(chatbot).toContainText("You typed: This");
135+
136+
await page.getByLabel("Edit").nth(1).click();
137+
await page.getByLabel("chatbot conversation").getByRole("textbox").fill("Do");
138+
await page.getByLabel("Submit").click();
139+
140+
await expect(chatbot).toContainText("You typed: Do");
141+
await expect(chatbot).not.toContainText("You typed: This");
142+
});

0 commit comments

Comments
 (0)