diff --git a/shiny/_app.py b/shiny/_app.py index 6eeba7b68..48975dac3 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -6,7 +6,17 @@ from contextlib import AsyncExitStack, asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any, Callable, Literal, Mapping, Optional, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Literal, + Mapping, + Optional, + TypeVar, + Union, + cast, +) import starlette.applications import starlette.exceptions @@ -29,7 +39,7 @@ from ._connection import Connection, StarletteConnection from ._error import ErrorMiddleware from ._shinyenv import is_pyodide -from ._utils import guess_mime_type, is_async_callable, sort_keys_length +from ._utils import guess_mime_type, is_async_callable, sort_keys_length, wrap_async from .bookmark import _global as bookmark_global_state from .bookmark._global import as_bookmark_dir_fn from .bookmark._restore_state import RestoreContext, restore_context @@ -66,8 +76,8 @@ class App: returns a UI definition, if you need the UI definition to be created dynamically for each pageview. server - A function which is called once for each session, ensuring that each session is - independent. + A sync or async function which is called once for each session, ensuring that + each session is independent. static_assets Static files to be served by the app. If this is a string or Path object, it must be a directory, and it will be mounted at `/`. If this is a dictionary, @@ -113,7 +123,7 @@ def server(input: Inputs, output: Outputs, session: Session): """ ui: RenderedHTML | Callable[[Request], Tag | TagList] - server: Callable[[Inputs, Outputs, Session], None] + server: Callable[[Inputs, Outputs, Session], Awaitable[None]] _bookmark_save_dir_fn: BookmarkSaveDirFn | None _bookmark_restore_dir_fn: BookmarkRestoreDirFn | None @@ -123,7 +133,9 @@ def __init__( self, ui: Tag | TagList | Callable[[Request], Tag | TagList] | Path, server: ( - Callable[[Inputs], None] | Callable[[Inputs, Outputs, Session], None] | None + Callable[[Inputs], Awaitable[None] | None] + | Callable[[Inputs, Outputs, Session], Awaitable[None] | None] + | None ), *, static_assets: Optional[str | Path | Mapping[str, str | Path]] = None, @@ -136,13 +148,20 @@ def __init__( self._exit_stack = AsyncExitStack() if server is None: - self.server = noop_server_fn + self.server = wrap_async(noop_server_fn) elif len(signature(server).parameters) == 1: self.server = wrap_server_fn_with_output_session( - cast(Callable[[Inputs], None], server) + wrap_async( + cast(Callable[[Inputs], Union[Awaitable[None], None]], server) + ) ) elif len(signature(server).parameters) == 3: - self.server = cast(Callable[[Inputs, Outputs, Session], None], server) + self.server = wrap_async( + cast( + Callable[[Inputs, Outputs, Session], Union[Awaitable[None], None]], + server, + ) + ) else: raise ValueError( "`server` must have 1 (Inputs) or 3 parameters (Inputs, Outputs, Session)" @@ -571,10 +590,10 @@ def noop_server_fn(input: Inputs, output: Outputs, session: Session) -> None: def wrap_server_fn_with_output_session( - server: Callable[[Inputs], None], -) -> Callable[[Inputs, Outputs, Session], None]: - def _server(input: Inputs, output: Outputs, session: Session): + server: Callable[[Inputs], Awaitable[None]], +) -> Callable[[Inputs, Outputs, Session], Awaitable[None]]: + async def _server(input: Inputs, output: Outputs, session: Session): # Only has 1 parameter, ignore output, session - server(input) + await server(input) return _server diff --git a/shiny/_utils.py b/shiny/_utils.py index fc0eaa21f..799988401 100644 --- a/shiny/_utils.py +++ b/shiny/_utils.py @@ -262,7 +262,7 @@ def private_seed() -> Generator[None, None, None]: def wrap_async( - fn: Callable[P, R] | Callable[P, Awaitable[R]], + fn: Callable[P, R] | Callable[P, Awaitable[R]] | Callable[P, Awaitable[R] | R], ) -> Callable[P, Awaitable[R]]: """ Given a synchronous function that returns R, return an async function that wraps the @@ -270,7 +270,7 @@ def wrap_async( """ if is_async_callable(fn): - return fn + return cast(Callable[P, Awaitable[R]], fn) fn = cast(Callable[P, R], fn) @@ -362,10 +362,10 @@ def is_async_callable( return False -# def not_is_async_callable( -# obj: Callable[P, T] | Callable[P, Awaitable[T]] -# ) -> TypeGuard[Callable[P, T]]: -# return not is_async_callable(obj) +def not_is_async_callable( + obj: Callable[P, T] | Callable[P, Awaitable[T]], +) -> TypeGuard[Callable[P, T]]: + return not is_async_callable(obj) # See https://stackoverflow.com/a/59780868/412655 for an excellent explanation diff --git a/shiny/express/_module.py b/shiny/express/_module.py index 9036e287a..f49458ce2 100644 --- a/shiny/express/_module.py +++ b/shiny/express/_module.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import functools -from typing import Callable, TypeVar +from typing import Awaitable, Callable, TypeVar, overload from .._docstring import add_example from .._typing_extensions import Concatenate, ParamSpec +from .._utils import is_async_callable, not_is_async_callable from ..module import Id from ..session._session import Inputs, Outputs, Session from ..session._utils import require_active_session, session_context @@ -16,9 +19,21 @@ @add_example(ex_dir="../api-examples/express_module") +# Use overloads so the function type stays the same for when the user calls it +@overload +def module( + fn: Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]], +) -> Callable[Concatenate[Id, P], Awaitable[R]]: ... +@overload def module( fn: Callable[Concatenate[Inputs, Outputs, Session, P], R], -) -> Callable[Concatenate[Id, P], R]: +) -> Callable[Concatenate[Id, P], R]: ... +def module( + fn: ( + Callable[Concatenate[Inputs, Outputs, Session, P], R] + | Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]] + ), +) -> Callable[Concatenate[Id, P], R] | Callable[Concatenate[Id, P], Awaitable[R]]: """ Create a Shiny module using Shiny Express syntax @@ -42,18 +57,43 @@ def module( """ fn = expressify(fn) - @functools.wraps(fn) - def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: - parent_session = require_active_session(None) - module_session = parent_session.make_scope(id) - - with session_context(module_session): - return fn( - module_session.input, - module_session.output, - module_session, - *args, - **kwargs, - ) - - return wrapper + if is_async_callable(fn): + # If the function is async, we need to wrap it in an async wrapper + @functools.wraps(fn) + async def async_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + parent_session = require_active_session(None) + module_session = parent_session.make_scope(id) + + with session_context(module_session): + return await fn( + module_session.input, + module_session.output, + module_session, + *args, + **kwargs, + ) + + return async_wrapper + + # Required for type narrowing. `TypeIs` did not seem to work as expected here. + if not_is_async_callable(fn): + + @functools.wraps(fn) + def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + parent_session = require_active_session(None) + module_session = parent_session.make_scope(id) + + with session_context(module_session): + return fn( + module_session.input, + module_session.output, + module_session, + *args, + **kwargs, + ) + + return wrapper + + raise RuntimeError( + "The provided function must be either synchronous or asynchronous." + ) diff --git a/shiny/module.py b/shiny/module.py index a47e52e5e..2f32a6f00 100644 --- a/shiny/module.py +++ b/shiny/module.py @@ -1,8 +1,7 @@ from __future__ import annotations -__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId") - -from typing import TYPE_CHECKING, Callable, TypeVar +import functools +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar, overload from ._docstring import no_example from ._namespaces import ( @@ -13,10 +12,13 @@ resolve_id, ) from ._typing_extensions import Concatenate, ParamSpec +from ._utils import is_async_callable, not_is_async_callable if TYPE_CHECKING: from .session import Inputs, Outputs, Session +__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId") + P = ParamSpec("P") R = TypeVar("R") @@ -34,15 +36,50 @@ def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: @no_example() +# Use overloads so the function type stays the same for when the user calls it +@overload +def server( + fn: Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]], +) -> Callable[Concatenate[str, P], Awaitable[R]]: ... +@overload def server( fn: Callable[Concatenate[Inputs, Outputs, Session, P], R], -) -> Callable[Concatenate[str, P], R]: +) -> Callable[Concatenate[str, P], R]: ... +def server( + fn: ( + Callable[Concatenate[Inputs, Outputs, Session, P], R] + | Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]] + ), +) -> Callable[Concatenate[str, P], R] | Callable[Concatenate[str, P], Awaitable[R]]: from .session import require_active_session, session_context - def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: - sess = require_active_session(None) - child_sess = sess.make_scope(id) - with session_context(child_sess): - return fn(child_sess.input, child_sess.output, child_sess, *args, **kwargs) + if is_async_callable(fn): - return wrapper + @functools.wraps(fn) + async def async_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + sess = require_active_session(None) + child_sess = sess.make_scope(id) + with session_context(child_sess): + return await fn( + child_sess.input, child_sess.output, child_sess, *args, **kwargs + ) + + return async_wrapper + + # Required for type narrowing. `TypeIs` did not seem to work as expected here. + if not_is_async_callable(fn): + + @functools.wraps(fn) + def sync_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R: + sess = require_active_session(None) + child_sess = sess.make_scope(id) + with session_context(child_sess): + return fn( + child_sess.input, child_sess.output, child_sess, *args, **kwargs + ) + + return sync_wrapper + + raise RuntimeError( + "The provided function must be either synchronous or asynchronous." + ) diff --git a/shiny/session/_session.py b/shiny/session/_session.py index 2d697d3f1..8eee4d41a 100644 --- a/shiny/session/_session.py +++ b/shiny/session/_session.py @@ -667,7 +667,7 @@ def verify_state(expected_state: ConnectionState) -> None: self._manage_inputs(message_obj["data"]) with session_context(self): - self.app.server(self.input, self.output, self) + await self.app.server(self.input, self.output, self) # TODO: Remove this call to reactive_flush() once https://github.com/posit-dev/py-shiny/issues/1889 is fixed # Workaround: Any `on_flushed()` calls from bookmark's `on_restored()` will be flushed here