Skip to content

Commit 369a44e

Browse files
abidlabsgradio-pr-bothannahblair
authored
Add ability to provide preset response options in gr.Chatbot / gr.ChatInterface (#9989)
* options * add changeset * list * types * add changeset * types * docs * changes * more docs * chatbot * changes * changes * changes * format * notebooks * typedict * docs * console logs * docs * docs * styling * docs * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Hannah <[email protected]> * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Hannah <[email protected]> * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Hannah <[email protected]> * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Hannah <[email protected]> * Update guides/05_chatbots/02_chat_interface_examples.md Co-authored-by: Hannah <[email protected]> * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Hannah <[email protected]> * restore --------- Co-authored-by: gradio-pr-bot <[email protected]> Co-authored-by: Hannah <[email protected]>
1 parent 74f22d5 commit 369a44e

16 files changed

+447
-206
lines changed

.changeset/orange-cobras-suffer.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@gradio/chatbot": minor
3+
"gradio": minor
4+
---
5+
6+
feat:Add ability to provide preset response options in `gr.Chatbot` / `gr.ChatInterface`

demo/chatinterface_options/run.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_options"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "example_code = \"\"\"\n", "Here's the code I generated:\n", "\n", "```python\n", "def greet(x):\n", " return f\"Hello, {x}!\"\n", "```\n", "\n", "Is this correct?\n", "\"\"\"\n", "\n", "def chat(message, history):\n", " if message == \"Yes, that's correct.\":\n", " return \"Great!\"\n", " else:\n", " return {\n", " \"role\": \"assistant\",\n", " \"content\": example_code,\n", " \"options\": [\n", " {\"value\": \"Yes, that's correct.\", \"label\": \"Yes\"},\n", " {\"value\": \"No\"}\n", " ]\n", " }\n", "\n", "demo = gr.ChatInterface(\n", " chat,\n", " type=\"messages\",\n", " examples=[\"Write a Python function that takes a string and returns a greeting.\"]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

demo/chatinterface_options/run.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import gradio as gr
2+
3+
example_code = """
4+
Here's the code I generated:
5+
6+
```python
7+
def greet(x):
8+
return f"Hello, {x}!"
9+
```
10+
11+
Is this correct?
12+
"""
13+
14+
def chat(message, history):
15+
if message == "Yes, that's correct.":
16+
return "Great!"
17+
else:
18+
return {
19+
"role": "assistant",
20+
"content": example_code,
21+
"options": [
22+
{"value": "Yes, that's correct.", "label": "Yes"},
23+
{"value": "No"}
24+
]
25+
}
26+
27+
demo = gr.ChatInterface(
28+
chat,
29+
type="messages",
30+
examples=["Write a Python function that takes a string and returns a greeting."]
31+
)
32+
33+
if __name__ == "__main__":
34+
demo.launch()

gradio/chat_interface.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,24 @@ def _setup_events(self) -> None:
444444
queue=False,
445445
)
446446

447+
self.chatbot.option_select(
448+
self.option_clicked,
449+
[self.chatbot],
450+
[self.chatbot, self.saved_input],
451+
show_api=False,
452+
).then(
453+
submit_fn,
454+
[self.saved_input, self.chatbot],
455+
[self.chatbot],
456+
show_api=False,
457+
concurrency_limit=cast(
458+
Union[int, Literal["default"], None], self.concurrency_limit
459+
),
460+
show_progress=cast(
461+
Literal["full", "minimal", "hidden"], self.show_progress
462+
),
463+
)
464+
447465
def _setup_stop_events(
448466
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
449467
) -> None:
@@ -686,6 +704,18 @@ async def _stream_fn(
686704
self._append_history(history_with_input, response, first_response=False)
687705
yield history_with_input
688706

707+
def option_clicked(
708+
self, history: list[MessageDict], option: SelectData
709+
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:
710+
"""
711+
When an option is clicked, the chat history is appended with the option value.
712+
The saved input value is also set to option value. Note that event can only
713+
be called if self.type is "messages" since options are only available for this
714+
chatbot type.
715+
"""
716+
history.append({"role": "user", "content": option.value})
717+
return history, option.value
718+
689719
def example_clicked(
690720
self, example: SelectData
691721
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:

gradio/components/chatbot.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
Any,
1313
Literal,
1414
Optional,
15-
TypedDict,
1615
Union,
1716
cast,
1817
)
1918

2019
from gradio_client import utils as client_utils
2120
from gradio_client.documentation import document
2221
from pydantic import Field
23-
from typing_extensions import NotRequired
22+
from typing_extensions import NotRequired, TypedDict
2423

2524
from gradio import utils
2625
from gradio.component_meta import ComponentMeta
@@ -37,6 +36,11 @@ class MetadataDict(TypedDict):
3736
title: Union[str, None]
3837

3938

39+
class Option(TypedDict):
40+
label: NotRequired[str]
41+
value: str
42+
43+
4044
class FileDataDict(TypedDict):
4145
path: str # server filepath
4246
url: NotRequired[Optional[str]] # normalised server url
@@ -51,6 +55,7 @@ class MessageDict(TypedDict):
5155
content: str | FileDataDict | tuple | Component
5256
role: Literal["user", "assistant", "system"]
5357
metadata: NotRequired[MetadataDict]
58+
options: NotRequired[list[Option]]
5459

5560

5661
class FileMessage(GradioModel):
@@ -82,6 +87,7 @@ class Message(GradioModel):
8287
role: str
8388
metadata: Metadata = Field(default_factory=Metadata)
8489
content: Union[str, FileMessage, ComponentMessage]
90+
options: Optional[list[Option]] = None
8591

8692

8793
class ExampleMessage(TypedDict):
@@ -102,6 +108,7 @@ class ChatMessage:
102108
role: Literal["user", "assistant", "system"]
103109
content: str | FileData | Component | FileDataDict | tuple | list
104110
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
111+
options: Optional[list[Option]] = None
105112

106113

107114
class ChatbotDataMessages(GradioRootModel):
@@ -150,6 +157,7 @@ class Chatbot(Component):
150157
Events.retry,
151158
Events.undo,
152159
Events.example_select,
160+
Events.option_select,
153161
Events.clear,
154162
Events.copy,
155163
]
@@ -502,6 +510,7 @@ def _postprocess_message_messages(
502510
role=message.role,
503511
content=message.content, # type: ignore
504512
metadata=message.metadata, # type: ignore
513+
options=message.options,
505514
)
506515
elif isinstance(message, Message):
507516
return message

gradio/events.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,12 @@ class Events:
963963
)
964964
example_select = EventListener(
965965
"example_select",
966-
config_data=lambda: {"example_selectable": False},
967-
callback=lambda block: setattr(block, "example_selectable", True),
968966
doc="This listener is triggered when the user clicks on an example from within the {{ component }}. This event has SelectData of type gradio.SelectData that carries information, accessible through SelectData.index and SelectData.value. See SelectData documentation on how to use this event data.",
969967
)
968+
option_select = EventListener(
969+
"option_select",
970+
doc="This listener is triggered when the user clicks on an option from within the {{ component }}. This event has SelectData of type gradio.SelectData that carries information, accessible through SelectData.index and SelectData.value. See SelectData documentation on how to use this event data.",
971+
)
970972
load = EventListener(
971973
"load",
972974
doc="This listener is triggered when the {{ component }} initially loads in the browser.",

0 commit comments

Comments
 (0)