Skip to content

Commit a0aac66

Browse files
Adds strict_cors parameter to launch() (#8959)
* prevent null origin requests by default * changes * add changeset * format --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 51b7a8b commit a0aac66

File tree

5 files changed

+50
-7
lines changed

5 files changed

+50
-7
lines changed

.changeset/hungry-tips-sin.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": minor
3+
---
4+
5+
feat:Adds `strict_cors` parameter to `launch()`

gradio/blocks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,7 @@ def launch(
21812181
max_file_size: str | int | None = None,
21822182
_frontend: bool = True,
21832183
enable_monitoring: bool = False,
2184+
strict_cors: bool = True,
21842185
) -> tuple[FastAPI, str, str]:
21852186
"""
21862187
Launches a simple web server that serves the demo. Can also be used to create a
@@ -2216,6 +2217,7 @@ def launch(
22162217
share_server_protocol: Use this to specify the protocol to use for the share links. Defaults to "https", unless a custom share_server_address is provided, in which case it defaults to "http". If you are using a custom share_server_address and want to use https, you must set this to "https".
22172218
auth_dependency: A function that takes a FastAPI request and returns a string user ID or None. If the function returns None for a specific request, that user is not authorized to access the app (they will see a 401 Unauthorized response). To be used with external authentication systems like OAuth. Cannot be used with `auth`.
22182219
max_file_size: The maximum file size in bytes that can be uploaded. Can be a string of the form "<value><unit>", where value is any positive integer and unit is one of "b", "kb", "mb", "gb", "tb". If None, no limit is set.
2220+
strict_cors: If True, prevents external domains from making requests to a Gradio server running on localhost. If False, allows requests to localhost that originate from localhost but also, crucially, from "null". This parameter should normally be True to prevent CSRF attacks but may need to be False when embedding a *locally-running Gradio app* using web components.
22192221
Returns:
22202222
app: FastAPI app object that is running the demo
22212223
local_url: Locally accessible link to the demo
@@ -2279,7 +2281,6 @@ def reverse(text):
22792281
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
22802282
else:
22812283
self.root_path = root_path
2282-
22832284
self.show_api = show_api
22842285

22852286
if allowed_paths:
@@ -2322,7 +2323,10 @@ def reverse(text):
23222323
self._queue.max_thread_count = max_threads
23232324
# self.server_app is included for backwards compatibility
23242325
self.server_app = self.app = App.create_app(
2325-
self, auth_dependency=auth_dependency, app_kwargs=app_kwargs
2326+
self,
2327+
auth_dependency=auth_dependency,
2328+
app_kwargs=app_kwargs,
2329+
strict_cors=strict_cors,
23262330
)
23272331

23282332
if self.is_running:

gradio/route_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ class CustomCORSMiddleware:
701701
def __init__(
702702
self,
703703
app: ASGIApp,
704+
strict_cors: bool = True,
704705
) -> None:
705706
self.app = app
706707
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
@@ -711,9 +712,12 @@ def __init__(
711712
}
712713
self.simple_headers = {"Access-Control-Allow-Credentials": "true"}
713714
# Any of these hosts suggests that the Gradio app is running locally.
714-
# Note: "null" is a special case that happens if a Gradio app is running
715-
# as an embedded web component in a local static webpage.
716-
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
715+
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0"]
716+
if not strict_cors: # type: ignore
717+
# Note: "null" is a special case that happens if a Gradio app is running
718+
# as an embedded web component in a local static webpage. However, it can
719+
# also be used maliciously for CSRF attacks, so it is not allowed by default.
720+
self.localhost_aliases.append("null")
717721

718722
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
719723
if scope["type"] != "http":

gradio/routes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def create_app(
240240
blocks: gradio.Blocks,
241241
app_kwargs: Dict[str, Any] | None = None,
242242
auth_dependency: Callable[[fastapi.Request], str | None] | None = None,
243+
strict_cors: bool = True,
243244
) -> App:
244245
app_kwargs = app_kwargs or {}
245246
app_kwargs.setdefault("default_response_class", ORJSONResponse)
@@ -251,7 +252,7 @@ def create_app(
251252
app.configure_app(blocks)
252253

253254
if not wasm_utils.IS_WASM:
254-
app.add_middleware(CustomCORSMiddleware)
255+
app.add_middleware(CustomCORSMiddleware, strict_cors=strict_cors)
255256

256257
@app.get("/user")
257258
@app.get("/user/")

test/test_routes.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def test_can_get_config_that_includes_non_pickle_able_objects(self):
508508
response = client.get("/config/")
509509
assert response.is_success
510510

511-
def test_cors_restrictions(self):
511+
def test_default_cors_restrictions(self):
512512
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
513513
app, _, _ = io.launch(prevent_thread_lock=True)
514514
client = TestClient(app)
@@ -518,12 +518,41 @@ def test_cors_restrictions(self):
518518
}
519519
file_response = client.get("/config", headers=custom_headers)
520520
assert "access-control-allow-origin" not in file_response.headers
521+
522+
custom_headers = {
523+
"host": "localhost:7860",
524+
"origin": "null",
525+
}
526+
file_response = client.get("/config", headers=custom_headers)
527+
assert "access-control-allow-origin" not in file_response.headers
528+
521529
custom_headers = {
522530
"host": "localhost:7860",
523531
"origin": "127.0.0.1",
524532
}
525533
file_response = client.get("/config", headers=custom_headers)
526534
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
535+
536+
io.close()
537+
538+
def test_loose_cors_restrictions(self):
539+
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
540+
app, _, _ = io.launch(prevent_thread_lock=True, strict_cors=False)
541+
client = TestClient(app)
542+
custom_headers = {
543+
"host": "localhost:7860",
544+
"origin": "https://example.com",
545+
}
546+
file_response = client.get("/config", headers=custom_headers)
547+
assert "access-control-allow-origin" not in file_response.headers
548+
549+
custom_headers = {
550+
"host": "localhost:7860",
551+
"origin": "null",
552+
}
553+
file_response = client.get("/config", headers=custom_headers)
554+
assert file_response.headers["access-control-allow-origin"] == "null"
555+
527556
io.close()
528557

529558
def test_delete_cache(self, connect, gradio_temp_dir, capsys):

0 commit comments

Comments
 (0)