Skip to content

Implement backpressure for HTTP request body and WebSocket messages #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"itsdangerous",
"jinja2",
"markupsafe",
"typing-extensions; python_version < '3.10'",
"typing-extensions; python_version < '3.11'",
"werkzeug>=3.0",
]

Expand Down
3 changes: 2 additions & 1 deletion src/quart/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from .signals import websocket_tearing_down
from .templating import _default_template_ctx_processor
from .templating import Environment
from .testing import make_test_body_chunks
from .testing import make_test_body_with_headers
from .testing import make_test_headers_path_and_query_string
from .testing import make_test_scope
Expand Down Expand Up @@ -1363,10 +1364,10 @@ def test_request_context(
headers,
root_path,
http_version,
body_chunks=make_test_body_chunks(request_body),
send_push_promise=send_push_promise,
scope=scope,
)
request.body.set_result(request_body)
return self.request_context(request)

def add_background_task(self, func: Callable, *args: Any, **kwargs: Any) -> None:
Expand Down
17 changes: 10 additions & 7 deletions src/quart/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .signals import websocket_received
from .signals import websocket_sent
from .typing import ResponseTypes
from .utils import AsyncQueueIterator
from .utils import cancel_tasks
from .utils import encode_headers
from .utils import raise_task_exceptions
Expand All @@ -46,28 +47,29 @@ class ASGIHTTPConnection:
def __init__(self, app: Quart, scope: HTTPScope) -> None:
self.app = app
self.scope = scope
self.queue: AsyncQueueIterator[bytes] = AsyncQueueIterator(1)

async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
request = self._create_request_from_scope(send)
receiver_task = asyncio.ensure_future(self.handle_messages(request, receive))
receiver_task = asyncio.ensure_future(self.handle_messages(receive))
handler_task = asyncio.ensure_future(self.handle_request(request, send))
done, pending = await asyncio.wait(
[handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED
)
await cancel_tasks(pending)
raise_task_exceptions(done)

async def handle_messages(
self, request: Request, receive: ASGIReceiveCallable
) -> None:
async def handle_messages(self, receive: ASGIReceiveCallable) -> None:
queue = self.queue # for quicker access in the loop

while True:
message = await receive()
if message["type"] == "http.request":
request.body.append(message.get("body", b""))
await queue.put(message.get("body", b""))
if not message.get("more_body", False):
request.body.set_complete()
queue.set_complete()
elif message["type"] == "http.disconnect":
return

Expand Down Expand Up @@ -99,6 +101,7 @@ def _create_request_from_scope(self, send: ASGISendCallable) -> Request:
self.scope["http_version"],
max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
body_timeout=self.app.config["BODY_TIMEOUT"],
body_chunks=self.queue,
send_push_promise=partial(self._send_push_promise, send),
scope=self.scope,
)
Expand Down Expand Up @@ -180,7 +183,7 @@ class ASGIWebsocketConnection:
def __init__(self, app: Quart, scope: WebsocketScope) -> None:
self.app = app
self.scope = scope
self.queue: asyncio.Queue = asyncio.Queue()
self.queue: asyncio.Queue = asyncio.Queue(1)
self._accepted = False
self._closed = False

Expand Down
2 changes: 2 additions & 0 deletions src/quart/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .app import TestApp
from .client import QuartClient
from .connections import WebsocketResponseError
from .utils import make_test_body_chunks
from .utils import make_test_body_with_headers
from .utils import make_test_headers_path_and_query_string
from .utils import make_test_scope
Expand All @@ -35,6 +36,7 @@ def invoke(self, cli: Any = None, args: Any = None, **kwargs: Any) -> Any: # ty


__all__ = (
"make_test_body_chunks",
"make_test_body_with_headers",
"make_test_headers_path_and_query_string",
"make_test_scope",
Expand Down
6 changes: 6 additions & 0 deletions src/quart/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Any
from typing import AnyStr
from typing import cast
Expand Down Expand Up @@ -218,6 +219,11 @@ def make_test_scope(
return cast(Scope, scope)


async def make_test_body_chunks(*chunks: bytes) -> AsyncIterator[bytes]:
for chunk in chunks:
yield chunk


async def no_op_push(path: str, headers: Headers) -> None:
"""A push promise sender that does nothing.

Expand Down
63 changes: 63 additions & 0 deletions src/quart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from .typing import Event
from .typing import FilePath

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

if TYPE_CHECKING:
from .wrappers.response import Response # noqa: F401

Expand Down Expand Up @@ -184,3 +189,61 @@ def raise_task_exceptions(tasks: set[asyncio.Task]) -> None:
for task in tasks:
if not task.cancelled() and task.exception() is not None:
raise task.exception()


# Dummy type used in AsyncQueueIterator to wakeup an await without sending any
# data. (None isn't used for that, because the generic type T could allow None
# as valid data in the queue.)
class _AsyncQueueWakeup:
pass


# Items go in using an async queue interface, and come out via async iteration.
class AsyncQueueIterator(AsyncIterator[T]):
_queue: asyncio.Queue[T | _AsyncQueueWakeup]
_complete: bool

def __init__(self, maxsize: int = 0) -> None:
self._queue = asyncio.Queue(maxsize)
self._complete = False # In Python 3.13, use queue's shutdown() instead

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> T:
while not (self._queue.empty() and self._complete):
item = await self._queue.get()

if not isinstance(item, _AsyncQueueWakeup):
return item

raise StopAsyncIteration()

def empty(self) -> bool:
return self._queue.empty()

def full(self) -> bool:
return self._queue.full()

def complete(self) -> bool:
return self._complete

def _reject_if_complete(self) -> None:
if self._complete:
raise RuntimeError("already complete")

async def put(self, item: T) -> None:
self._reject_if_complete()

await self._queue.put(item)

def put_nowait(self, item: T) -> None:
self._reject_if_complete()

self._queue.put_nowait(item)

def set_complete(self) -> None:
self._complete = True

if self._queue.empty(): # so a get() might be waiting
self._queue.put_nowait(_AsyncQueueWakeup())
95 changes: 40 additions & 55 deletions src/quart/wrappers/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Generator
from typing import Any
Expand Down Expand Up @@ -49,11 +50,13 @@ class Body:
"""

def __init__(
self, expected_content_length: int | None, max_content_length: int | None
self,
chunks: AsyncIterator[bytes],
expected_content_length: int | None,
max_content_length: int | None,
) -> None:
self._data = bytearray()
self._complete: asyncio.Event = asyncio.Event()
self._has_data: asyncio.Event = asyncio.Event()
self._chunks = chunks
self._received_content_length = 0
self._max_content_length = max_content_length
# Exceptions must be raised within application (not ASGI)
# calls, this is achieved by having the ASGI methods set this
Expand All @@ -73,56 +76,30 @@ async def __anext__(self) -> bytes:
if self._must_raise is not None:
raise self._must_raise

# if we got all of the data in the first shot, then self._complete is
# set and self._has_data will not get set again, so skip the await
# if we already have completed everything
if not self._complete.is_set():
await self._has_data.wait()

if self._complete.is_set() and len(self._data) == 0:
raise StopAsyncIteration()

data = bytes(self._data)
self._data.clear()
self._has_data.clear()
return data

def __await__(self) -> Generator[Any, None, Any]:
# Must check the _must_raise before and after waiting on the
# completion event as it may change whilst waiting and the
# event may not be set if there is already an issue.
if self._must_raise is not None:
raise self._must_raise

yield from self._complete.wait().__await__()
data = await self._chunks.__anext__()

if self._must_raise is not None:
raise self._must_raise
return bytes(self._data)
self._received_content_length += len(data)

def append(self, data: bytes) -> None:
if data == b"" or self._must_raise is not None:
return
self._data.extend(data)
self._has_data.set()
if (
self._max_content_length is not None
and len(self._data) > self._max_content_length
and self._received_content_length > self._max_content_length
):
self._must_raise = RequestEntityTooLarge()
self.set_complete()
raise RequestEntityTooLarge()

return data

def __await__(self) -> Generator[Any, None, Any]:
async def accumulate_data() -> bytes:
data = bytearray()

def set_complete(self) -> None:
self._complete.set()
self._has_data.set()
# Receive chunks of data from the client and build up the complete
# request body.
async for data_chunk in self:
data.extend(data_chunk)

def set_result(self, data: bytes) -> None:
"""Convenience method, mainly for testing."""
self.append(data)
self.set_complete()
return bytes(data)

def clear(self) -> None:
self._data.clear()
return accumulate_data().__await__()


class Request(BaseRequestWebsocket):
Expand Down Expand Up @@ -158,6 +135,7 @@ def __init__(
*,
max_content_length: int | None = None,
body_timeout: int | None = None,
body_chunks: AsyncIterator[bytes],
send_push_promise: Callable[[str, Headers], Awaitable[None]],
) -> None:
"""Create a request object.
Expand All @@ -171,10 +149,10 @@ def __init__(
root_path: The root path that should be prepended to all
routes.
http_version: The HTTP version of the request.
body: An awaitable future for the body data i.e.
``data = await body``
max_content_length: The maximum length in bytes of the
body (None implies no limit in Quart).
body_chunks: An async iterable that provides the request body as a
sequence of data chunks.
body_timeout: The maximum time (seconds) to wait for the
body before timing out.
send_push_promise: An awaitable to send a push promise based
Expand All @@ -185,7 +163,12 @@ def __init__(
method, scheme, path, query_string, headers, root_path, http_version, scope
)
self.body_timeout = body_timeout
self.body = self.body_class(self.content_length, max_content_length)
self.body = self.body_class(
body_chunks,
self.content_length,
max_content_length,
)
self._cached_data: str | bytes | None = None
self._cached_json: dict[bool, Any] = {False: Ellipsis, True: Ellipsis}
self._form: MultiDict | None = None
self._files: MultiDict | None = None
Expand Down Expand Up @@ -271,6 +254,9 @@ async def get_data(
parse_form_data: Parse the data as form data first, return any
remaining data.
"""
if self._cached_data is not None:
return self._cached_data

if parse_form_data:
await self._load_form_data()

Expand All @@ -279,13 +265,12 @@ async def get_data(
except asyncio.TimeoutError as e:
raise RequestTimeout() from e
else:
if not cache:
self.body.clear()
data = raw_data.decode() if as_text else raw_data

if as_text:
return raw_data.decode()
else:
return raw_data
if cache:
self._cached_data = data

return data

@property
async def values(self) -> CombinedMultiDict:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from quart.globals import websocket
from quart.sessions import SecureCookieSession
from quart.sessions import SessionInterface
from quart.testing import make_test_body_chunks
from quart.testing import no_op_push
from quart.testing import WebsocketResponseError
from quart.typing import ResponseReturnValue
Expand Down Expand Up @@ -273,6 +274,7 @@ async def index() -> NoReturn:
"",
"1.1",
http_scope,
body_chunks=make_test_body_chunks(),
send_push_promise=no_op_push,
)
with pytest.raises(asyncio.CancelledError):
Expand Down Expand Up @@ -390,6 +392,7 @@ async def exception() -> ResponseReturnValue:
"",
"1.1",
http_scope,
body_chunks=make_test_body_chunks(),
send_push_promise=no_op_push,
)
)
Expand Down
Loading
Loading