Skip to content

Commit a1c21cb

Browse files
fix dataset update (#8581)
* fix dataset update * revert' * add changeset * add test * add changeset * changes * add template * add changeset * fix docstring * test postprocessing --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 2b0c157 commit a1c21cb

File tree

10 files changed

+102
-37
lines changed

10 files changed

+102
-37
lines changed

.changeset/soft-worms-remain.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@gradio/dataset": patch
3+
"gradio": patch
4+
"website": patch
5+
---
6+
7+
fix:fix dataset update

gradio/blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,12 +1723,12 @@ async def postprocess_data(
17231723
) from err
17241724

17251725
if block.stateful:
1726-
if not utils.is_update(predictions[i]):
1726+
if not utils.is_prop_update(predictions[i]):
17271727
state[block._id] = predictions[i]
17281728
output.append(None)
17291729
else:
17301730
prediction_value = predictions[i]
1731-
if utils.is_update(
1731+
if utils.is_prop_update(
17321732
prediction_value
17331733
): # if update is passed directly (deprecated), remove Nones
17341734
prediction_value = utils.delete_none(
@@ -1738,7 +1738,7 @@ async def postprocess_data(
17381738
if isinstance(prediction_value, Block):
17391739
prediction_value = prediction_value.constructor_args.copy()
17401740
prediction_value["__type__"] = "update"
1741-
if utils.is_update(prediction_value):
1741+
if utils.is_prop_update(prediction_value):
17421742
kwargs = state[block._id].constructor_args.copy()
17431743
kwargs.update(prediction_value)
17441744
kwargs.pop("value", None)

gradio/components/dataset.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from typing import Any, Literal
67

78
from gradio_client.documentation import document
@@ -17,7 +18,8 @@
1718
@document()
1819
class Dataset(Component):
1920
"""
20-
Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
21+
Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
22+
However, it can also be used directly to display a dataset and let users select examples.
2123
"""
2224

2325
EVENTS = [Events.click, Events.select]
@@ -26,7 +28,7 @@ def __init__(
2628
self,
2729
*,
2830
label: str | None = None,
29-
components: list[Component] | list[str],
31+
components: list[Component] | list[str] | None = None,
3032
component_props: list[dict[str, Any]] | None = None,
3133
samples: list[list[Any]] | None = None,
3234
headers: list[str] | None = None,
@@ -70,7 +72,7 @@ def __init__(
7072
self.container = container
7173
self.scale = scale
7274
self.min_width = min_width
73-
self._components = [get_component_instance(c) for c in components]
75+
self._components = [get_component_instance(c) for c in components or []]
7476
if component_props is None:
7577
self.component_props = [
7678
component.recover_kwargs(
@@ -131,29 +133,39 @@ def get_config(self):
131133

132134
return config
133135

134-
def preprocess(self, payload: int) -> int | list | None:
136+
def preprocess(self, payload: int | None) -> int | list | None:
135137
"""
136138
Parameters:
137139
payload: the index of the selected example in the dataset
138140
Returns:
139141
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
140142
"""
143+
if payload is None:
144+
return None
141145
if self.type == "index":
142146
return payload
143147
elif self.type == "values":
144148
return self.samples[payload]
145149

146-
def postprocess(self, samples: list[list]) -> dict:
150+
def postprocess(self, sample: int | list | None) -> int | None:
147151
"""
148152
Parameters:
149-
samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
153+
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
150154
Returns:
151-
Returns the updated dataset data as a `dict` with the key "samples".
155+
Returns the index of the sample in the dataset.
152156
"""
153-
return {
154-
"samples": samples,
155-
"__type__": "update",
156-
}
157+
if sample is None or isinstance(sample, int):
158+
return sample
159+
if isinstance(sample, list):
160+
try:
161+
index = self.samples.index(sample)
162+
except ValueError:
163+
index = None
164+
warnings.warn(
165+
"The `Dataset` component does not support updating the dataset data by providing "
166+
"a set of list values. Instead, you should return a new Dataset(samples=...) object."
167+
)
168+
return index
157169

158170
def example_payload(self) -> Any:
159171
return 0

gradio/flagging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def flag(
164164
) / client_utils.strip_invalid_filename_characters(
165165
getattr(component, "label", None) or f"component {idx}"
166166
)
167-
if utils.is_update(sample):
167+
if utils.is_prop_update(sample):
168168
csv_data.append(str(sample))
169169
else:
170170
data = (

gradio/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def load_from_cache(self, example_id: int) -> list[Any]:
544544
component, components.File
545545
):
546546
value_to_use = value_as_dict
547-
if not utils.is_update(value_as_dict):
547+
if not utils.is_prop_update(value_as_dict):
548548
raise TypeError("value wasn't an update") # caught below
549549
output.append(value_as_dict)
550550
except (ValueError, TypeError, SyntaxError):

gradio/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
737737
return False
738738

739739

740-
def is_update(val):
740+
def is_prop_update(val):
741741
return isinstance(val, dict) and "update" in val.get("__type__", "")
742742

743743

js/_website/src/lib/templates/gradio/03_components/dataset.svx

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,40 @@ def predict(···) -> list[list]
8686
<DemosSection demos={obj.demos} />
8787
{/if}
8888

89+
### Examples
90+
91+
**Updating a Dataset**
92+
93+
In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:
94+
95+
```py
96+
import gradio as gr
97+
98+
philosophy_quotes = [
99+
["I think therefore I am."],
100+
["The unexamined life is not worth living."]
101+
]
102+
103+
startup_quotes = [
104+
["Ideas are easy. Implementation is hard"],
105+
["Make mistakes faster."]
106+
]
107+
108+
def show_startup_quotes():
109+
return gr.Dataset(samples=startup_quotes)
110+
111+
with gr.Blocks() as demo:
112+
textbox = gr.Textbox()
113+
dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
114+
button = gr.Button()
115+
116+
button.click(show_startup_quotes, None, dataset)
117+
118+
demo.launch()
119+
```
120+
121+
122+
89123
{#if obj.fns && obj.fns.length > 0}
90124
<!--- Event Listeners -->
91125
### Event Listeners
@@ -97,3 +131,4 @@ def predict(···) -> list[list]
97131
### Guides
98132
<GuidesSection guides={obj.guides}/>
99133
{/if}
134+

js/dataset/Index.svelte

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
>;
1313
export let label = "Examples";
1414
export let headers: string[];
15-
export let samples: any[][];
15+
export let samples: any[][] | null = null;
1616
export let elem_id = "";
1717
export let elem_classes: string[] = [];
1818
export let visible = true;
@@ -34,7 +34,7 @@
3434
: `${root}/file=`;
3535
let page = 0;
3636
$: gallery = components.length < 2;
37-
let paginate = samples.length > samples_per_page;
37+
let paginate = samples ? samples.length > samples_per_page : false;
3838
3939
let selected_samples: any[][];
4040
let page_count: number;
@@ -51,6 +51,7 @@
5151
}
5252
5353
$: {
54+
samples = samples || [];
5455
paginate = samples.length > samples_per_page;
5556
if (paginate) {
5657
visible_pages = [];

test/components/test_dataset.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,10 @@ def test_preprocessing(self):
4343
assert dataset.samples == [["value 1"], ["value 2"]]
4444

4545
def test_postprocessing(self):
46-
test_file_dir = Path(Path(__file__).parent, "test_files")
47-
bus = Path(test_file_dir, "bus.png")
48-
4946
dataset = gr.Dataset(
5047
components=["number", "textbox", "image", "html", "markdown"], type="index"
5148
)
52-
53-
output = dataset.postprocess(
54-
samples=[
55-
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
56-
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
57-
],
58-
)
59-
60-
assert output == {
61-
"samples": [
62-
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
63-
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
64-
],
65-
"__type__": "update",
66-
}
49+
assert dataset.postprocess(1) == 1
6750

6851

6952
@patch(

test/test_blocks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,33 @@ def infer(a, b):
732732
):
733733
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)
734734

735+
@pytest.mark.asyncio
736+
async def test_dataset_is_updated(self):
737+
def update(value):
738+
return value, gr.Dataset(samples=[["New A"], ["New B"]])
739+
740+
with gr.Blocks() as demo:
741+
with gr.Row():
742+
textbox = gr.Textbox()
743+
dataset = gr.Dataset(
744+
components=["text"], samples=[["Original"]], label="Saved Prompts"
745+
)
746+
dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
747+
app, _, _ = demo.launch(prevent_thread_lock=True)
748+
749+
client = TestClient(app)
750+
751+
session_1 = client.post(
752+
"/api/predict/",
753+
json={"data": [0], "session_hash": "1", "fn_index": 0},
754+
)
755+
assert "Original" in session_1.json()["data"][0]
756+
session_2 = client.post(
757+
"/api/predict/",
758+
json={"data": [0], "session_hash": "1", "fn_index": 0},
759+
)
760+
assert "New" in session_2.json()["data"][0]
761+
735762

736763
class TestStateHolder:
737764
@pytest.mark.asyncio

0 commit comments

Comments
 (0)