From 32d7b128fbf71c9556a0a935e3d2c9e2ee5f0579 Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 11:47:10 -0400 Subject: [PATCH 01/16] Improve typing of inference functions --- astroid/inference.py | 1 - astroid/inference_tip.py | 32 +++++++++++++++++++------------- astroid/nodes/node_ng.py | 4 ++-- astroid/typing.py | 24 ++++++++++++++++++++---- 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/astroid/inference.py b/astroid/inference.py index 4dadc11698..1729d81df7 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -254,7 +254,6 @@ def infer_name( return bases._infer_stmts(stmts, context, frame) -# pylint: disable=no-value-for-parameter # The order of the decorators here is important # See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5 nodes.Name._infer = decorators.raise_if_nothing_inferred( diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 92cb6b4fe1..2529b98f3d 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -6,20 +6,19 @@ from __future__ import annotations -import sys from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG -from astroid.typing import InferenceResult, InferFn - -if sys.version_info >= (3, 11): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec - -_P = ParamSpec("_P") +from astroid.typing import ( + _P, + InferenceResult, + InferFn, + InferFnExplicit, + InferFnTransform, +) _cache: dict[ tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] @@ -35,12 +34,18 @@ def clear_inference_tip_cache() -> None: def _inference_tip_cached( func: Callable[_P, Iterator[InferenceResult]], -) -> Callable[_P, Iterator[InferenceResult]]: +) -> InferFnExplicit: """Cache decorator used for inference tips.""" - def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]: + def inner( + *args: _P.args, **kwargs: _P.kwargs + ) -> Iterator[InferenceResult] | list[InferenceResult]: node = args[0] context = args[1] + if TYPE_CHECKING: + assert isinstance(node, NodeNG) + assert context is None or isinstance(context, InferenceContext) + partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same @@ -64,7 +69,9 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]: return inner -def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn: +def inference_tip( + infer_function: InferFn, raise_on_overwrite: bool = False +) -> InferFnTransform: """Given an instance specific inference function, return a function to be given to AstroidManager().register_transform to set this inference function. @@ -100,7 +107,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG: node=node, ) ) - # pylint: disable=no-value-for-parameter node._explicit_inference = _inference_tip_cached(infer_function) return node diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py index de5dec77d7..9ad7c4bf9a 100644 --- a/astroid/nodes/node_ng.py +++ b/astroid/nodes/node_ng.py @@ -35,7 +35,7 @@ from astroid.nodes.as_string import AsStringVisitor from astroid.nodes.const import OP_PRECEDENCE from astroid.nodes.utils import Position -from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn +from astroid.typing import InferenceErrorInfo, InferenceResult, InferFnExplicit if TYPE_CHECKING: from astroid import nodes @@ -80,7 +80,7 @@ class NodeNG: _other_other_fields: ClassVar[tuple[str, ...]] = () """Attributes that contain AST-dependent fields.""" # instance specific inference function infer(node, context) - _explicit_inference: InferFn | None = None + _explicit_inference: InferFnExplicit | None = None def __init__( self, diff --git a/astroid/typing.py b/astroid/typing.py index f42832e47e..df55093f31 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -4,14 +4,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union +import sys +from typing import ( + TYPE_CHECKING, + Callable, + Generator, + Iterator, + TypedDict, + TypeVar, + Union, +) if TYPE_CHECKING: from astroid import bases, exceptions, nodes, transforms, util from astroid.context import InferenceContext from astroid.interpreter._import import spec +if sys.version_info >= (3, 11): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec +_P = ParamSpec("_P") _NodesT = TypeVar("_NodesT", bound="nodes.NodeNG") @@ -24,9 +38,6 @@ class InferenceErrorInfo(TypedDict): context: InferenceContext | None -InferFn = Callable[..., Any] - - class AstroidManagerBrain(TypedDict): """Dictionary to store relevant information for a AstroidManager class.""" @@ -67,3 +78,8 @@ class AstroidManagerBrain(TypedDict): ], Generator[InferenceResult, None, None], ] + +InferFn = Callable[..., Iterator[InferenceResult]] +# pylint: disable-next=unsupported-binary-operation +InferFnExplicit = Callable[_P, Iterator[InferenceResult] | list[InferenceResult]] +InferFnTransform = Callable[[_NodesT, InferFn], _NodesT] From 28b021446dba60905f174364b6b04f5d185d17fc Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 11:51:24 -0400 Subject: [PATCH 02/16] fixup! Improve typing --- astroid/typing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/astroid/typing.py b/astroid/typing.py index df55093f31..844ece11c6 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -80,6 +80,5 @@ class AstroidManagerBrain(TypedDict): ] InferFn = Callable[..., Iterator[InferenceResult]] -# pylint: disable-next=unsupported-binary-operation -InferFnExplicit = Callable[_P, Iterator[InferenceResult] | list[InferenceResult]] +InferFnExplicit = Callable[_P, Union[Iterator[InferenceResult], list[InferenceResult]]] InferFnTransform = Callable[[_NodesT, InferFn], _NodesT] From 91b4128b4e4f2ec3c68982c3f30f515e6dc02d19 Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:15:06 -0400 Subject: [PATCH 03/16] Avoid assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com> --- astroid/inference_tip.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 2529b98f3d..741941b645 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -40,11 +40,8 @@ def _inference_tip_cached( def inner( *args: _P.args, **kwargs: _P.kwargs ) -> Iterator[InferenceResult] | list[InferenceResult]: - node = args[0] - context = args[1] - if TYPE_CHECKING: - assert isinstance(node, NodeNG) - assert context is None or isinstance(context, InferenceContext) + node: NodeNG = args[0] + context: InferenceContext | None = args[1] partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: From 113d4bcfddcc486672fc19de263bd5389ef5c04c Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:15:32 -0400 Subject: [PATCH 04/16] Remove import --- astroid/inference_tip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 741941b645..8c3d3048d3 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -7,7 +7,6 @@ from __future__ import annotations from collections.abc import Callable, Iterator -from typing import TYPE_CHECKING from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault From 5fd854fff4439c4b322a706687d069e25e6a802e Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:25:55 -0400 Subject: [PATCH 05/16] Cast types --- astroid/inference_tip.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 8c3d3048d3..a7f612c9cf 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -6,18 +6,21 @@ from __future__ import annotations +import sys from collections.abc import Callable, Iterator +from typing import cast from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG -from astroid.typing import ( - _P, - InferenceResult, - InferFn, - InferFnExplicit, - InferFnTransform, -) +from astroid.typing import InferenceResult, InferFn, InferFnExplicit, InferFnTransform + +if sys.version_info >= (3, 11): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +_P = ParamSpec("_P") _cache: dict[ tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] @@ -39,8 +42,8 @@ def _inference_tip_cached( def inner( *args: _P.args, **kwargs: _P.kwargs ) -> Iterator[InferenceResult] | list[InferenceResult]: - node: NodeNG = args[0] - context: InferenceContext | None = args[1] + node = cast(NodeNG, args[0]) + context = cast(InferenceContext | None, args[1]) partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: From 99752cf52a99658742c99570179304f027d5eb7d Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:31:30 -0400 Subject: [PATCH 06/16] fixup -- optional --- astroid/inference_tip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index a7f612c9cf..fa13dd7459 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -8,7 +8,7 @@ import sys from collections.abc import Callable, Iterator -from typing import cast +from typing import Optional, cast from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault @@ -43,7 +43,7 @@ def inner( *args: _P.args, **kwargs: _P.kwargs ) -> Iterator[InferenceResult] | list[InferenceResult]: node = cast(NodeNG, args[0]) - context = cast(InferenceContext | None, args[1]) + context = cast(Optional[InferenceContext], args[1]) partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: From 06f42bba092a4e7873c3600e6ab9664747aa71a0 Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:40:57 -0400 Subject: [PATCH 07/16] Remove support for dynamic infererence function arguments --- astroid/inference_tip.py | 18 +++--------------- astroid/typing.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index fa13dd7459..1a80102410 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -6,22 +6,13 @@ from __future__ import annotations -import sys from collections.abc import Callable, Iterator -from typing import Optional, cast from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG from astroid.typing import InferenceResult, InferFn, InferFnExplicit, InferFnTransform -if sys.version_info >= (3, 11): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec - -_P = ParamSpec("_P") - _cache: dict[ tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] ] = {} @@ -35,16 +26,13 @@ def clear_inference_tip_cache() -> None: def _inference_tip_cached( - func: Callable[_P, Iterator[InferenceResult]], + func: Callable[[NodeNG, InferenceContext | None], Iterator[InferenceResult]], ) -> InferFnExplicit: """Cache decorator used for inference tips.""" def inner( - *args: _P.args, **kwargs: _P.kwargs + node: NodeNG, context: InferenceContext | None ) -> Iterator[InferenceResult] | list[InferenceResult]: - node = cast(NodeNG, args[0]) - context = cast(Optional[InferenceContext], args[1]) - partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same @@ -59,7 +47,7 @@ def inner( # with slightly different contexts while still passing the simple # test cases included with this commit. _CURRENTLY_INFERRING.add(partial_cache_key) - result = _cache[func, node, context] = list(func(*args, **kwargs)) + result = _cache[func, node, context] = list(func(node, context)) # Remove recursion guard. _CURRENTLY_INFERRING.remove(partial_cache_key) diff --git a/astroid/typing.py b/astroid/typing.py index 844ece11c6..911c4654d7 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -4,12 +4,12 @@ from __future__ import annotations -import sys from typing import ( TYPE_CHECKING, Callable, Generator, Iterator, + Optional, TypedDict, TypeVar, Union, @@ -20,12 +20,6 @@ from astroid.context import InferenceContext from astroid.interpreter._import import spec -if sys.version_info >= (3, 11): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec - -_P = ParamSpec("_P") _NodesT = TypeVar("_NodesT", bound="nodes.NodeNG") @@ -79,6 +73,9 @@ class AstroidManagerBrain(TypedDict): Generator[InferenceResult, None, None], ] -InferFn = Callable[..., Iterator[InferenceResult]] -InferFnExplicit = Callable[_P, Union[Iterator[InferenceResult], list[InferenceResult]]] +InferFn = Callable[[_NodesT, Optional["InferenceContext"]], Iterator[InferenceResult]] +InferFnExplicit = Callable[ + [_NodesT, Optional["InferenceContext"]], + Union[Iterator[InferenceResult], list[InferenceResult]], +] InferFnTransform = Callable[[_NodesT, InferFn], _NodesT] From c3ce316af51f63fabf2d78d806a44d913eb16b53 Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:42:18 -0400 Subject: [PATCH 08/16] Restore blank line --- astroid/typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/astroid/typing.py b/astroid/typing.py index 911c4654d7..b7f06cc04b 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -20,6 +20,7 @@ from astroid.context import InferenceContext from astroid.interpreter._import import spec + _NodesT = TypeVar("_NodesT", bound="nodes.NodeNG") From 7df995f3911a0c277928427c162b3383b67d3d8e Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 12:51:33 -0400 Subject: [PATCH 09/16] import typing.List --- astroid/typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/astroid/typing.py b/astroid/typing.py index b7f06cc04b..85d7031d04 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -9,6 +9,7 @@ Callable, Generator, Iterator, + List, Optional, TypedDict, TypeVar, @@ -77,6 +78,6 @@ class AstroidManagerBrain(TypedDict): InferFn = Callable[[_NodesT, Optional["InferenceContext"]], Iterator[InferenceResult]] InferFnExplicit = Callable[ [_NodesT, Optional["InferenceContext"]], - Union[Iterator[InferenceResult], list[InferenceResult]], + Union[Iterator[InferenceResult], List[InferenceResult]], ] InferFnTransform = Callable[[_NodesT, InferFn], _NodesT] From 5e07b6aaa8396c2e350c11f74cab6ee0a148d55c Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sat, 6 May 2023 13:46:14 -0400 Subject: [PATCH 10/16] Restore support for **kwargs --- astroid/inference_tip.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 1a80102410..ee936ed5c8 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -7,6 +7,7 @@ from __future__ import annotations from collections.abc import Callable, Iterator +from typing import Any from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault @@ -31,7 +32,7 @@ def _inference_tip_cached( """Cache decorator used for inference tips.""" def inner( - node: NodeNG, context: InferenceContext | None + node: NodeNG, context: InferenceContext | None, **kwargs: Any, ) -> Iterator[InferenceResult] | list[InferenceResult]: partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: @@ -47,7 +48,7 @@ def inner( # with slightly different contexts while still passing the simple # test cases included with this commit. _CURRENTLY_INFERRING.add(partial_cache_key) - result = _cache[func, node, context] = list(func(node, context)) + result = _cache[func, node, context] = list(func(node, context, **kwargs)) # Remove recursion guard. _CURRENTLY_INFERRING.remove(partial_cache_key) From fd0b10556317aca7b795e2622b92bf8e94ddd5d7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 May 2023 17:46:33 +0000 Subject: [PATCH 11/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- astroid/inference_tip.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index ee936ed5c8..148d20ffde 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -32,7 +32,9 @@ def _inference_tip_cached( """Cache decorator used for inference tips.""" def inner( - node: NodeNG, context: InferenceContext | None, **kwargs: Any, + node: NodeNG, + context: InferenceContext | None, + **kwargs: Any, ) -> Iterator[InferenceResult] | list[InferenceResult]: partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: From 887faf028e0cf7638c509ae91b30f77f155d102a Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Sun, 7 May 2023 12:43:01 -0400 Subject: [PATCH 12/16] Use InferFn --- astroid/inference_tip.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 148d20ffde..cbaed4ff37 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -26,9 +26,7 @@ def clear_inference_tip_cache() -> None: _cache.clear() -def _inference_tip_cached( - func: Callable[[NodeNG, InferenceContext | None], Iterator[InferenceResult]], -) -> InferFnExplicit: +def _inference_tip_cached(func: InferFn) -> InferFnExplicit: """Cache decorator used for inference tips.""" def inner( From 079d83bd861592245c0a0c23be07ff13ffc34851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 7 May 2023 20:42:12 +0200 Subject: [PATCH 13/16] Use ``Protocols`` --- astroid/inference_tip.py | 39 +++++++++++++++++++------------- astroid/nodes/node_ng.py | 11 +++++++-- astroid/transforms.py | 11 ++++----- astroid/typing.py | 48 ++++++++++++++++++++++++++++++---------- 4 files changed, 73 insertions(+), 36 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index cbaed4ff37..a20a651cd9 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -6,19 +6,25 @@ from __future__ import annotations -from collections.abc import Callable, Iterator -from typing import Any +from collections.abc import Generator +from typing import Any, TypeVar from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG -from astroid.typing import InferenceResult, InferFn, InferFnExplicit, InferFnTransform +from astroid.typing import ( + InferenceResult, + InferFn, + TransformFn, +) _cache: dict[ - tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] + tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult] ] = {} -_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set() +_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set() + +_NodesT = TypeVar("_NodesT", bound=NodeNG) def clear_inference_tip_cache() -> None: @@ -26,21 +32,22 @@ def clear_inference_tip_cache() -> None: _cache.clear() -def _inference_tip_cached(func: InferFn) -> InferFnExplicit: +def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]: """Cache decorator used for inference tips.""" def inner( - node: NodeNG, - context: InferenceContext | None, + node: _NodesT, + context: InferenceContext | None = None, **kwargs: Any, - ) -> Iterator[InferenceResult] | list[InferenceResult]: + ) -> Generator[InferenceResult, None, None]: partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same # func + node we raise here. raise UseInferenceDefault try: - return _cache[func, node, context] + yield from _cache[func, node, context] + return except KeyError: # Recursion guard with a partial cache key. # Using the full key causes a recursion error on PyPy. @@ -48,18 +55,18 @@ def inner( # with slightly different contexts while still passing the simple # test cases included with this commit. _CURRENTLY_INFERRING.add(partial_cache_key) - result = _cache[func, node, context] = list(func(node, context, **kwargs)) + _cache[func, node, context] = list(func(node, context, **kwargs)) # Remove recursion guard. _CURRENTLY_INFERRING.remove(partial_cache_key) - return iter(result) + yield from _cache[func, node, context] return inner def inference_tip( - infer_function: InferFn, raise_on_overwrite: bool = False -) -> InferFnTransform: + infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False +) -> TransformFn[_NodesT]: """Given an instance specific inference function, return a function to be given to AstroidManager().register_transform to set this inference function. @@ -81,7 +88,9 @@ def inference_tip( excess overwrites. """ - def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG: + def transform( + node: _NodesT, infer_function: InferFn[_NodesT] = infer_function + ) -> _NodesT: if ( raise_on_overwrite and node._explicit_inference is not None diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py index 9ad7c4bf9a..9857c62cd7 100644 --- a/astroid/nodes/node_ng.py +++ b/astroid/nodes/node_ng.py @@ -5,6 +5,7 @@ from __future__ import annotations import pprint +import sys import warnings from collections.abc import Generator, Iterator from functools import cached_property @@ -35,7 +36,13 @@ from astroid.nodes.as_string import AsStringVisitor from astroid.nodes.const import OP_PRECEDENCE from astroid.nodes.utils import Position -from astroid.typing import InferenceErrorInfo, InferenceResult, InferFnExplicit +from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + if TYPE_CHECKING: from astroid import nodes @@ -80,7 +87,7 @@ class NodeNG: _other_other_fields: ClassVar[tuple[str, ...]] = () """Attributes that contain AST-dependent fields.""" # instance specific inference function infer(node, context) - _explicit_inference: InferFnExplicit | None = None + _explicit_inference: InferFn[Self] | None = None def __init__( self, diff --git a/astroid/transforms.py b/astroid/transforms.py index f6c727948d..29332223f8 100644 --- a/astroid/transforms.py +++ b/astroid/transforms.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload from astroid.context import _invalidate_cache -from astroid.typing import SuccessfulInferenceResult +from astroid.typing import SuccessfulInferenceResult, TransformFn if TYPE_CHECKING: from astroid import nodes @@ -17,9 +17,6 @@ _SuccessfulInferenceResultT = TypeVar( "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult ) - _Transform = Callable[ - [_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult] - ] _Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]] _Vistables = Union[ @@ -52,7 +49,7 @@ def __init__(self) -> None: type[SuccessfulInferenceResult], list[ tuple[ - _Transform[SuccessfulInferenceResult], + TransformFn[SuccessfulInferenceResult], _Predicate[SuccessfulInferenceResult], ] ], @@ -123,7 +120,7 @@ def _visit_generic(self, node: _Vistables) -> _VisitReturns: def register_transform( self, node_class: type[_SuccessfulInferenceResultT], - transform: _Transform[_SuccessfulInferenceResultT], + transform: TransformFn[_SuccessfulInferenceResultT], predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, ) -> None: """Register `transform(node)` function to be applied on the given node. @@ -139,7 +136,7 @@ def register_transform( def unregister_transform( self, node_class: type[_SuccessfulInferenceResultT], - transform: _Transform[_SuccessfulInferenceResultT], + transform: TransformFn[_SuccessfulInferenceResultT], predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, ) -> None: """Unregister the given transform.""" diff --git a/astroid/typing.py b/astroid/typing.py index 85d7031d04..584dec117b 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -6,11 +6,11 @@ from typing import ( TYPE_CHECKING, + Any, Callable, Generator, - Iterator, - List, - Optional, + Generic, + Protocol, TypedDict, TypeVar, Union, @@ -22,9 +22,6 @@ from astroid.interpreter._import import spec -_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG") - - class InferenceErrorInfo(TypedDict): """Store additional Inference error information raised with StopIteration exception. @@ -53,6 +50,11 @@ class AstroidManagerBrain(TypedDict): _SuccessfulInferenceResultT = TypeVar( "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult ) +_SuccessfulInferenceResultT_contra = TypeVar( + "_SuccessfulInferenceResultT_contra", + bound=SuccessfulInferenceResult, + contravariant=True, +) ConstFactoryResult = Union[ "nodes.List", @@ -75,9 +77,31 @@ class AstroidManagerBrain(TypedDict): Generator[InferenceResult, None, None], ] -InferFn = Callable[[_NodesT, Optional["InferenceContext"]], Iterator[InferenceResult]] -InferFnExplicit = Callable[ - [_NodesT, Optional["InferenceContext"]], - Union[Iterator[InferenceResult], List[InferenceResult]], -] -InferFnTransform = Callable[[_NodesT, InferFn], _NodesT] + +class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]): + def __call__( + self, + node: _SuccessfulInferenceResultT_contra, + context: InferenceContext | None = None, + **kwargs: Any, + ) -> Generator[InferenceResult, None, None]: + ... + + +class SuccessfulInferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]): + def __call__( + self, + node: _SuccessfulInferenceResultT_contra, + context: InferenceContext | None = None, + **kwargs: Any, + ) -> Generator[SuccessfulInferenceResult, None, None]: + ... + + +class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]): + def __call__( + self, + node: _SuccessfulInferenceResultT, + infer_function: InferFn[_SuccessfulInferenceResultT] = ..., + ) -> _SuccessfulInferenceResultT | None: + ... From b3b6c53a8024cf26e7b0fb2b663475174b63b686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 7 May 2023 20:43:34 +0200 Subject: [PATCH 14/16] Remove unused ``Protocol`` --- astroid/typing.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/astroid/typing.py b/astroid/typing.py index 584dec117b..76b1ed9bc6 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -88,16 +88,6 @@ def __call__( ... -class SuccessfulInferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]): - def __call__( - self, - node: _SuccessfulInferenceResultT_contra, - context: InferenceContext | None = None, - **kwargs: Any, - ) -> Generator[SuccessfulInferenceResult, None, None]: - ... - - class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]): def __call__( self, From 57425f28f210df3749d91d6378bc5eba42edbc1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 7 May 2023 23:22:04 +0200 Subject: [PATCH 15/16] Implement review feedback --- astroid/inference_tip.py | 4 ++-- astroid/typing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index a20a651cd9..44a7fcf15a 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -55,11 +55,11 @@ def inner( # with slightly different contexts while still passing the simple # test cases included with this commit. _CURRENTLY_INFERRING.add(partial_cache_key) - _cache[func, node, context] = list(func(node, context, **kwargs)) + result = _cache[func, node, context] = list(func(node, context, **kwargs)) # Remove recursion guard. _CURRENTLY_INFERRING.remove(partial_cache_key) - yield from _cache[func, node, context] + yield from result return inner diff --git a/astroid/typing.py b/astroid/typing.py index 76b1ed9bc6..0ae30fcc28 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -85,7 +85,7 @@ def __call__( context: InferenceContext | None = None, **kwargs: Any, ) -> Generator[InferenceResult, None, None]: - ... + ... # pragma: no cover class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]): @@ -94,4 +94,4 @@ def __call__( node: _SuccessfulInferenceResultT, infer_function: InferFn[_SuccessfulInferenceResultT] = ..., ) -> _SuccessfulInferenceResultT | None: - ... + ... # pragma: no cover From b39202c4bea5beecb9b51364d1a6fadc8ff7dbfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 7 May 2023 23:50:21 +0200 Subject: [PATCH 16/16] Add type ignore --- astroid/nodes/node_ng.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py index 9857c62cd7..31c842ee50 100644 --- a/astroid/nodes/node_ng.py +++ b/astroid/nodes/node_ng.py @@ -144,9 +144,17 @@ def infer( # explicit_inference is not bound, give it self explicitly try: if context is None: - yield from self._explicit_inference(self, context, **kwargs) + yield from self._explicit_inference( + self, # type: ignore[arg-type] + context, + **kwargs, + ) return - for result in self._explicit_inference(self, context, **kwargs): + for result in self._explicit_inference( + self, # type: ignore[arg-type] + context, + **kwargs, + ): context.nodes_inferred += 1 yield result return