diff --git a/.changeset/soft-worms-remain.md b/.changeset/soft-worms-remain.md
new file mode 100644
index 0000000000..5e71e60a1c
--- /dev/null
+++ b/.changeset/soft-worms-remain.md
@@ -0,0 +1,7 @@
+---
+"@gradio/dataset": patch
+"gradio": patch
+"website": patch
+---
+
+fix:fix dataset update
diff --git a/gradio/blocks.py b/gradio/blocks.py
index b64a4a224a..145f44cd37 100644
--- a/gradio/blocks.py
+++ b/gradio/blocks.py
@@ -1723,12 +1723,12 @@ async def postprocess_data(
) from err
if block.stateful:
- if not utils.is_update(predictions[i]):
+ if not utils.is_prop_update(predictions[i]):
state[block._id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
- if utils.is_update(
+ if utils.is_prop_update(
prediction_value
): # if update is passed directly (deprecated), remove Nones
prediction_value = utils.delete_none(
@@ -1738,7 +1738,7 @@ async def postprocess_data(
if isinstance(prediction_value, Block):
prediction_value = prediction_value.constructor_args.copy()
prediction_value["__type__"] = "update"
- if utils.is_update(prediction_value):
+ if utils.is_prop_update(prediction_value):
kwargs = state[block._id].constructor_args.copy()
kwargs.update(prediction_value)
kwargs.pop("value", None)
diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py
index 77628f8720..982bd1f139 100644
--- a/gradio/components/dataset.py
+++ b/gradio/components/dataset.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import warnings
from typing import Any, Literal
from gradio_client.documentation import document
@@ -17,7 +18,8 @@
@document()
class Dataset(Component):
"""
- Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
+ Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
+ However, it can also be used directly to display a dataset and let users select examples.
"""
EVENTS = [Events.click, Events.select]
@@ -26,7 +28,7 @@ def __init__(
self,
*,
label: str | None = None,
- components: list[Component] | list[str],
+ components: list[Component] | list[str] | None = None,
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
@@ -70,7 +72,7 @@ def __init__(
self.container = container
self.scale = scale
self.min_width = min_width
- self._components = [get_component_instance(c) for c in components]
+ self._components = [get_component_instance(c) for c in components or []]
if component_props is None:
self.component_props = [
component.recover_kwargs(
@@ -131,29 +133,39 @@ def get_config(self):
return config
- def preprocess(self, payload: int) -> int | list | None:
+ def preprocess(self, payload: int | None) -> int | list | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
"""
+ if payload is None:
+ return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]
- def postprocess(self, samples: list[list]) -> dict:
+ def postprocess(self, sample: int | list | None) -> int | None:
"""
Parameters:
- samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
+ sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
Returns:
- Returns the updated dataset data as a `dict` with the key "samples".
+ Returns the index of the sample in the dataset.
"""
- return {
- "samples": samples,
- "__type__": "update",
- }
+ if sample is None or isinstance(sample, int):
+ return sample
+ if isinstance(sample, list):
+ try:
+ index = self.samples.index(sample)
+ except ValueError:
+ index = None
+ warnings.warn(
+ "The `Dataset` component does not support updating the dataset data by providing "
+ "a set of list values. Instead, you should return a new Dataset(samples=...) object."
+ )
+ return index
def example_payload(self) -> Any:
return 0
diff --git a/gradio/flagging.py b/gradio/flagging.py
index c012835ca6..0bd164752b 100644
--- a/gradio/flagging.py
+++ b/gradio/flagging.py
@@ -164,7 +164,7 @@ def flag(
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
- if utils.is_update(sample):
+ if utils.is_prop_update(sample):
csv_data.append(str(sample))
else:
data = (
diff --git a/gradio/helpers.py b/gradio/helpers.py
index 0df85af994..0da9287b96 100644
--- a/gradio/helpers.py
+++ b/gradio/helpers.py
@@ -544,7 +544,7 @@ def load_from_cache(self, example_id: int) -> list[Any]:
component, components.File
):
value_to_use = value_as_dict
- if not utils.is_update(value_as_dict):
+ if not utils.is_prop_update(value_as_dict):
raise TypeError("value wasn't an update") # caught below
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError):
diff --git a/gradio/utils.py b/gradio/utils.py
index 76b882ff4d..eb2b25a770 100644
--- a/gradio/utils.py
+++ b/gradio/utils.py
@@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
return False
-def is_update(val):
+def is_prop_update(val):
return isinstance(val, dict) and "update" in val.get("__type__", "")
diff --git a/js/_website/src/lib/templates/gradio/03_components/dataset.svx b/js/_website/src/lib/templates/gradio/03_components/dataset.svx
index c2f9f2ce9a..560785a8af 100644
--- a/js/_website/src/lib/templates/gradio/03_components/dataset.svx
+++ b/js/_website/src/lib/templates/gradio/03_components/dataset.svx
@@ -86,6 +86,40 @@ def predict(···) -> list[list]
{/if}
+### Examples
+
+**Updating a Dataset**
+
+In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:
+
+```py
+import gradio as gr
+
+philosophy_quotes = [
+ ["I think therefore I am."],
+ ["The unexamined life is not worth living."]
+]
+
+startup_quotes = [
+ ["Ideas are easy. Implementation is hard"],
+ ["Make mistakes faster."]
+]
+
+def show_startup_quotes():
+ return gr.Dataset(samples=startup_quotes)
+
+with gr.Blocks() as demo:
+ textbox = gr.Textbox()
+ dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
+ button = gr.Button()
+
+ button.click(show_startup_quotes, None, dataset)
+
+demo.launch()
+```
+
+
+
{#if obj.fns && obj.fns.length > 0}
### Event Listeners
@@ -97,3 +131,4 @@ def predict(···) -> list[list]
### Guides
{/if}
+
diff --git a/js/dataset/Index.svelte b/js/dataset/Index.svelte
index dab1c5969c..4b48b431b2 100644
--- a/js/dataset/Index.svelte
+++ b/js/dataset/Index.svelte
@@ -12,7 +12,7 @@
>;
export let label = "Examples";
export let headers: string[];
- export let samples: any[][];
+ export let samples: any[][] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
@@ -34,7 +34,7 @@
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
- let paginate = samples.length > samples_per_page;
+ let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
let page_count: number;
@@ -51,6 +51,7 @@
}
$: {
+ samples = samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];
diff --git a/test/components/test_dataset.py b/test/components/test_dataset.py
index 9030a4efbf..3101784777 100644
--- a/test/components/test_dataset.py
+++ b/test/components/test_dataset.py
@@ -43,27 +43,10 @@ def test_preprocessing(self):
assert dataset.samples == [["value 1"], ["value 2"]]
def test_postprocessing(self):
- test_file_dir = Path(Path(__file__).parent, "test_files")
- bus = Path(test_file_dir, "bus.png")
-
dataset = gr.Dataset(
components=["number", "textbox", "image", "html", "markdown"], type="index"
)
-
- output = dataset.postprocess(
- samples=[
- [5, "hello", bus, "Bold", "**Bold**"],
- [15, "hi", bus, "Italics", "*Italics*"],
- ],
- )
-
- assert output == {
- "samples": [
- [5, "hello", bus, "Bold", "**Bold**"],
- [15, "hi", bus, "Italics", "*Italics*"],
- ],
- "__type__": "update",
- }
+ assert dataset.postprocess(1) == 1
@patch(
diff --git a/test/test_blocks.py b/test/test_blocks.py
index 9b02df69a3..c74ef7a7f9 100644
--- a/test/test_blocks.py
+++ b/test/test_blocks.py
@@ -732,6 +732,33 @@ def infer(a, b):
):
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)
+ @pytest.mark.asyncio
+ async def test_dataset_is_updated(self):
+ def update(value):
+ return value, gr.Dataset(samples=[["New A"], ["New B"]])
+
+ with gr.Blocks() as demo:
+ with gr.Row():
+ textbox = gr.Textbox()
+ dataset = gr.Dataset(
+ components=["text"], samples=[["Original"]], label="Saved Prompts"
+ )
+ dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
+ app, _, _ = demo.launch(prevent_thread_lock=True)
+
+ client = TestClient(app)
+
+ session_1 = client.post(
+ "/api/predict/",
+ json={"data": [0], "session_hash": "1", "fn_index": 0},
+ )
+ assert "Original" in session_1.json()["data"][0]
+ session_2 = client.post(
+ "/api/predict/",
+ json={"data": [0], "session_hash": "1", "fn_index": 0},
+ )
+ assert "New" in session_2.json()["data"][0]
+
class TestStateHolder:
@pytest.mark.asyncio