Skip to content

Commit 99b6938

Browse files
authored
Allow to raise HTTPException before websocket.accept() (#2725)
* Allow to raise `HTTPException` before `websocket.accept()` * move << * Add documentation
1 parent 4ded4b7 commit 99b6938

File tree

5 files changed

+53
-54
lines changed

5 files changed

+53
-54
lines changed

docs/exceptions.md

+21-5
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,30 @@ In order to deal with this behaviour correctly, the middleware stack of a
115115

116116
## HTTPException
117117

118-
The `HTTPException` class provides a base class that you can use for any
119-
handled exceptions. The `ExceptionMiddleware` implementation defaults to
120-
returning plain-text HTTP responses for any `HTTPException`.
118+
The `HTTPException` class provides a base class that you can use for any handled exceptions.
119+
The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`.
121120

122121
* `HTTPException(status_code, detail=None, headers=None)`
123122

124-
You should only raise `HTTPException` inside routing or endpoints. Middleware
125-
classes should instead just return appropriate responses directly.
123+
You should only raise `HTTPException` inside routing or endpoints.
124+
Middleware classes should instead just return appropriate responses directly.
125+
126+
You can use an `HTTPException` on a WebSocket endpoint in case it's raised before `websocket.accept()`.
127+
The connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned.
128+
129+
```python
130+
from starlette.applications import Starlette
131+
from starlette.exceptions import HTTPException
132+
from starlette.routing import WebSocketRoute
133+
from starlette.websockets import WebSocket
134+
135+
136+
async def websocket_endpoint(websocket: WebSocket):
137+
raise HTTPException(status_code=400, detail="Bad request")
138+
139+
140+
app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
141+
```
126142

127143
## WebSocketException
128144

starlette/_exception_handler.py

+8-28
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
from starlette.concurrency import run_in_threadpool
77
from starlette.exceptions import HTTPException
88
from starlette.requests import Request
9-
from starlette.types import (
10-
ASGIApp,
11-
ExceptionHandler,
12-
HTTPExceptionHandler,
13-
Message,
14-
Receive,
15-
Scope,
16-
Send,
17-
WebSocketExceptionHandler,
18-
)
9+
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
1910
from starlette.websockets import WebSocket
2011

2112
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
@@ -62,24 +53,13 @@ async def sender(message: Message) -> None:
6253
raise exc
6354

6455
if response_started:
65-
msg = "Caught handled exception, but response already started."
66-
raise RuntimeError(msg) from exc
67-
68-
if scope["type"] == "http":
69-
nonlocal conn
70-
handler = typing.cast(HTTPExceptionHandler, handler)
71-
conn = typing.cast(Request, conn)
72-
if is_async_callable(handler):
73-
response = await handler(conn, exc)
74-
else:
75-
response = await run_in_threadpool(handler, conn, exc)
56+
raise RuntimeError("Caught handled exception, but response already started.") from exc
57+
58+
if is_async_callable(handler):
59+
response = await handler(conn, exc)
60+
else:
61+
response = await run_in_threadpool(handler, conn, exc) # type: ignore
62+
if response is not None:
7663
await response(scope, receive, sender)
77-
elif scope["type"] == "websocket":
78-
handler = typing.cast(WebSocketExceptionHandler, handler)
79-
conn = typing.cast(WebSocket, conn)
80-
if is_async_callable(handler):
81-
await handler(conn, exc)
82-
else:
83-
await run_in_threadpool(handler, conn, exc)
8464

8565
return wrapped_app

starlette/testclient.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,7 @@ def _raise_on_close(self, message: Message) -> None:
178178
body.append(message["body"])
179179
if not message.get("more_body", False):
180180
break
181-
raise WebSocketDenialResponse(
182-
status_code=status_code,
183-
headers=headers,
184-
content=b"".join(body),
185-
)
181+
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
186182

187183
def send(self, message: Message) -> None:
188184
self._receive_queue.put(message)

tests/test_applications.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44
from typing import AsyncGenerator, AsyncIterator, Generator
55

6-
import anyio
6+
import anyio.from_thread
77
import pytest
88

99
from starlette import status
@@ -17,7 +17,7 @@
1717
from starlette.responses import JSONResponse, PlainTextResponse
1818
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
1919
from starlette.staticfiles import StaticFiles
20-
from starlette.testclient import TestClient
20+
from starlette.testclient import TestClient, WebSocketDenialResponse
2121
from starlette.types import ASGIApp, Receive, Scope, Send
2222
from starlette.websockets import WebSocket
2323
from tests.types import TestClientFactory
@@ -71,11 +71,15 @@ async def websocket_endpoint(session: WebSocket) -> None:
7171
await session.close()
7272

7373

74-
async def websocket_raise_websocket(websocket: WebSocket) -> None:
74+
async def websocket_raise_websocket_exception(websocket: WebSocket) -> None:
7575
await websocket.accept()
7676
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
7777

7878

79+
async def websocket_raise_http_exception(websocket: WebSocket) -> None:
80+
raise HTTPException(status_code=401, detail="Unauthorized")
81+
82+
7983
class CustomWSException(Exception):
8084
pass
8185

@@ -118,7 +122,8 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
118122
Route("/class", endpoint=Homepage),
119123
Route("/500", endpoint=runtime_error),
120124
WebSocketRoute("/ws", endpoint=websocket_endpoint),
121-
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
125+
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
126+
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
122127
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
123128
Mount("/users", app=users),
124129
Host("{subdomain}.example.org", app=subdomain),
@@ -219,6 +224,14 @@ def test_websocket_raise_websocket_exception(client: TestClient) -> None:
219224
}
220225

221226

227+
def test_websocket_raise_http_exception(client: TestClient) -> None:
228+
with pytest.raises(WebSocketDenialResponse) as exc:
229+
with client.websocket_connect("/ws-raise-http"):
230+
pass # pragma: no cover
231+
assert exc.value.status_code == 401
232+
assert exc.value.content == b'{"detail":"Unauthorized"}'
233+
234+
222235
def test_websocket_raise_custom_exception(client: TestClient) -> None:
223236
with client.websocket_connect("/ws-raise-custom") as session:
224237
response = session.receive()
@@ -243,7 +256,8 @@ def test_routes() -> None:
243256
Route("/class", endpoint=Homepage),
244257
Route("/500", endpoint=runtime_error, methods=["GET"]),
245258
WebSocketRoute("/ws", endpoint=websocket_endpoint),
246-
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
259+
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
260+
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
247261
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
248262
Mount(
249263
"/users",

tests/test_exceptions.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6363
Route("/with_headers", endpoint=with_headers),
6464
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
6565
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
66-
Route(
67-
"/consume_body_in_endpoint_and_handler",
68-
endpoint=read_body_and_raise_exc,
69-
methods=["POST"],
70-
),
66+
Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
7167
]
7268
)
7369

@@ -114,13 +110,10 @@ def test_websockets_should_raise(client: TestClient) -> None:
114110
pass # pragma: no cover
115111

116112

117-
def test_handled_exc_after_response(
118-
test_client_factory: TestClientFactory,
119-
client: TestClient,
120-
) -> None:
113+
def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
121114
# A 406 HttpException is raised *after* the response has already been sent.
122115
# The exception middleware should raise a RuntimeError.
123-
with pytest.raises(RuntimeError):
116+
with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."):
124117
client.get("/handled_exc_after_response")
125118

126119
# If `raise_server_exceptions=False` then the test client will still allow
@@ -132,7 +125,7 @@ def test_handled_exc_after_response(
132125

133126

134127
def test_force_500_response(test_client_factory: TestClientFactory) -> None:
135-
# use a sentinal variable to make sure we actually
128+
# use a sentinel variable to make sure we actually
136129
# make it into the endpoint and don't get a 500
137130
# from an incorrect ASGI app signature or something
138131
called = False

0 commit comments

Comments
 (0)