Skip to content

Commit 597e345

Browse files
TeamSpen210CoolCat467A5rocks
authored
Use TypeVarTuple in our APIs (#2881)
* Use TypeVarTuple in various functions, except for Nursery.start(). That isn't handled by type checkers well yet. * Fix docs failure * Make gen_exports create an __all__ list Co-authored-by: CoolCat467 <[email protected]> Co-authored-by: EXPLOSION <[email protected]>
1 parent 31b87ad commit 597e345

24 files changed

+261
-74
lines changed

check.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ if [ $PYRIGHT -ne 0 ]; then
110110
fi
111111

112112
pyright src/trio/_tests/type_tests || EXIT_STATUS=$?
113+
pyright src/trio/_core/_tests/type_tests || EXIT_STATUS=$?
113114
echo "::endgroup::"
114115

115116
# Finally, leave a really clear warning of any issues and exit

docs/source/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
# aliasing doesn't actually fix the warning for types.FrameType, but displaying
7474
# "types.FrameType" is more helpful than just "frame"
7575
"FrameType": "types.FrameType",
76-
"Context": "OpenSSL.SSL.Context",
7776
# SSLListener.accept's return type is seen as trio._ssl.SSLStream
7877
"SSLStream": "trio.SSLStream",
7978
}
@@ -91,6 +90,8 @@ def autodoc_process_signature(
9190
# Strip the type from the union, make it look like = ...
9291
signature = signature.replace(" | type[trio._core._local._NoValue]", "")
9392
signature = signature.replace("<class 'trio._core._local._NoValue'>", "...")
93+
if "DTLS" in name:
94+
signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context")
9495
# Don't specify PathLike[str] | PathLike[bytes], this is just for humans.
9596
signature = signature.replace("StrOrBytesPath", "str | bytes | os.PathLike")
9697

newsfragments/2881.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`TypeVarTuple <https://docs.python.org/3.12/library/typing.html#typing.TypeVarTuple>`_ is now used to fully type :meth:`nursery.start_soon() <trio.Nursery.start_soon>`, :func:`trio.run()`, :func:`trio.to_thread.run_sync()`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() <trio.Nursery.start>` is not fully typed yet however.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ check_untyped_defs = true
157157

158158
[tool.pyright]
159159
pythonVersion = "3.8"
160+
reportUnnecessaryTypeIgnoreComment = true
161+
typeCheckingMode = "strict"
160162

161163
[tool.pytest.ini_options]
162164
addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"]
@@ -235,6 +237,8 @@ omit = [
235237
"*/trio/_core/_tests/test_multierror_scripts/*",
236238
# Omit the generated files in trio/_core starting with _generated_
237239
"*/trio/_core/_generated_*",
240+
# Type tests aren't intended to be run, just passed to type checkers.
241+
"*/type_tests/*",
238242
]
239243
# The test suite spawns subprocesses to test some stuff, so make sure
240244
# this doesn't corrupt the coverage files

src/trio/_core/_entry_queue.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,21 @@
22

33
import threading
44
from collections import deque
5-
from typing import Callable, Iterable, NoReturn, Tuple
5+
from typing import TYPE_CHECKING, Callable, NoReturn, Tuple
66

77
import attr
88

99
from .. import _core
1010
from .._util import NoPublicConstructor, final
1111
from ._wakeup_socketpair import WakeupSocketpair
1212

13-
# TODO: Type with TypeVarTuple, at least to an extent where it makes
14-
# the public interface safe.
13+
if TYPE_CHECKING:
14+
from typing_extensions import TypeVarTuple, Unpack
15+
16+
PosArgsT = TypeVarTuple("PosArgsT")
17+
1518
Function = Callable[..., object]
16-
Job = Tuple[Function, Iterable[object]]
19+
Job = Tuple[Function, Tuple[object, ...]]
1720

1821

1922
@attr.s(slots=True)
@@ -122,7 +125,10 @@ def size(self) -> int:
122125
return len(self.queue) + len(self.idempotent_queue)
123126

124127
def run_sync_soon(
125-
self, sync_fn: Function, *args: object, idempotent: bool = False
128+
self,
129+
sync_fn: Callable[[Unpack[PosArgsT]], object],
130+
*args: Unpack[PosArgsT],
131+
idempotent: bool = False,
126132
) -> None:
127133
with self.lock:
128134
if self.done:
@@ -163,7 +169,10 @@ class TrioToken(metaclass=NoPublicConstructor):
163169
_reentry_queue: EntryQueue = attr.ib()
164170

165171
def run_sync_soon(
166-
self, sync_fn: Function, *args: object, idempotent: bool = False
172+
self,
173+
sync_fn: Callable[[Unpack[PosArgsT]], object],
174+
*args: Unpack[PosArgsT],
175+
idempotent: bool = False,
167176
) -> None:
168177
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
169178
Trio task.

src/trio/_core/_generated_instrumentation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if TYPE_CHECKING:
1212
from ._instrumentation import Instrument
1313

14+
__all__ = ["add_instrument", "remove_instrument"]
15+
1416

1517
def add_instrument(instrument: Instrument) -> None:
1618
"""Start instrumenting the current run loop with the given instrument.

src/trio/_core/_generated_io_epoll.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
assert not TYPE_CHECKING or sys.platform == "linux"
1616

1717

18+
__all__ = ["notify_closing", "wait_readable", "wait_writable"]
19+
20+
1821
async def wait_readable(fd: (int | _HasFileNo)) -> None:
1922
"""Block until the kernel reports that the given object is readable.
2023

src/trio/_core/_generated_io_kqueue.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
assert not TYPE_CHECKING or sys.platform == "darwin"
2020

2121

22+
__all__ = [
23+
"current_kqueue",
24+
"monitor_kevent",
25+
"notify_closing",
26+
"wait_kevent",
27+
"wait_readable",
28+
"wait_writable",
29+
]
30+
31+
2232
def current_kqueue() -> select.kqueue:
2333
"""TODO: these are implemented, but are currently more of a sketch than
2434
anything real. See `#26

src/trio/_core/_generated_io_windows.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@
1919
assert not TYPE_CHECKING or sys.platform == "win32"
2020

2121

22+
__all__ = [
23+
"current_iocp",
24+
"monitor_completion_key",
25+
"notify_closing",
26+
"readinto_overlapped",
27+
"register_with_iocp",
28+
"wait_overlapped",
29+
"wait_readable",
30+
"wait_writable",
31+
"write_overlapped",
32+
]
33+
34+
2235
async def wait_readable(sock: (_HasFileNo | int)) -> None:
2336
"""Block until the kernel reports that the given object is readable.
2437

src/trio/_core/_generated_run.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,23 @@
1313
from collections.abc import Awaitable, Callable
1414

1515
from outcome import Outcome
16+
from typing_extensions import Unpack
1617

1718
from .._abc import Clock
1819
from ._entry_queue import TrioToken
20+
from ._run import PosArgT
21+
22+
23+
__all__ = [
24+
"current_clock",
25+
"current_root_task",
26+
"current_statistics",
27+
"current_time",
28+
"current_trio_token",
29+
"reschedule",
30+
"spawn_system_task",
31+
"wait_all_tasks_blocked",
32+
]
1933

2034

2135
def current_statistics() -> RunStatistics:
@@ -113,8 +127,8 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None:
113127

114128

115129
def spawn_system_task(
116-
async_fn: Callable[..., Awaitable[object]],
117-
*args: object,
130+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
131+
*args: Unpack[PosArgT],
118132
name: object = None,
119133
context: (contextvars.Context | None) = None,
120134
) -> Task:

src/trio/_core/_run.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
if sys.version_info < (3, 11):
5454
from exceptiongroup import BaseExceptionGroup
5555

56+
FnT = TypeVar("FnT", bound="Callable[..., Any]")
57+
StatusT = TypeVar("StatusT")
58+
StatusT_co = TypeVar("StatusT_co", covariant=True)
59+
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
60+
RetT = TypeVar("RetT")
61+
5662

5763
if TYPE_CHECKING:
5864
import contextvars
@@ -70,19 +76,25 @@
7076
# for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in
7177
# start_guest_run. Same with types.FrameType in iter_await_frames
7278
import outcome
73-
from typing_extensions import Self
79+
from typing_extensions import Self, TypeVarTuple, Unpack
80+
81+
PosArgT = TypeVarTuple("PosArgT")
82+
83+
# Needs to be guarded, since Unpack[] would be evaluated at runtime.
84+
class _NurseryStartFunc(Protocol[Unpack[PosArgT], StatusT_co]):
85+
"""Type of functions passed to `nursery.start() <trio.Nursery.start>`."""
86+
87+
def __call__(
88+
self, *args: Unpack[PosArgT], task_status: TaskStatus[StatusT_co]
89+
) -> Awaitable[object]:
90+
...
91+
7492

7593
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000
7694

7795
# Passed as a sentinel
7896
_NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object())
7997

80-
FnT = TypeVar("FnT", bound="Callable[..., Any]")
81-
StatusT = TypeVar("StatusT")
82-
StatusT_co = TypeVar("StatusT_co", covariant=True)
83-
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
84-
RetT = TypeVar("RetT")
85-
8698

8799
@final
88100
class _NoStatus(metaclass=NoPublicConstructor):
@@ -1119,9 +1131,8 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort:
11191131

11201132
def start_soon(
11211133
self,
1122-
# TODO: TypeVarTuple
1123-
async_fn: Callable[..., Awaitable[object]],
1124-
*args: object,
1134+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
1135+
*args: Unpack[PosArgT],
11251136
name: object = None,
11261137
) -> None:
11271138
"""Creates a child task, scheduling ``await async_fn(*args)``.
@@ -1170,7 +1181,7 @@ async def start(
11701181
async_fn: Callable[..., Awaitable[object]],
11711182
*args: object,
11721183
name: object = None,
1173-
) -> StatusT:
1184+
) -> Any:
11741185
r"""Creates and initializes a child task.
11751186
11761187
Like :meth:`start_soon`, but blocks until the new task has
@@ -1219,7 +1230,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
12191230
# `run` option, which would cause it to wrap a pre-started()
12201231
# exception in an extra ExceptionGroup. See #2611.
12211232
async with open_nursery(strict_exception_groups=False) as old_nursery:
1222-
task_status: _TaskStatus[StatusT] = _TaskStatus(old_nursery, self)
1233+
task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self)
12231234
thunk = functools.partial(async_fn, task_status=task_status)
12241235
task = GLOBAL_RUN_CONTEXT.runner.spawn_impl(
12251236
thunk, args, old_nursery, name
@@ -1232,7 +1243,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
12321243
# (Any exceptions propagate directly out of the above.)
12331244
if task_status._value is _NoStatus:
12341245
raise RuntimeError("child exited without calling task_status.started()")
1235-
return task_status._value # type: ignore[return-value] # Mypy doesn't narrow yet.
1246+
return task_status._value
12361247
finally:
12371248
self._pending_starts -= 1
12381249
self._check_nursery_closed()
@@ -1690,9 +1701,8 @@ def reschedule( # type: ignore[misc]
16901701

16911702
def spawn_impl(
16921703
self,
1693-
# TODO: TypeVarTuple
1694-
async_fn: Callable[..., Awaitable[object]],
1695-
args: tuple[object, ...],
1704+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
1705+
args: tuple[Unpack[PosArgT]],
16961706
nursery: Nursery | None,
16971707
name: object,
16981708
*,
@@ -1721,7 +1731,8 @@ def spawn_impl(
17211731
# Call the function and get the coroutine object, while giving helpful
17221732
# errors for common mistakes.
17231733
######
1724-
coro = context.run(coroutine_or_error, async_fn, *args)
1734+
# TypeVarTuple passed into ParamSpec function confuses Mypy.
1735+
coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type]
17251736

17261737
if name is None:
17271738
name = async_fn
@@ -1808,12 +1819,11 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None:
18081819
# System tasks and init
18091820
################
18101821

1811-
@_public # Type-ignore due to use of Any here.
1812-
def spawn_system_task( # type: ignore[misc]
1822+
@_public
1823+
def spawn_system_task(
18131824
self,
1814-
# TODO: TypeVarTuple
1815-
async_fn: Callable[..., Awaitable[object]],
1816-
*args: object,
1825+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
1826+
*args: Unpack[PosArgT],
18171827
name: object = None,
18181828
context: contextvars.Context | None = None,
18191829
) -> Task:
@@ -1878,10 +1888,9 @@ def spawn_system_task( # type: ignore[misc]
18781888
)
18791889

18801890
async def init(
1881-
# TODO: TypeVarTuple
18821891
self,
1883-
async_fn: Callable[..., Awaitable[object]],
1884-
args: tuple[object, ...],
1892+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
1893+
args: tuple[Unpack[PosArgT]],
18851894
) -> None:
18861895
# run_sync_soon task runs here:
18871896
async with open_nursery() as run_sync_soon_nursery:
@@ -2407,8 +2416,8 @@ def my_done_callback(run_outcome):
24072416
# straight through.
24082417
def unrolled_run(
24092418
runner: Runner,
2410-
async_fn: Callable[..., object],
2411-
args: tuple[object, ...],
2419+
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
2420+
args: tuple[Unpack[PosArgT]],
24122421
host_uses_signal_set_wakeup_fd: bool = False,
24132422
) -> Generator[float, EventResult, None]:
24142423
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True

src/trio/_core/_tests/test_guest_mode.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,6 @@ async def trio_main() -> None:
658658
# Ensure we don't pollute the thread-level context if run under
659659
# an asyncio without contextvars support (3.6)
660660
context = contextvars.copy_context()
661-
if TYPE_CHECKING:
662-
aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)
663661
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
664662

665663
assert record == {("asyncio", "asyncio"), ("trio", "trio")}

src/trio/_core/_tests/test_local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ async def task1() -> None:
7777
t1.set("plaice")
7878
assert t1.get() == "plaice"
7979

80-
async def task2(tok: str) -> None:
81-
t1.reset(token)
80+
async def task2(tok: RunVarToken[str]) -> None:
81+
t1.reset(tok)
8282

8383
with pytest.raises(LookupError):
8484
t1.get()

0 commit comments

Comments
 (0)