Skip to content

Commit 890eaa3

Browse files
Allow displaying SVG images securely in gr.Image and gr.Gallery components (#10269)
* changes * changes * add changeset * changes * add changeset * changes * changes * changes * add changeset * add changeset * add changeset * format fe * changes * changes * changes * revert * revert more * revert * add changeset * more changes * add changeset * changes * add changeset * format * add changeset * changes * changes * svg * changes * format * add changeset * fix tests --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 99123e7 commit 890eaa3

File tree

10 files changed

+156
-53
lines changed

10 files changed

+156
-53
lines changed

.changeset/eleven-suits-itch.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@gradio/gallery": patch
3+
"@gradio/image": patch
4+
"gradio": patch
5+
---
6+
7+
fix:Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components

gradio/components/gallery.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Optional,
1313
Union,
1414
)
15-
from urllib.parse import urlparse
15+
from urllib.parse import quote, urlparse
1616

1717
import numpy as np
1818
import PIL.Image
@@ -21,9 +21,9 @@
2121
from gradio_client.documentation import document
2222
from gradio_client.utils import is_http_url_like
2323

24-
from gradio import processing_utils, utils, wasm_utils
24+
from gradio import image_utils, processing_utils, utils, wasm_utils
2525
from gradio.components.base import Component
26-
from gradio.data_classes import FileData, GradioModel, GradioRootModel
26+
from gradio.data_classes import FileData, GradioModel, GradioRootModel, ImageData
2727
from gradio.events import Events
2828
from gradio.exceptions import Error
2929

@@ -35,7 +35,7 @@
3535

3636

3737
class GalleryImage(GradioModel):
38-
image: FileData
38+
image: ImageData
3939
caption: Optional[str] = None
4040

4141

@@ -188,7 +188,7 @@ def preprocess(
188188
if isinstance(gallery_element, GalleryVideo):
189189
file_path = gallery_element.video.path
190190
else:
191-
file_path = gallery_element.image.path
191+
file_path = gallery_element.image.path or ""
192192
if self.file_types and not client_utils.is_valid_file(
193193
file_path, self.file_types
194194
):
@@ -216,6 +216,10 @@ def postprocess(
216216
"""
217217
if value is None:
218218
return GalleryData(root=[])
219+
if isinstance(value, str):
220+
raise ValueError(
221+
"The `value` passed into `gr.Gallery` must be a list of images or videos, or list of (media, caption) tuples."
222+
)
219223
output = []
220224

221225
def _save(img):
@@ -236,14 +240,20 @@ def _save(img):
236240
)
237241
file_path = str(utils.abspath(file))
238242
elif isinstance(img, str):
239-
file_path = img
240-
mime_type = client_utils.get_mimetype(file_path)
241-
if is_http_url_like(img):
243+
mime_type = client_utils.get_mimetype(img)
244+
if img.lower().endswith(".svg"):
245+
svg_content = image_utils.extract_svg_content(img)
246+
orig_name = Path(img).name
247+
url = f"data:image/svg+xml,{quote(svg_content)}"
248+
file_path = None
249+
elif is_http_url_like(img):
242250
url = img
243251
orig_name = Path(urlparse(img).path).name
252+
file_path = img
244253
else:
245254
url = None
246255
orig_name = Path(img).name
256+
file_path = img
247257
elif isinstance(img, Path):
248258
file_path = str(img)
249259
orig_name = img.name
@@ -253,7 +263,7 @@ def _save(img):
253263
if mime_type is not None and "video" in mime_type:
254264
return GalleryVideo(
255265
video=FileData(
256-
path=file_path,
266+
path=file_path, # type: ignore
257267
url=url,
258268
orig_name=orig_name,
259269
mime_type=mime_type,
@@ -262,7 +272,7 @@ def _save(img):
262272
)
263273
else:
264274
return GalleryImage(
265-
image=FileData(
275+
image=ImageData(
266276
path=file_path,
267277
url=url,
268278
orig_name=orig_name,

gradio/components/image.py

+11-29
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
import warnings
66
from collections.abc import Callable, Sequence
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
8+
from typing import TYPE_CHECKING, Any, Literal, cast
9+
from urllib.parse import quote
910

1011
import numpy as np
1112
import PIL.Image
1213
from gradio_client import handle_file
1314
from gradio_client.documentation import document
1415
from PIL import ImageOps
15-
from pydantic import ConfigDict, Field
1616

1717
from gradio import image_utils, utils
1818
from gradio.components.base import Component, StreamingInput
19-
from gradio.data_classes import GradioModel
19+
from gradio.data_classes import Base64ImageData, ImageData
2020
from gradio.events import Events
2121
from gradio.exceptions import Error
2222

@@ -26,28 +26,6 @@
2626
PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
2727

2828

29-
class ImageData(GradioModel):
30-
path: Optional[str] = Field(default=None, description="Path to a local file")
31-
url: Optional[str] = Field(
32-
default=None, description="Publicly available url or base64 encoded image"
33-
)
34-
size: Optional[int] = Field(default=None, description="Size of image in bytes")
35-
orig_name: Optional[str] = Field(default=None, description="Original filename")
36-
mime_type: Optional[str] = Field(default=None, description="mime type of image")
37-
is_stream: bool = Field(default=False, description="Can always be set to False")
38-
meta: dict = {"_type": "gradio.FileData"}
39-
40-
model_config = ConfigDict(
41-
json_schema_extra={
42-
"description": "For input, either path or url must be provided. For output, path is always provided."
43-
}
44-
)
45-
46-
47-
class Base64ImageData(GradioModel):
48-
url: str = Field(description="base64 encoded image")
49-
50-
5129
@document()
5230
class Image(StreamingInput, Component):
5331
"""
@@ -112,7 +90,7 @@ def __init__(
11290
width: The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed image file or numpy array, but will affect the displayed image.
11391
image_mode: The pixel format and color depth that the image should be loaded and preprocessed as. "RGB" will load the image as a color image, or "L" as black-and-white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file type (e.g. "RGBA" for a .png image, "RGB" in most other cases).
11492
sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"].
115-
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil".
93+
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". To support SVGs, the `type` should be set to "filepath".
11694
label: the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
11795
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
11896
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
@@ -198,7 +176,7 @@ def preprocess(
198176
Parameters:
199177
payload: image data in the form of a FileData object
200178
Returns:
201-
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned.
179+
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`.
202180
"""
203181
if payload is None:
204182
return payload
@@ -227,7 +205,7 @@ def preprocess(
227205
if suffix.lower() == "svg":
228206
if self.type == "filepath":
229207
return str(file_path)
230-
raise Error("SVG files are not supported as input images.")
208+
raise Error("SVG files are not supported as input images for this app.")
231209

232210
im = PIL.Image.open(file_path)
233211
if self.type == "filepath" and (self.image_mode in [None, im.mode]):
@@ -267,7 +245,11 @@ def postprocess(
267245
if value is None:
268246
return None
269247
if isinstance(value, str) and value.lower().endswith(".svg"):
270-
return ImageData(path=value, orig_name=Path(value).name)
248+
svg_content = image_utils.extract_svg_content(value)
249+
return ImageData(
250+
orig_name=Path(value).name,
251+
url=f"data:image/svg+xml,{quote(svg_content)}",
252+
)
271253
if self.streaming:
272254
if isinstance(value, np.ndarray):
273255
return Base64ImageData(

gradio/data_classes.py

+24
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from gradio_client.utils import is_file_obj_with_meta, traverse
2525
from pydantic import (
2626
BaseModel,
27+
ConfigDict,
28+
Field,
2729
GetCoreSchemaHandler,
2830
GetJsonSchemaHandler,
2931
RootModel,
@@ -391,3 +393,25 @@ class MediaStreamChunk(TypedDict):
391393
duration: float
392394
extension: str
393395
id: NotRequired[str]
396+
397+
398+
class ImageData(GradioModel):
399+
path: Optional[str] = Field(default=None, description="Path to a local file")
400+
url: Optional[str] = Field(
401+
default=None, description="Publicly available url or base64 encoded image"
402+
)
403+
size: Optional[int] = Field(default=None, description="Size of image in bytes")
404+
orig_name: Optional[str] = Field(default=None, description="Original filename")
405+
mime_type: Optional[str] = Field(default=None, description="mime type of image")
406+
is_stream: bool = Field(default=False, description="Can always be set to False")
407+
meta: dict = {"_type": "gradio.FileData"}
408+
409+
model_config = ConfigDict(
410+
json_schema_extra={
411+
"description": "For input, either path or url must be provided. For output, path is always provided."
412+
}
413+
)
414+
415+
416+
class Base64ImageData(GradioModel):
417+
url: str = Field(description="base64 encoded image")

gradio/image_utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from pathlib import Path
66
from typing import Literal, cast
77

8+
import httpx
89
import numpy as np
910
import PIL.Image
10-
from gradio_client.utils import get_mimetype
11+
from gradio_client.utils import get_mimetype, is_http_url_like
1112
from PIL import ImageOps
1213

1314
from gradio import processing_utils
@@ -152,3 +153,22 @@ def encode_image_file_to_base64(image_file: str | Path) -> str:
152153
bytes_data = f.read()
153154
base64_str = str(base64.b64encode(bytes_data), "utf-8")
154155
return f"data:{mime_type};base64," + base64_str
156+
157+
158+
def extract_svg_content(image_file: str | Path) -> str:
159+
"""
160+
Provided a path or URL to an SVG file, return the SVG content as a string.
161+
Parameters:
162+
image_file: Local file path or URL to an SVG file
163+
Returns:
164+
str: The SVG content as a string
165+
"""
166+
image_file = str(image_file)
167+
if is_http_url_like(image_file):
168+
response = httpx.get(image_file)
169+
response.raise_for_status() # Raise an error for bad status codes
170+
return response.text
171+
else:
172+
with open(image_file) as file:
173+
svg_content = file.read()
174+
return svg_content

js/gallery/Index.svelte

+27-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
<script lang="ts">
66
import type { GalleryImage, GalleryVideo } from "./types";
7+
import type { FileData } from "@gradio/client";
78
import type { Gradio, ShareData, SelectData } from "@gradio/utils";
89
import { Block, UploadText } from "@gradio/atoms";
910
import Gallery from "./shared/Gallery.svelte";
@@ -52,6 +53,30 @@
5253
5354
$: no_value = value === null ? true : value.length === 0;
5455
$: selected_index, dispatch("prop_change", { selected_index });
56+
57+
async function process_upload_files(
58+
files: FileData[]
59+
): Promise<GalleryData[]> {
60+
const processed_files = await Promise.all(
61+
files.map(async (x) => {
62+
if (x.path?.toLowerCase().endsWith(".svg") && x.url) {
63+
const response = await fetch(x.url);
64+
const svgContent = await response.text();
65+
return {
66+
...x,
67+
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
68+
};
69+
}
70+
return x;
71+
})
72+
);
73+
74+
return processed_files.map((x) =>
75+
x.mime_type?.includes("video")
76+
? { video: x, caption: null }
77+
: { image: x, caption: null }
78+
);
79+
}
5580
</script>
5681

5782
<Block
@@ -83,13 +108,9 @@
83108
i18n={gradio.i18n}
84109
upload={(...args) => gradio.client.upload(...args)}
85110
stream_handler={(...args) => gradio.client.stream(...args)}
86-
on:upload={(e) => {
111+
on:upload={async (e) => {
87112
const files = Array.isArray(e.detail) ? e.detail : [e.detail];
88-
value = files.map((x) =>
89-
x.mime_type?.includes("video")
90-
? { video: x, caption: null }
91-
: { image: x, caption: null }
92-
);
113+
value = await process_upload_files(files);
93114
gradio.dispatch("upload", value);
94115
}}
95116
on:error={({ detail }) => {

js/image/shared/ImageUploader.svelte

+13-3
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,20 @@
4545
4646
export let webcam_constraints: { [key: string]: any } | undefined = undefined;
4747
48-
function handle_upload({ detail }: CustomEvent<FileData>): void {
49-
// only trigger streaming event if streaming
48+
async function handle_upload({
49+
detail
50+
}: CustomEvent<FileData>): Promise<void> {
5051
if (!streaming) {
51-
value = detail;
52+
if (detail.path?.toLowerCase().endsWith(".svg") && detail.url) {
53+
const response = await fetch(detail.url);
54+
const svgContent = await response.text();
55+
value = {
56+
...detail,
57+
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
58+
};
59+
} else {
60+
value = detail;
61+
}
5262
dispatch("upload");
5363
}
5464
}

test/components/test_gallery.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import gradio as gr
77
from gradio.components.gallery import GalleryImage
8-
from gradio.data_classes import FileData
8+
from gradio.data_classes import ImageData
99

1010

1111
class TestGallery:
@@ -96,7 +96,7 @@ def test_gallery_preprocess(self):
9696
from gradio.components.gallery import GalleryData, GalleryImage
9797

9898
gallery = gr.Gallery()
99-
img = GalleryImage(image=FileData(path="test/test_files/bus.png"))
99+
img = GalleryImage(image=ImageData(path="test/test_files/bus.png"))
100100
data = GalleryData(root=[img])
101101

102102
assert (preprocessed := gallery.preprocess(data))
@@ -115,7 +115,7 @@ def test_gallery_preprocess(self):
115115
)
116116

117117
img_captions = GalleryImage(
118-
image=FileData(path="test/test_files/bus.png"), caption="bus"
118+
image=ImageData(path="test/test_files/bus.png"), caption="bus"
119119
)
120120
data = GalleryData(root=[img_captions])
121121
assert (preprocess := gr.Gallery().preprocess(data))
@@ -127,4 +127,6 @@ def test_gallery_format(self):
127127
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
128128
)
129129
if isinstance(output.root[0], GalleryImage):
130-
assert output.root[0].image.path.endswith(".jpeg")
130+
assert output.root[0].image.path and output.root[0].image.path.endswith(
131+
".jpeg"
132+
)

test/test_files/file_icon.svg

+1
Loading

0 commit comments

Comments
 (0)