Skip to content

Commit e3fde28

Browse files
authored
[flake8-pyi] Allow overloaded __exit__ and __aexit__ definitions (PYI036) (#11057)
1 parent 1c8849f commit e3fde28

File tree

6 files changed

+481
-50
lines changed

6 files changed

+481
-50
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44
from collections.abc import Awaitable
55
from types import TracebackType
6-
from typing import Any, Type
6+
from typing import Any, Type, overload
77

88
import _typeshed
99
import typing_extensions
@@ -73,3 +73,97 @@ async def __aexit__(self, /, typ: type[BaseException] | None, *args: Any) -> Awa
7373
class BadSix:
7474
def __exit__(self, typ, exc, tb, weird_extra_arg, extra_arg2 = None) -> None: ... # PYI036: Extra arg must have default
7575
async def __aexit__(self, typ, exc, tb, *, weird_extra_arg) -> None: ... # PYI036: kwargs must have default
76+
77+
class AllPositionalOnlyArgs:
78+
def __exit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
79+
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
80+
81+
class BadAllPositionalOnlyArgs:
82+
def __exit__(self, typ: type[Exception] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
83+
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType, /) -> None: ...
84+
85+
# Definitions not in a class scope can do whatever, we don't care
86+
def __exit__(self, *args: bool) -> None: ...
87+
async def __aexit__(self, *, go_crazy: bytes) -> list[str]: ...
88+
89+
# Here come the overloads...
90+
91+
class AcceptableOverload1:
92+
@overload
93+
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
94+
@overload
95+
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...
96+
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...
97+
98+
# Using `object` or `Unused` in an overload definition is kinda strange,
99+
# but let's allow it to be on the safe side
100+
class AcceptableOverload2:
101+
@overload
102+
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
103+
@overload
104+
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...
105+
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...
106+
107+
class AcceptableOverload3:
108+
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
109+
# We don't have the ability (yet) to do arbitrary checking
110+
# of whether one function definition is a subtype of another...
111+
@overload
112+
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
113+
@overload
114+
def __exit__(self, *args: object) -> None: ...
115+
def __exit__(self, *args: object) -> None: ...
116+
@overload
117+
async def __aexit__(self, exc_typ: bool, /, exc: bool, exc_tb: bool, *, keyword_only: str) -> None: ...
118+
@overload
119+
async def __aexit__(self, *args: object) -> None: ...
120+
async def __aexit__(self, *args: object) -> None: ...
121+
122+
class AcceptableOverload4:
123+
# Same as above
124+
@overload
125+
def __exit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType) -> None: ...
126+
@overload
127+
def __exit__(self, *args: object) -> None: ...
128+
def __exit__(self, *args: object) -> None: ...
129+
@overload
130+
async def __aexit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType, *, extra: str = "foo") -> None: ...
131+
@overload
132+
async def __aexit__(self, exc_typ: None, exc: None, tb: None) -> None: ...
133+
async def __aexit__(self, *args: object) -> None: ...
134+
135+
class StrangeNumberOfOverloads:
136+
# Only one overload? Type checkers will emit an error, but we should just ignore it
137+
@overload
138+
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
139+
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
140+
# More than two overloads? Anything could be going on; again, just ignore all the overloads
141+
@overload
142+
async def __aexit__(self, arg: bool) -> None: ...
143+
@overload
144+
async def __aexit__(self, arg: None, arg2: None, arg3: None) -> None: ...
145+
@overload
146+
async def __aexit__(self, arg: bool, arg2: bool, arg3: bool) -> None: ...
147+
async def __aexit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
148+
149+
# TODO: maybe we should emit an error on this one as well?
150+
class BizarreAsyncSyncOverloadMismatch:
151+
@overload
152+
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
153+
@overload
154+
async def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
155+
def __exit__(self, *args: object) -> None: ...
156+
157+
class UnacceptableOverload1:
158+
@overload
159+
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
160+
@overload
161+
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036
162+
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
163+
164+
class UnacceptableOverload2:
165+
@overload
166+
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
167+
@overload
168+
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036
169+
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...

crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi

+84-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import types
33
import typing
44
from collections.abc import Awaitable
55
from types import TracebackType
6-
from typing import Any, Type
6+
from typing import Any, Type, overload
77

88
import _typeshed
99
import typing_extensions
@@ -80,3 +80,86 @@ def isolated_scope():
8080

8181
class ShouldNotError:
8282
def __exit__(self, typ: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
83+
84+
class AllPositionalOnlyArgs:
85+
def __exit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
86+
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
87+
88+
class BadAllPositionalOnlyArgs:
89+
def __exit__(self, typ: type[Exception] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
90+
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType, /) -> None: ...
91+
92+
# Definitions not in a class scope can do whatever, we don't care
93+
def __exit__(self, *args: bool) -> None: ...
94+
async def __aexit__(self, *, go_crazy: bytes) -> list[str]: ...
95+
96+
# Here come the overloads...
97+
98+
class AcceptableOverload1:
99+
@overload
100+
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
101+
@overload
102+
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...
103+
104+
# Using `object` or `Unused` in an overload definition is kinda strange,
105+
# but let's allow it to be on the safe side
106+
class AcceptableOverload2:
107+
@overload
108+
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
109+
@overload
110+
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...
111+
112+
class AcceptableOverload3:
113+
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
114+
# We don't have the ability (yet) to do arbitrary checking
115+
# of whether one function definition is a subtype of another...
116+
@overload
117+
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
118+
@overload
119+
def __exit__(self, *args: object) -> None: ...
120+
@overload
121+
async def __aexit__(self, exc_typ: bool, /, exc: bool, exc_tb: bool, *, keyword_only: str) -> None: ...
122+
@overload
123+
async def __aexit__(self, *args: object) -> None: ...
124+
125+
class AcceptableOverload4:
126+
# Same as above
127+
@overload
128+
def __exit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType) -> None: ...
129+
@overload
130+
def __exit__(self, *args: object) -> None: ...
131+
@overload
132+
async def __aexit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType, *, extra: str = "foo") -> None: ...
133+
@overload
134+
async def __aexit__(self, exc_typ: None, exc: None, tb: None) -> None: ...
135+
136+
class StrangeNumberOfOverloads:
137+
# Only one overload? Type checkers will emit an error, but we should just ignore it
138+
@overload
139+
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
140+
# More than two overloads? Anything could be going on; again, just ignore all the overloads
141+
@overload
142+
async def __aexit__(self, arg: bool) -> None: ...
143+
@overload
144+
async def __aexit__(self, arg: None, arg2: None, arg3: None) -> None: ...
145+
@overload
146+
async def __aexit__(self, arg: bool, arg2: bool, arg3: bool) -> None: ...
147+
148+
# TODO: maybe we should emit an error on this one as well?
149+
class BizarreAsyncSyncOverloadMismatch:
150+
@overload
151+
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
152+
@overload
153+
async def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
154+
155+
class UnacceptableOverload1:
156+
@overload
157+
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
158+
@overload
159+
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036
160+
161+
class UnacceptableOverload2:
162+
@overload
163+
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
164+
@overload
165+
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036

crates/ruff_linter/src/checkers/ast/analyze/statement.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
174174
}
175175
}
176176
if checker.enabled(Rule::BadExitAnnotation) {
177-
flake8_pyi::rules::bad_exit_annotation(checker, *is_async, name, parameters);
177+
flake8_pyi::rules::bad_exit_annotation(checker, function_def);
178178
}
179179
if checker.enabled(Rule::RedundantNumericUnion) {
180180
flake8_pyi::rules::redundant_numeric_union(checker, parameters);

0 commit comments

Comments
 (0)