diff --git a/.changeset/soft-worms-remain.md b/.changeset/soft-worms-remain.md new file mode 100644 index 0000000000..5e71e60a1c --- /dev/null +++ b/.changeset/soft-worms-remain.md @@ -0,0 +1,7 @@ +--- +"@gradio/dataset": patch +"gradio": patch +"website": patch +--- + +fix:fix dataset update diff --git a/gradio/blocks.py b/gradio/blocks.py index b64a4a224a..145f44cd37 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1723,12 +1723,12 @@ async def postprocess_data( ) from err if block.stateful: - if not utils.is_update(predictions[i]): + if not utils.is_prop_update(predictions[i]): state[block._id] = predictions[i] output.append(None) else: prediction_value = predictions[i] - if utils.is_update( + if utils.is_prop_update( prediction_value ): # if update is passed directly (deprecated), remove Nones prediction_value = utils.delete_none( @@ -1738,7 +1738,7 @@ async def postprocess_data( if isinstance(prediction_value, Block): prediction_value = prediction_value.constructor_args.copy() prediction_value["__type__"] = "update" - if utils.is_update(prediction_value): + if utils.is_prop_update(prediction_value): kwargs = state[block._id].constructor_args.copy() kwargs.update(prediction_value) kwargs.pop("value", None) diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 77628f8720..982bd1f139 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from typing import Any, Literal from gradio_client.documentation import document @@ -17,7 +18,8 @@ @document() class Dataset(Component): """ - Creates a gallery or table to display data samples. This component is designed for internal use to display examples. + Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples. + However, it can also be used directly to display a dataset and let users select examples. """ EVENTS = [Events.click, Events.select] @@ -26,7 +28,7 @@ def __init__( self, *, label: str | None = None, - components: list[Component] | list[str], + components: list[Component] | list[str] | None = None, component_props: list[dict[str, Any]] | None = None, samples: list[list[Any]] | None = None, headers: list[str] | None = None, @@ -70,7 +72,7 @@ def __init__( self.container = container self.scale = scale self.min_width = min_width - self._components = [get_component_instance(c) for c in components] + self._components = [get_component_instance(c) for c in components or []] if component_props is None: self.component_props = [ component.recover_kwargs( @@ -131,29 +133,39 @@ def get_config(self): return config - def preprocess(self, payload: int) -> int | list | None: + def preprocess(self, payload: int | None) -> int | list | None: """ Parameters: payload: the index of the selected example in the dataset Returns: Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index") """ + if payload is None: + return None if self.type == "index": return payload elif self.type == "values": return self.samples[payload] - def postprocess(self, samples: list[list]) -> dict: + def postprocess(self, sample: int | list | None) -> int | None: """ Parameters: - samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset. + sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found. Returns: - Returns the updated dataset data as a `dict` with the key "samples". + Returns the index of the sample in the dataset. """ - return { - "samples": samples, - "__type__": "update", - } + if sample is None or isinstance(sample, int): + return sample + if isinstance(sample, list): + try: + index = self.samples.index(sample) + except ValueError: + index = None + warnings.warn( + "The `Dataset` component does not support updating the dataset data by providing " + "a set of list values. Instead, you should return a new Dataset(samples=...) object." + ) + return index def example_payload(self) -> Any: return 0 diff --git a/gradio/flagging.py b/gradio/flagging.py index c012835ca6..0bd164752b 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -164,7 +164,7 @@ def flag( ) / client_utils.strip_invalid_filename_characters( getattr(component, "label", None) or f"component {idx}" ) - if utils.is_update(sample): + if utils.is_prop_update(sample): csv_data.append(str(sample)) else: data = ( diff --git a/gradio/helpers.py b/gradio/helpers.py index 0df85af994..0da9287b96 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -544,7 +544,7 @@ def load_from_cache(self, example_id: int) -> list[Any]: component, components.File ): value_to_use = value_as_dict - if not utils.is_update(value_as_dict): + if not utils.is_prop_update(value_as_dict): raise TypeError("value wasn't an update") # caught below output.append(value_as_dict) except (ValueError, TypeError, SyntaxError): diff --git a/gradio/utils.py b/gradio/utils.py index 76b882ff4d..eb2b25a770 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool: return False -def is_update(val): +def is_prop_update(val): return isinstance(val, dict) and "update" in val.get("__type__", "") diff --git a/js/_website/src/lib/templates/gradio/03_components/dataset.svx b/js/_website/src/lib/templates/gradio/03_components/dataset.svx index c2f9f2ce9a..560785a8af 100644 --- a/js/_website/src/lib/templates/gradio/03_components/dataset.svx +++ b/js/_website/src/lib/templates/gradio/03_components/dataset.svx @@ -86,6 +86,40 @@ def predict(···) -> list[list] {/if} +### Examples + +**Updating a Dataset** + +In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button: + +```py +import gradio as gr + +philosophy_quotes = [ + ["I think therefore I am."], + ["The unexamined life is not worth living."] +] + +startup_quotes = [ + ["Ideas are easy. Implementation is hard"], + ["Make mistakes faster."] +] + +def show_startup_quotes(): + return gr.Dataset(samples=startup_quotes) + +with gr.Blocks() as demo: + textbox = gr.Textbox() + dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes) + button = gr.Button() + + button.click(show_startup_quotes, None, dataset) + +demo.launch() +``` + + + {#if obj.fns && obj.fns.length > 0} ### Event Listeners @@ -97,3 +131,4 @@ def predict(···) -> list[list] ### Guides {/if} + diff --git a/js/dataset/Index.svelte b/js/dataset/Index.svelte index dab1c5969c..4b48b431b2 100644 --- a/js/dataset/Index.svelte +++ b/js/dataset/Index.svelte @@ -12,7 +12,7 @@ >; export let label = "Examples"; export let headers: string[]; - export let samples: any[][]; + export let samples: any[][] | null = null; export let elem_id = ""; export let elem_classes: string[] = []; export let visible = true; @@ -34,7 +34,7 @@ : `${root}/file=`; let page = 0; $: gallery = components.length < 2; - let paginate = samples.length > samples_per_page; + let paginate = samples ? samples.length > samples_per_page : false; let selected_samples: any[][]; let page_count: number; @@ -51,6 +51,7 @@ } $: { + samples = samples || []; paginate = samples.length > samples_per_page; if (paginate) { visible_pages = []; diff --git a/test/components/test_dataset.py b/test/components/test_dataset.py index 9030a4efbf..3101784777 100644 --- a/test/components/test_dataset.py +++ b/test/components/test_dataset.py @@ -43,27 +43,10 @@ def test_preprocessing(self): assert dataset.samples == [["value 1"], ["value 2"]] def test_postprocessing(self): - test_file_dir = Path(Path(__file__).parent, "test_files") - bus = Path(test_file_dir, "bus.png") - dataset = gr.Dataset( components=["number", "textbox", "image", "html", "markdown"], type="index" ) - - output = dataset.postprocess( - samples=[ - [5, "hello", bus, "Bold", "**Bold**"], - [15, "hi", bus, "Italics", "*Italics*"], - ], - ) - - assert output == { - "samples": [ - [5, "hello", bus, "Bold", "**Bold**"], - [15, "hi", bus, "Italics", "*Italics*"], - ], - "__type__": "update", - } + assert dataset.postprocess(1) == 1 @patch( diff --git a/test/test_blocks.py b/test/test_blocks.py index 9b02df69a3..c74ef7a7f9 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -732,6 +732,33 @@ def infer(a, b): ): await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None) + @pytest.mark.asyncio + async def test_dataset_is_updated(self): + def update(value): + return value, gr.Dataset(samples=[["New A"], ["New B"]]) + + with gr.Blocks() as demo: + with gr.Row(): + textbox = gr.Textbox() + dataset = gr.Dataset( + components=["text"], samples=[["Original"]], label="Saved Prompts" + ) + dataset.click(update, inputs=[dataset], outputs=[textbox, dataset]) + app, _, _ = demo.launch(prevent_thread_lock=True) + + client = TestClient(app) + + session_1 = client.post( + "/api/predict/", + json={"data": [0], "session_hash": "1", "fn_index": 0}, + ) + assert "Original" in session_1.json()["data"][0] + session_2 = client.post( + "/api/predict/", + json={"data": [0], "session_hash": "1", "fn_index": 0}, + ) + assert "New" in session_2.json()["data"][0] + class TestStateHolder: @pytest.mark.asyncio