diff --git a/crates/red_knot_python_semantic/resources/mdtest/annotations/never.md b/crates/red_knot_python_semantic/resources/mdtest/annotations/never.md index 6699239873103..1a5c5dd000564 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/annotations/never.md +++ b/crates/red_knot_python_semantic/resources/mdtest/annotations/never.md @@ -47,7 +47,9 @@ def f(): ## `typing.Never` -`typing.Never` is only available in Python 3.11 and later: +`typing.Never` is only available in Python 3.11 and later. + +### Python 3.11 ```toml [environment] @@ -57,8 +59,17 @@ python-version = "3.11" ```py from typing import Never -x: Never +reveal_type(Never) # revealed: typing.Never +``` -def f(): - reveal_type(x) # revealed: Never +### Python 3.10 + +```toml +[environment] +python-version = "3.10" +``` + +```py +# error: [unresolved-import] +from typing import Never ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/assignment/annotations.md b/crates/red_knot_python_semantic/resources/mdtest/assignment/annotations.md index c3977ed46b6c4..f696cd4ea414f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/red_knot_python_semantic/resources/mdtest/assignment/annotations.md @@ -33,8 +33,6 @@ b: tuple[int] = (42,) c: tuple[str, int] = ("42", 42) d: tuple[tuple[str, str], tuple[int, int]] = (("foo", "foo"), (42, 42)) e: tuple[str, ...] = () -# TODO: we should not emit this error -# error: [call-possibly-unbound-method] "Method `__class_getitem__` of type `Literal[tuple]` is possibly unbound" f: tuple[str, *tuple[int, ...], bytes] = ("42", b"42") g: tuple[str, Unpack[tuple[int, ...]], bytes] = ("42", b"42") h: tuple[list[int], list[int]] = ([], []) diff --git a/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md b/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md index 9ee078606d5cd..6ad75f185bb35 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md +++ b/crates/red_knot_python_semantic/resources/mdtest/boolean/short_circuit.md @@ -32,13 +32,10 @@ def _(flag: bool): ```py if True or (x := 1): - # TODO: infer that the second arm is never executed, and raise `unresolved-reference`. - # error: [possibly-unresolved-reference] - reveal_type(x) # revealed: Literal[1] + # error: [unresolved-reference] + reveal_type(x) # revealed: Unknown if True and (x := 1): - # TODO: infer that the second arm is always executed, do not raise a diagnostic - # error: [possibly-unresolved-reference] reveal_type(x) # revealed: Literal[1] ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/expression/if.md b/crates/red_knot_python_semantic/resources/mdtest/expression/if.md index 79faa45426855..6461522cefa46 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/expression/if.md +++ b/crates/red_knot_python_semantic/resources/mdtest/expression/if.md @@ -7,7 +7,7 @@ def _(flag: bool): reveal_type(1 if flag else 2) # revealed: Literal[1, 2] ``` -## Statically known branches +## Statically known conditions in if-expressions ```py reveal_type(1 if True else 2) # revealed: Literal[1] diff --git a/crates/red_knot_python_semantic/resources/mdtest/literal/ellipsis.md b/crates/red_knot_python_semantic/resources/mdtest/literal/ellipsis.md index 2b7bb7c61d9b1..241b498372d49 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/literal/ellipsis.md +++ b/crates/red_knot_python_semantic/resources/mdtest/literal/ellipsis.md @@ -1,7 +1,23 @@ # Ellipsis literals -## Simple +## Python 3.9 + +```toml +[environment] +python-version = "3.9" +``` + +```py +reveal_type(...) # revealed: ellipsis +``` + +## Python 3.10 + +```toml +[environment] +python-version = "3.10" +``` ```py -reveal_type(...) # revealed: EllipsisType | ellipsis +reveal_type(...) # revealed: EllipsisType ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md index 7da1ad3e36126..b1099a1f7ae83 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md @@ -95,10 +95,14 @@ def _(t: type[object]): ### Handling of `None` +`types.NoneType` is only available in Python 3.10 and later: + +```toml +[environment] +python-version = "3.10" +``` + ```py -# TODO: this error should ideally go away once we (1) understand `sys.version_info` branches, -# and (2) set the target Python version for this test to 3.10. -# error: [possibly-unbound-import] "Member `NoneType` of module `types` is possibly unbound" from types import NoneType def _(flag: bool): diff --git a/crates/red_knot_python_semantic/resources/mdtest/statically_known_branches.md b/crates/red_knot_python_semantic/resources/mdtest/statically_known_branches.md new file mode 100644 index 0000000000000..ff6118a99290a --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/statically_known_branches.md @@ -0,0 +1,1172 @@ +# Statically-known branches + +## Introduction + +We have the ability to infer precise types and boundness information for symbols that are defined in +branches whose conditions we can statically determine to be always true or always false. This is +useful for `sys.version_info` branches, which can make new features available based on the Python +version: + +```py path=module1.py +import sys + +if sys.version_info >= (3, 9): + SomeFeature = "available" +``` + +If we can statically determine that the condition is always true, then we can also understand that +`SomeFeature` is always bound, without raising any errors: + +```py path=test1.py +from module1 import SomeFeature + +# SomeFeature is unconditionally available here, because we are on Python 3.9 or newer: +reveal_type(SomeFeature) # revealed: Literal["available"] +``` + +Another scenario where this is useful is for `typing.TYPE_CHECKING` branches, which are often used +for conditional imports: + +```py path=module2.py +class SomeType: ... +``` + +```py path=test2.py +import typing + +if typing.TYPE_CHECKING: + from module2 import SomeType + +# `SomeType` is unconditionally available here for type checkers: +def f(s: SomeType) -> None: ... +``` + +The rest of this document contains tests for various cases where this feature can be used. + +## If statements + +### Always false + +#### If + +```py +x = 1 + +if False: + x = 2 + +reveal_type(x) # revealed: Literal[1] +``` + +#### Else + +```py +x = 1 + +if True: + pass +else: + x = 2 + +reveal_type(x) # revealed: Literal[1] +``` + +### Always true + +#### If + +```py +x = 1 + +if True: + x = 2 + +reveal_type(x) # revealed: Literal[2] +``` + +#### Else + +```py +x = 1 + +if False: + pass +else: + x = 2 + +reveal_type(x) # revealed: Literal[2] +``` + +### Ambiguous + +Just for comparison, we still infer the combined type if the condition is not statically known: + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 + +reveal_type(x) # revealed: Literal[1, 2] +``` + +### Combination of always true and always false + +```py +x = 1 + +if True: + x = 2 +else: + x = 3 + +reveal_type(x) # revealed: Literal[2] +``` + +### `elif` branches + +#### Always false + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif False: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2, 4] +``` + +#### Always true + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif True: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2, 3] +``` + +#### Ambiguous + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif flag(): + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2, 3, 4] +``` + +#### Multiple `elif` branches, always false + +Make sure that we include bindings from all non-`False` branches: + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif flag(): + x = 3 +elif False: + x = 4 +elif False: + x = 5 +elif flag(): + x = 6 +elif flag(): + x = 7 +else: + x = 8 + +reveal_type(x) # revealed: Literal[2, 3, 6, 7, 8] +``` + +#### Multiple `elif` branches, always true + +Make sure that we only include the binding from the first `elif True` branch: + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif flag(): + x = 3 +elif True: + x = 4 +elif True: + x = 5 +elif flag(): + x = 6 +else: + x = 7 + +reveal_type(x) # revealed: Literal[2, 3, 4] +``` + +#### `elif` without `else` branch + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + x = 2 +elif True: + x = 3 + +# TODO: This should be Literal[2, 3] +reveal_type(x) # revealed: Literal[1, 2, 3] +``` + +### Nested conditionals + +#### `if True` inside `if True` + +```py +x = 1 + +if True: + if True: + x = 2 +else: + x = 3 + +reveal_type(x) # revealed: Literal[2] +``` + +#### `if False` inside `if True` + +```py +x = 1 + +if True: + if False: + x = 2 +else: + x = 3 + +reveal_type(x) # revealed: Literal[1] +``` + +#### `if ` inside `if True` + +```py +def flag() -> bool: ... + +x = 1 + +if True: + if flag(): + x = 2 +else: + x = 3 + +reveal_type(x) # revealed: Literal[1, 2] +``` + +#### `if True` inside `if ` + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + if True: + x = 2 +else: + x = 3 + +# TODO: This should be Literal[2, 3] +reveal_type(x) # revealed: Literal[1, 2, 3] +``` + +#### `if True` inside `if False` ... `else` + +```py +x = 1 + +if False: + x = 2 +else: + if True: + x = 3 + +reveal_type(x) # revealed: Literal[3] +``` + +#### `if False` inside `if False` ... `else` + +```py +x = 1 + +if False: + x = 2 +else: + if False: + x = 3 + +reveal_type(x) # revealed: Literal[1] +``` + +#### `if ` inside `if False` ... `else` + +```py +def flag() -> bool: ... + +x = 1 + +if False: + x = 2 +else: + if flag(): + x = 3 + +reveal_type(x) # revealed: Literal[1, 3] +``` + +### Nested conditionals (with inner `else`) + +#### `if True` inside `if True` + +```py +x = 1 + +if True: + if True: + x = 2 + else: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2] +``` + +#### `if False` inside `if True` + +```py +x = 1 + +if True: + if False: + x = 2 + else: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[3] +``` + +#### `if ` inside `if True` + +```py +def flag() -> bool: ... + +x = 1 + +if True: + if flag(): + x = 2 + else: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2, 3] +``` + +#### `if True` inside `if ` + +```py +def flag() -> bool: ... + +x = 1 + +if flag(): + if True: + x = 2 + else: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[2, 4] +``` + +#### `if True` inside `if False` ... `else` + +```py +x = 1 + +if False: + x = 2 +else: + if True: + x = 3 + else: + x = 4 + +reveal_type(x) # revealed: Literal[3] +``` + +#### `if False` inside `if False` ... `else` + +```py +x = 1 + +if False: + x = 2 +else: + if False: + x = 3 + else: + x = 4 + +reveal_type(x) # revealed: Literal[4] +``` + +#### `if ` inside `if False` ... `else` + +```py +def flag() -> bool: ... + +x = 1 + +if False: + x = 2 +else: + if flag(): + x = 3 + else: + x = 4 + +reveal_type(x) # revealed: Literal[3, 4] +``` + +### Combination with non-conditional control flow + +#### `try` ... `except` + +##### `if True` inside `try` + +```py +def may_raise() -> None: ... + +x = 1 + +try: + may_raise() + if True: + x = 2 + else: + x = 3 +except: + x = 4 + +reveal_type(x) # revealed: Literal[2, 4] +``` + +##### `try` inside `if True` + +```py +def may_raise() -> None: ... + +x = 1 + +if True: + try: + may_raise() + x = 2 + except KeyError: + x = 3 + except ValueError: + x = 4 +else: + x = 5 + +reveal_type(x) # revealed: Literal[2, 3, 4] +``` + +##### `try` with `else` inside `if True` + +```py +def may_raise() -> None: ... + +x = 1 + +if True: + try: + may_raise() + x = 2 + except KeyError: + x = 3 + else: + x = 4 +else: + x = 5 + +reveal_type(x) # revealed: Literal[3, 4] +``` + +##### `try` with `finally` inside `if True` + +```py +def may_raise() -> None: ... + +x = 1 + +if True: + try: + may_raise() + x = 2 + except KeyError: + x = 3 + else: + x = 4 + finally: + x = 5 +else: + x = 6 + +reveal_type(x) # revealed: Literal[5] +``` + +#### `for` loops + +##### `if True` inside `for` + +```py +def iterable() -> list[object]: ... + +x = 1 + +for _ in iterable(): + x = 2 + if True: + x = 3 + +# TODO: This should be Literal[1, 3] +reveal_type(x) # revealed: Literal[1, 2, 3] +``` + +##### `if True` inside `for` ... `else` + +```py +def iterable() -> list[object]: ... + +x = 1 + +for _ in iterable(): + x = 2 +else: + if True: + x = 3 + else: + x = 4 + +reveal_type(x) # revealed: Literal[3] +``` + +##### `for` inside `if True` + +```py +def iterable() -> list[object]: ... + +x = 1 + +if True: + for _ in iterable(): + x = 2 +else: + x = 3 + +reveal_type(x) # revealed: Literal[1, 2] +``` + +##### `for` ... `else` inside `if True` + +```py +def iterable() -> list[object]: ... + +x = 1 + +if True: + for _ in iterable(): + x = 2 + else: + x = 3 +else: + x = 4 + +reveal_type(x) # revealed: Literal[3] +``` + +##### `for` loop with `break` inside `if True` + +```py +def iterable() -> list[object]: ... + +x = 1 + +if True: + x = 2 + for _ in iterable(): + x = 3 + break + else: + x = 4 +else: + x = 5 + +reveal_type(x) # revealed: Literal[3, 4] +``` + +## If expressions + +See also: tests in [expression/if.md](expression/if.md). + +### Always true + +```py +x = 1 if True else 2 + +reveal_type(x) # revealed: Literal[1] +``` + +### Always false + +```py +x = 1 if False else 2 + +reveal_type(x) # revealed: Literal[2] +``` + +## Boolean expressions + +### Always true, `or` + +```py +(x := 1) or (x := 2) + +reveal_type(x) # revealed: Literal[1] +``` + +### Always true, `and` + +```py +(x := 1) and (x := 2) + +reveal_type(x) # revealed: Literal[2] +``` + +### Always false, `or` + +```py +(x := 0) or (x := 2) + +reveal_type(x) # revealed: Literal[2] +``` + +### Always false, `and` + +```py +(x := 0) and (x := 2) + +reveal_type(x) # revealed: Literal[0] +``` + +## While loops + +### Always false + +```py +x = 1 + +while False: + x = 2 + +reveal_type(x) # revealed: Literal[1] +``` + +### Always true + +```py +x = 1 + +while True: + x = 2 + break + +reveal_type(x) # revealed: Literal[2] +``` + +### Ambiguous + +Make sure that we still infer the combined type if the condition is not statically known: + +```py +def flag() -> bool: ... + +x = 1 + +while flag(): + x = 2 + +reveal_type(x) # revealed: Literal[1, 2] +``` + +### `while` ... `else` + +#### `while False` + +```py +while False: + x = 1 +else: + x = 2 + +reveal_type(x) # revealed: Literal[2] +``` + +#### `while True` + +```py +while True: + x = 1 + break +else: + x = 2 + +reveal_type(x) # revealed: Literal[1] +``` + +## `match` statements + +### Single-valued types, always true + +```py +x = 1 + +match "a": + case "a": + x = 2 + case "b": + x = 3 + +reveal_type(x) # revealed: Literal[2] +``` + +### Single-valued types, always false + +```py +x = 1 + +match "something else": + case "a": + x = 2 + case "b": + x = 3 + +reveal_type(x) # revealed: Literal[1] +``` + +### Single-valued types, with wildcard pattern + +This is a case that we can not handle at the moment. Our reasoning about match patterns is too +local. We can infer that the `x = 2` binding is unconditionally visible. But when we traverse all +bindings backwards, we first see the `x = 3` binding which is also visible. At the moment, we do not +mark it as *unconditionally* visible to avoid blocking off previous bindings (we would infer +`Literal[3]` otherwise). + +```py +x = 1 + +match "a": + case "a": + x = 2 + case _: + x = 3 + +# TODO: ideally, this should be Literal[2] +reveal_type(x) # revealed: Literal[2, 3] +``` + +### Non-single-valued types + +```py +def _(s: str): + match s: + case "a": + x = 1 + case _: + x = 2 + + reveal_type(x) # revealed: Literal[1, 2] +``` + +### `sys.version_info` + +```toml +[environment] +python-version = "3.13" +``` + +```py +import sys + +minor = "too old" + +match sys.version_info.minor: + case 12: + minor = 12 + case 13: + minor = 13 + case _: + pass + +reveal_type(minor) # revealed: Literal[13] +``` + +## Conditional declarations + +### Always false + +#### `if False` + +```py +x: str + +if False: + x: int + +def f() -> None: + reveal_type(x) # revealed: str +``` + +#### `if True … else` + +```py +x: str + +if True: + pass +else: + x: int + +def f() -> None: + reveal_type(x) # revealed: str +``` + +### Always true + +#### `if True` + +```py +x: str + +if True: + x: int + +def f() -> None: + reveal_type(x) # revealed: int +``` + +#### `if False … else` + +```py +x: str + +if False: + pass +else: + x: int + +def f() -> None: + reveal_type(x) # revealed: int +``` + +### Ambiguous + +```py +def flag() -> bool: ... + +x: str + +if flag(): + x: int + +def f() -> None: + reveal_type(x) # revealed: str | int +``` + +## Conditional function definitions + +```py +def f() -> int: ... +def g() -> int: ... + +if True: + def f() -> str: ... + +else: + def g() -> str: ... + +reveal_type(f()) # revealed: str +reveal_type(g()) # revealed: int +``` + +## Conditional class definitions + +```py +if True: + class C: + x: int = 1 + +else: + class C: + x: str = "a" + +reveal_type(C.x) # revealed: int +``` + +## Conditional class attributes + +```py +class C: + if True: + x: int = 1 + else: + x: str = "a" + +reveal_type(C.x) # revealed: int +``` + +## (Un)boundness + +### Unbound, `if False` + +```py +if False: + x = 1 + +# error: [unresolved-reference] +x +``` + +### Unbound, `if True … else` + +```py +if True: + pass +else: + x = 1 + +# error: [unresolved-reference] +x +``` + +### Bound, `if True` + +```py +if True: + x = 1 + +# x is always bound, no error +x +``` + +### Bound, `if False … else` + +```py +if False: + pass +else: + x = 1 + +# x is always bound, no error +x +``` + +### Ambiguous, possibly unbound + +For comparison, we still detect definitions inside non-statically known branches as possibly +unbound: + +```py +def flag() -> bool: ... + +if flag(): + x = 1 + +# error: [possibly-unresolved-reference] +x +``` + +### Nested conditionals + +```py +def flag() -> bool: ... + +if False: + if True: + unbound1 = 1 + +if True: + if False: + unbound2 = 1 + +if False: + if False: + unbound3 = 1 + +if False: + if flag(): + unbound4 = 1 + +if flag(): + if False: + unbound5 = 1 + +# error: [unresolved-reference] +# error: [unresolved-reference] +# error: [unresolved-reference] +# error: [unresolved-reference] +# error: [unresolved-reference] +(unbound1, unbound2, unbound3, unbound4, unbound5) +``` + +### Chained conditionals + +```py +if False: + x = 1 +if True: + x = 2 + +# x is always bound, no error +x + +if False: + y = 1 +if True: + y = 2 + +# y is always bound, no error +y + +if False: + z = 1 +if False: + z = 2 + +# z is never bound: +# error: [unresolved-reference] +z +``` + +### Public boundness + +```py +if True: + x = 1 + +def f(): + # x is always bound, no error + x +``` + +### Imports of conditionally defined symbols + +#### Always false, unbound + +```py path=module.py +if False: + symbol = 1 +``` + +```py +# error: [unresolved-import] +from module import symbol +``` + +#### Always true, bound + +```py path=module.py +if True: + symbol = 1 +``` + +```py +# no error +from module import symbol +``` + +#### Ambiguous, possibly unbound + +```py path=module.py +def flag() -> bool: ... + +if flag(): + symbol = 1 +``` + +```py +# error: [possibly-unbound-import] +from module import symbol +``` + +#### Always false, undeclared + +```py path=module.py +if False: + symbol: int +``` + +```py +# error: [unresolved-import] +from module import symbol + +reveal_type(symbol) # revealed: Unknown +``` + +#### Always true, declared + +```py path=module.py +if True: + symbol: int +``` + +```py +# no error +from module import symbol +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md index 1d27885567df2..88dd39144a6ae 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md @@ -81,10 +81,7 @@ python-version = "3.9" ``` ```py -# TODO: -# * `tuple.__class_getitem__` is always bound on 3.9 (`sys.version_info`) -# * `tuple[int, str]` is a valid base (generics) -# error: [call-possibly-unbound-method] "Method `__class_getitem__` of type `Literal[tuple]` is possibly unbound" +# TODO: `tuple[int, str]` is a valid base (generics) # error: [invalid-base] "Invalid class base with type `GenericAlias` (all bases must be a class, `Any`, `Unknown` or `Todo`)" class A(tuple[int, str]): ... diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 2771ad301e380..637b1b32a14b2 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -20,6 +20,7 @@ use crate::semantic_index::use_def::UseDefMap; use crate::Db; pub mod ast_ids; +pub(crate) mod branching_condition; mod builder; pub(crate) mod constraint; pub mod definition; @@ -28,7 +29,8 @@ pub mod symbol; mod use_def; pub(crate) use self::use_def::{ - BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, + BindingWithConstraints, BindingWithConstraintsIterator, BranchingConditionsIterator, + DeclarationsIterator, }; type SymbolMap = hashbrown::HashMap; diff --git a/crates/red_knot_python_semantic/src/semantic_index/branching_condition.rs b/crates/red_knot_python_semantic/src/semantic_index/branching_condition.rs new file mode 100644 index 0000000000000..d43b3edf31781 --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/branching_condition.rs @@ -0,0 +1,29 @@ +use super::constraint::Constraint; + +/// Used to represent active branching conditions that apply to a particular definition. +/// A definition can either be conditional on a specific constraint from a `if`, `elif`, +/// `while` statement, an `if`-expression, or a Boolean expression. Or it can be marked +/// as 'ambiguous' if it occurred in a control-flow path that is not conditional on any +/// specific expression that can be statically analyzed (`for` loop, `try` ... `except`). +/// +/// +/// For example: +/// ```py +/// a = 1 # no active branching conditions +/// +/// if test1: +/// b = 1 # ConditionalOn(test1) +/// +/// if test2: +/// c = 1 # ConditionalOn(test1), ConditionalOn(test2) +/// +/// for _ in range(10): +/// d = 1 # ConditionalOn(test1), Ambiguous +/// else: +/// d = 1 # ConditionalOn(~test1) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum BranchingCondition<'db> { + ConditionalOn(Constraint<'db>), + Ambiguous, +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 27e657aba959b..4617b4027a1b3 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -6,14 +6,15 @@ use rustc_hash::FxHashMap; use ruff_db::files::File; use ruff_db::parsed::ParsedModule; use ruff_index::IndexVec; -use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor}; +use ruff_python_ast::{self as ast, Pattern}; use ruff_python_ast::{BoolOp, Expr}; use crate::ast_node_ref::AstNodeRef; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; +use crate::semantic_index::constraint::PatternConstraintKind; use crate::semantic_index::definition::{ AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey, DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef, @@ -23,7 +24,7 @@ use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTableBuilder, }; -use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; +use crate::semantic_index::use_def::{BranchingConditionsSnapshot, FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::unpack::Unpack; use crate::Db; @@ -200,12 +201,28 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_use_def_map().snapshot() } - fn flow_restore(&mut self, state: FlowSnapshot) { + fn branching_conditions_snapshot(&self) -> BranchingConditionsSnapshot { + self.current_use_def_map().branching_conditions_snapshot() + } + + fn flow_restore( + &mut self, + state: FlowSnapshot, + branching_conditions: BranchingConditionsSnapshot, + ) { self.current_use_def_map_mut().restore(state); + self.current_use_def_map_mut() + .restore_branching_conditions(branching_conditions); } - fn flow_merge(&mut self, state: FlowSnapshot) { + fn flow_merge( + &mut self, + state: FlowSnapshot, + branching_conditions: BranchingConditionsSnapshot, + ) { self.current_use_def_map_mut().merge(state); + self.current_use_def_map_mut() + .restore_branching_conditions(branching_conditions); } fn add_symbol(&mut self, name: Name) -> ScopedSymbolId { @@ -285,6 +302,10 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_use_def_map_mut().record_constraint(constraint); } + fn record_ambiguous_branching(&mut self) { + self.current_use_def_map_mut().record_ambiguous_branching(); + } + fn build_constraint(&mut self, constraint_node: &Expr) -> Constraint<'db> { let expression = self.add_standalone_expression(constraint_node); Constraint { @@ -320,22 +341,24 @@ impl<'db> SemanticIndexBuilder<'db> { fn add_pattern_constraint( &mut self, - subject: &ast::Expr, + subject: Expression<'db>, pattern: &ast::Pattern, ) -> PatternConstraint<'db> { - #[allow(unsafe_code)] - let (subject, pattern) = unsafe { - ( - AstNodeRef::new(self.module.clone(), subject), - AstNodeRef::new(self.module.clone(), pattern), - ) + let kind = match pattern { + Pattern::MatchValue(pattern) => { + let value = self.add_standalone_expression(&pattern.value); + PatternConstraintKind::Value(value) + } + Pattern::MatchSingleton(singleton) => PatternConstraintKind::Singleton(singleton.value), + _ => PatternConstraintKind::Unsupported, }; + let pattern_constraint = PatternConstraint::new( self.db, self.file, self.current_scope(), subject, - pattern, + kind, countme::Count::default(), ); self.current_use_def_map_mut() @@ -785,6 +808,7 @@ where ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); + let pre_if_conditions = self.branching_conditions_snapshot(); let constraint = self.record_expression_constraint(&node.test); let mut constraints = vec![constraint]; self.visit_body(&node.body); @@ -810,7 +834,7 @@ where post_clauses.push(self.flow_snapshot()); // we can only take an elif/else branch if none of the previous ones were // taken, so the block entry state is always `pre_if` - self.flow_restore(pre_if.clone()); + self.flow_restore(pre_if.clone(), pre_if_conditions.clone()); for constraint in &constraints { self.record_negated_constraint(*constraint); } @@ -821,7 +845,7 @@ where self.visit_body(clause_body); } for post_clause_state in post_clauses { - self.flow_merge(post_clause_state); + self.flow_merge(post_clause_state, pre_if_conditions.clone()); } } ast::Stmt::While(ast::StmtWhile { @@ -833,6 +857,7 @@ where self.visit_expr(test); let pre_loop = self.flow_snapshot(); + let pre_loop_conditions = self.branching_conditions_snapshot(); let constraint = self.record_expression_constraint(test); // Save aside any break states from an outer loop @@ -852,14 +877,14 @@ where // We may execute the `else` clause without ever executing the body, so merge in // the pre-loop state before visiting `else`. - self.flow_merge(pre_loop); + self.flow_merge(pre_loop, pre_loop_conditions.clone()); self.record_negated_constraint(constraint); self.visit_body(orelse); // Breaking out of a while loop bypasses the `else` clause, so merge in the break // states after visiting `else`. for break_state in break_states { - self.flow_merge(break_state); + self.flow_merge(break_state, pre_loop_conditions.clone()); } } ast::Stmt::With(ast::StmtWith { @@ -902,8 +927,11 @@ where self.visit_expr(iter); let pre_loop = self.flow_snapshot(); + let pre_loop_conditions = self.branching_conditions_snapshot(); let saved_break_states = std::mem::take(&mut self.loop_break_states); + self.record_ambiguous_branching(); + debug_assert_eq!(&self.current_assignments, &[]); self.push_assignment(for_stmt.into()); self.visit_expr(target); @@ -922,13 +950,14 @@ where // We may execute the `else` clause without ever executing the body, so merge in // the pre-loop state before visiting `else`. - self.flow_merge(pre_loop); + self.flow_merge(pre_loop, pre_loop_conditions.clone()); + self.record_ambiguous_branching(); self.visit_body(orelse); // Breaking out of a `for` loop bypasses the `else` clause, so merge in the break // states after visiting `else`. for break_state in break_states { - self.flow_merge(break_state); + self.flow_merge(break_state, pre_loop_conditions.clone()); } } ast::Stmt::Match(ast::StmtMatch { @@ -936,31 +965,32 @@ where cases, range: _, }) => { - self.add_standalone_expression(subject); + let subject_expr = self.add_standalone_expression(subject); self.visit_expr(subject); let after_subject = self.flow_snapshot(); + let after_subject_cs = self.branching_conditions_snapshot(); let Some((first, remaining)) = cases.split_first() else { return; }; - self.add_pattern_constraint(subject, &first.pattern); + self.add_pattern_constraint(subject_expr, &first.pattern); self.visit_match_case(first); let mut post_case_snapshots = vec![]; for case in remaining { post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - self.add_pattern_constraint(subject, &case.pattern); + self.flow_restore(after_subject.clone(), after_subject_cs.clone()); + self.add_pattern_constraint(subject_expr, &case.pattern); self.visit_match_case(case); } for post_clause_state in post_case_snapshots { - self.flow_merge(post_clause_state); + self.flow_merge(post_clause_state, after_subject_cs.clone()); } if !cases .last() .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()) { - self.flow_merge(after_subject); + self.flow_merge(after_subject, after_subject_cs); } } ast::Stmt::Try(ast::StmtTry { @@ -978,6 +1008,9 @@ where // We will merge this state with all of the intermediate // states during the `try` block before visiting those suites. let pre_try_block_state = self.flow_snapshot(); + let pre_try_block_conditions = self.branching_conditions_snapshot(); + + self.record_ambiguous_branching(); self.try_node_context_stack_manager.push_context(); @@ -1000,9 +1033,9 @@ where let post_try_block_state = self.flow_snapshot(); // Prepare for visiting the `except` block(s) - self.flow_restore(pre_try_block_state); + self.flow_restore(pre_try_block_state, pre_try_block_conditions.clone()); for state in try_block_snapshots { - self.flow_merge(state); + self.flow_merge(state, pre_try_block_conditions.clone()); } let pre_except_state = self.flow_snapshot(); @@ -1017,6 +1050,8 @@ where range: _, } = except_handler; + self.record_ambiguous_branching(); + if let Some(handled_exceptions) = handled_exceptions { self.visit_expr(handled_exceptions); } @@ -1044,19 +1079,24 @@ where // as we'll immediately call `self.flow_restore()` to a different state // as soon as this loop over the handlers terminates. if i < (num_handlers - 1) { - self.flow_restore(pre_except_state.clone()); + self.flow_restore( + pre_except_state.clone(), + pre_try_block_conditions.clone(), + ); } } // If we get to the `else` block, we know that 0 of the `except` blocks can have been executed, // and the entire `try` block must have been executed: - self.flow_restore(post_try_block_state); + self.flow_restore(post_try_block_state, pre_try_block_conditions.clone()); } + self.record_ambiguous_branching(); + self.visit_body(orelse); for post_except_state in post_except_states { - self.flow_merge(post_except_state); + self.flow_merge(post_except_state, pre_try_block_conditions.clone()); } // TODO: there's lots of complexity here that isn't yet handled by our model. @@ -1069,7 +1109,12 @@ where // For more details, see: // - https://astral-sh.notion.site/Exception-handler-control-flow-11348797e1ca80bb8ce1e9aedbbe439d // - https://github.com/astral-sh/ruff/pull/13633#discussion_r1788626702 + self.record_ambiguous_branching(); + self.visit_body(finalbody); + + self.current_use_def_map_mut() + .restore_branching_conditions(pre_try_block_conditions); } _ => { walk_stmt(self, stmt); @@ -1211,19 +1256,17 @@ where ast::Expr::If(ast::ExprIf { body, test, orelse, .. }) => { - // TODO detect statically known truthy or falsy test (via type inference, not naive - // AST inspection, so we can't simplify here, need to record test expression for - // later checking) self.visit_expr(test); let pre_if = self.flow_snapshot(); + let pre_if_conditions = self.branching_conditions_snapshot(); let constraint = self.record_expression_constraint(test); self.visit_expr(body); let post_body = self.flow_snapshot(); - self.flow_restore(pre_if); + self.flow_restore(pre_if, pre_if_conditions.clone()); self.record_negated_constraint(constraint); self.visit_expr(orelse); - self.flow_merge(post_body); + self.flow_merge(post_body, pre_if_conditions); } ast::Expr::ListComp( list_comprehension @ ast::ExprListComp { @@ -1280,11 +1323,8 @@ where range: _, op, }) => { - // TODO detect statically known truthy or falsy values (via type inference, not naive - // AST inspection, so we can't simplify here, need to record test expression for - // later checking) let mut snapshots = vec![]; - + let pre_op_conditions = self.branching_conditions_snapshot(); for (index, value) in values.iter().enumerate() { self.visit_expr(value); // In the last value we don't need to take a snapshot nor add a constraint @@ -1299,7 +1339,7 @@ where } } for snapshot in snapshots { - self.flow_merge(snapshot); + self.flow_merge(snapshot, pre_op_conditions.clone()); } } _ => { diff --git a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs index 44b542f0e90ac..347f0ebaac4f7 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs @@ -1,7 +1,6 @@ use ruff_db::files::File; -use ruff_python_ast as ast; +use ruff_python_ast::Singleton; -use crate::ast_node_ref::AstNodeRef; use crate::db::Db; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{FileScopeId, ScopeId}; @@ -18,6 +17,14 @@ pub(crate) enum ConstraintNode<'db> { Pattern(PatternConstraint<'db>), } +/// Pattern kinds for which we do support type narrowing and/or static truthiness analysis. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum PatternConstraintKind<'db> { + Singleton(Singleton), + Value(Expression<'db>), + Unsupported, +} + #[salsa::tracked] pub(crate) struct PatternConstraint<'db> { #[id] @@ -28,11 +35,11 @@ pub(crate) struct PatternConstraint<'db> { #[no_eq] #[return_ref] - pub(crate) subject: AstNodeRef, + pub(crate) subject: Expression<'db>, #[no_eq] #[return_ref] - pub(crate) pattern: AstNodeRef, + pub(crate) kind: PatternConstraintKind<'db>, #[no_eq] count: countme::Count>, diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index 9f3e197c74eee..a85876be5f7a5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -226,9 +226,14 @@ use self::symbol_state::{ ScopedConstraintId, ScopedDefinitionId, SymbolBindings, SymbolDeclarations, SymbolState, }; use crate::semantic_index::ast_ids::ScopedUseId; +use crate::semantic_index::branching_condition::BranchingCondition; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::ScopedSymbolId; +use crate::semantic_index::use_def::symbol_state::{ + BranchingConditionIdIterator, BranchingConditions, ScopedBranchingConditionId, +}; use crate::symbol::Boundness; +use crate::types::StaticTruthiness; use ruff_index::IndexVec; use rustc_hash::FxHashMap; @@ -246,6 +251,9 @@ pub(crate) struct UseDefMap<'db> { /// Array of [`Constraint`] in this scope. all_constraints: IndexVec>, + /// Array of [`BranchingCondition`] in this scope. + all_branching_conditions: IndexVec>, + /// [`SymbolBindings`] reaching a [`ScopedUseId`]. bindings_by_use: IndexVec, @@ -275,12 +283,16 @@ impl<'db> UseDefMap<'db> { self.bindings_iterator(&self.bindings_by_use[use_id]) } - pub(crate) fn use_boundness(&self, use_id: ScopedUseId) -> Boundness { - if self.bindings_by_use[use_id].may_be_unbound() { - Boundness::PossiblyUnbound - } else { - Boundness::Bound - } + pub(crate) fn use_boundness( + &self, + db: &dyn crate::db::Db, + use_id: ScopedUseId, + ) -> Option { + let bindings = &self.bindings_by_use[use_id]; + let conditions_per_binding = self + .bindings_iterator(bindings) + .map(|binding| binding.branching_conditions); + analyze_boundness(db, conditions_per_binding, bindings.may_be_unbound()) } pub(crate) fn public_bindings( @@ -290,12 +302,16 @@ impl<'db> UseDefMap<'db> { self.bindings_iterator(self.public_symbols[symbol].bindings()) } - pub(crate) fn public_boundness(&self, symbol: ScopedSymbolId) -> Boundness { - if self.public_symbols[symbol].may_be_unbound() { - Boundness::PossiblyUnbound - } else { - Boundness::Bound - } + pub(crate) fn public_boundness( + &self, + db: &dyn crate::db::Db, + symbol: ScopedSymbolId, + ) -> Option { + let bindings = self.public_symbols[symbol].bindings(); + let conditions = self + .bindings_iterator(bindings) + .map(|binding| binding.branching_conditions); + analyze_boundness(db, conditions, bindings.may_be_unbound()) } pub(crate) fn bindings_at_declaration( @@ -331,10 +347,6 @@ impl<'db> UseDefMap<'db> { self.declarations_iterator(declarations) } - pub(crate) fn has_public_declarations(&self, symbol: ScopedSymbolId) -> bool { - !self.public_symbols[symbol].declarations().is_empty() - } - fn bindings_iterator<'a>( &'a self, bindings: &'a SymbolBindings, @@ -342,7 +354,8 @@ impl<'db> UseDefMap<'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, all_constraints: &self.all_constraints, - inner: bindings.iter(), + all_branching_conditions: &self.all_branching_conditions, + inner: bindings.iter_rev(), } } @@ -352,7 +365,8 @@ impl<'db> UseDefMap<'db> { ) -> DeclarationsIterator<'a, 'db> { DeclarationsIterator { all_definitions: &self.all_definitions, - inner: declarations.iter(), + all_branching_conditions: &self.all_branching_conditions, + inner: declarations.iter_rev(), may_be_undeclared: declarations.may_be_undeclared(), } } @@ -369,6 +383,7 @@ enum SymbolDefinitions { pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { all_definitions: &'map IndexVec>, all_constraints: &'map IndexVec>, + all_branching_conditions: &'map IndexVec>, inner: BindingIdWithConstraintsIterator<'map>, } @@ -376,15 +391,17 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { type Item = BindingWithConstraints<'map, 'db>; fn next(&mut self) -> Option { - self.inner - .next() - .map(|def_id_with_constraints| BindingWithConstraints { - binding: self.all_definitions[def_id_with_constraints.definition], - constraints: ConstraintsIterator { - all_constraints: self.all_constraints, - constraint_ids: def_id_with_constraints.constraint_ids, - }, - }) + self.inner.next().map(|binding| BindingWithConstraints { + binding: self.all_definitions[binding.definition], + constraints: ConstraintsIterator { + all_constraints: self.all_constraints, + constraint_ids: binding.constraint_ids, + }, + branching_conditions: BranchingConditionsIterator { + all_branching_conditions: self.all_branching_conditions, + branching_condition_ids: binding.branching_conditions_ids, + }, + }) } } @@ -393,6 +410,7 @@ impl std::iter::FusedIterator for BindingWithConstraintsIterator<'_, '_> {} pub(crate) struct BindingWithConstraints<'map, 'db> { pub(crate) binding: Definition<'db>, pub(crate) constraints: ConstraintsIterator<'map, 'db>, + pub(crate) branching_conditions: BranchingConditionsIterator<'map, 'db>, } pub(crate) struct ConstraintsIterator<'map, 'db> { @@ -412,23 +430,60 @@ impl<'db> Iterator for ConstraintsIterator<'_, 'db> { impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} +pub(crate) struct BranchingConditionsIterator<'map, 'db> { + all_branching_conditions: &'map IndexVec>, + branching_condition_ids: BranchingConditionIdIterator<'map>, +} + +impl<'db> Iterator for BranchingConditionsIterator<'_, 'db> { + type Item = BranchingCondition<'db>; + + fn next(&mut self) -> Option { + self.branching_condition_ids + .next() + .map(|branching_condition_id| self.all_branching_conditions[branching_condition_id]) + } +} + +impl std::iter::FusedIterator for BranchingConditionsIterator<'_, '_> {} + +#[derive(Clone)] pub(crate) struct DeclarationsIterator<'map, 'db> { all_definitions: &'map IndexVec>, + all_branching_conditions: &'map IndexVec>, inner: DeclarationIdIterator<'map>, may_be_undeclared: bool, } impl DeclarationsIterator<'_, '_> { - pub(crate) fn may_be_undeclared(&self) -> bool { - self.may_be_undeclared + pub(crate) fn declaredness(self, db: &dyn crate::db::Db) -> Option { + let may_be_undeclared = self.may_be_undeclared; + let conditions_per_binding = self.map(|(_, conditions)| conditions); + analyze_boundness(db, conditions_per_binding, may_be_undeclared) + } + + pub(crate) fn may_be_undeclared(self, db: &dyn crate::db::Db) -> bool { + match self.declaredness(db) { + Some(Boundness::Bound) => false, + Some(Boundness::PossiblyUnbound) => true, + None => true, + } } } -impl<'db> Iterator for DeclarationsIterator<'_, 'db> { - type Item = Definition<'db>; +impl<'map, 'db> Iterator for DeclarationsIterator<'map, 'db> { + type Item = (Definition<'db>, BranchingConditionsIterator<'map, 'db>); fn next(&mut self) -> Option { - self.inner.next().map(|def_id| self.all_definitions[def_id]) + self.inner.next().map(|(def_id, branching_condition_ids)| { + ( + self.all_definitions[def_id], + BranchingConditionsIterator { + all_branching_conditions: self.all_branching_conditions, + branching_condition_ids, + }, + ) + }) } } @@ -440,6 +495,10 @@ pub(super) struct FlowSnapshot { symbol_states: IndexVec, } +/// A snapshot of the active branching conditions at a particular point in control flow. +#[derive(Clone, Debug)] +pub(super) struct BranchingConditionsSnapshot(BranchingConditions); + #[derive(Debug, Default)] pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`Definition`]. @@ -448,6 +507,12 @@ pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`Constraint`]. all_constraints: IndexVec>, + /// Append-only array of [`BranchingCondition`]. + all_branching_conditions: IndexVec>, + + /// Active branching conditions. + active_branching_conditions: BranchingConditions, + /// Live bindings at each so-far-recorded use. bindings_by_use: IndexVec, @@ -471,7 +536,7 @@ impl<'db> UseDefMapBuilder<'db> { binding, SymbolDefinitions::Declarations(symbol_state.declarations().clone()), ); - symbol_state.record_binding(def_id); + symbol_state.record_binding(def_id, &self.active_branching_conditions); } pub(super) fn record_constraint(&mut self, constraint: Constraint<'db>) { @@ -479,6 +544,20 @@ impl<'db> UseDefMapBuilder<'db> { for state in &mut self.symbol_states { state.record_constraint(constraint_id); } + + self.record_branching_condition(BranchingCondition::ConditionalOn(constraint)); + } + + /// Marks a point in control-flow where we branch on a condition that we can not (or choose + /// not to) analyze statically. Examples are `try` blocks or `for` loops. + pub(super) fn record_ambiguous_branching(&mut self) { + self.record_branching_condition(BranchingCondition::Ambiguous); + } + + pub(super) fn record_branching_condition(&mut self, condition: BranchingCondition<'db>) { + let condition_id = self.all_branching_conditions.push(condition); + self.active_branching_conditions + .insert(condition_id.as_u32()); } pub(super) fn record_declaration( @@ -492,7 +571,7 @@ impl<'db> UseDefMapBuilder<'db> { declaration, SymbolDefinitions::Bindings(symbol_state.bindings().clone()), ); - symbol_state.record_declaration(def_id); + symbol_state.record_declaration(def_id, &self.active_branching_conditions); } pub(super) fn record_declaration_and_binding( @@ -503,8 +582,8 @@ impl<'db> UseDefMapBuilder<'db> { // We don't need to store anything in self.definitions_by_definition. let def_id = self.all_definitions.push(definition); let symbol_state = &mut self.symbol_states[symbol]; - symbol_state.record_declaration(def_id); - symbol_state.record_binding(def_id); + symbol_state.record_declaration(def_id, &self.active_branching_conditions); + symbol_state.record_binding(def_id, &self.active_branching_conditions); } pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) { @@ -523,6 +602,10 @@ impl<'db> UseDefMapBuilder<'db> { } } + pub(super) fn branching_conditions_snapshot(&self) -> BranchingConditionsSnapshot { + BranchingConditionsSnapshot(self.active_branching_conditions.clone()) + } + /// Restore the current builder symbols state to the given snapshot. pub(super) fn restore(&mut self, snapshot: FlowSnapshot) { // We never remove symbols from `symbol_states` (it's an IndexVec, and the symbol @@ -541,6 +624,10 @@ impl<'db> UseDefMapBuilder<'db> { .resize(num_symbols, SymbolState::undefined()); } + pub(super) fn restore_branching_conditions(&mut self, snapshot: BranchingConditionsSnapshot) { + self.active_branching_conditions = snapshot.0; + } + /// Merge the given snapshot into the current state, reflecting that we might have taken either /// path to get here. The new state for each symbol should include definitions from both the /// prior state and the snapshot. @@ -572,9 +659,74 @@ impl<'db> UseDefMapBuilder<'db> { UseDefMap { all_definitions: self.all_definitions, all_constraints: self.all_constraints, + all_branching_conditions: self.all_branching_conditions, bindings_by_use: self.bindings_by_use, public_symbols: self.symbol_states, definitions_by_definition: self.definitions_by_definition, } } } + +/// Analyze the boundness (or declaredness) of a symbol based on all the branching conditions +/// that were active for each of its bindings (or declarations). +/// +/// Returns `None` if the symbol is definitely unbound. +/// +/// Consider this example: +/// ```py +/// if test: +/// x = 1 +/// ``` +/// +/// Depending on the static truthiness of `test`, `x` could either be definitely bound (if `test` +/// is always true), definitely unbound (if `test` is always false), or possibly unbound (if the +/// truthiness of `test` is ambiguous). +/// +/// If there are multiple bindings, the results need to be merged: +/// ```py +/// if test1: +/// x = 1 +/// if test2: +/// x = 2 +/// ``` +/// +/// Here, `x` is definitely bound if `test1` is always true OR if `test2` is always true. `x` is +/// definitely unbound if `test1` is always false AND `test2` is always false. `x` is possibly +/// unbound in all other cases. This logic is handled in [`StaticTruthiness::flow_merge`]. +/// +/// Finally, we also need to consider that a symbol could be definitely bound, even if we can not +/// statically infer the truthiness of a test condition. On such example is: +/// ```py +/// if test: +/// x = 1 +/// else: +/// x = 2 +/// ``` +/// Here, `x` is definitely bound, no matter the value of `test`. The `may_be_unbound` flag from +/// semantic index building is used to determine this (with a value of `false` for this case). +fn analyze_boundness<'db, 'map, C>( + db: &dyn crate::db::Db, + conditions_per_binding: C, + may_be_unbound: bool, +) -> Option +where + 'db: 'map, + C: Iterator>, +{ + let result = conditions_per_binding.fold(StaticTruthiness::no_bindings(), |r, conditions| { + r.flow_merge(&StaticTruthiness::analyze(db, conditions)) + }); + + let definitely_unbound = result.any_always_false; + let definitely_bound = result.all_always_true || !may_be_unbound; + + if definitely_unbound { + None + } else { + if definitely_bound { + Some(Boundness::Bound) + } else { + Some(Boundness::PossiblyUnbound) + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs index 464f718e7b4f4..69c052e4ae164 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/bitset.rs @@ -32,10 +32,6 @@ impl BitSet { bitset } - pub(super) fn is_empty(&self) -> bool { - self.blocks().iter().all(|&b| b == 0) - } - /// Convert from Inline to Heap, if needed, and resize the Heap vector, if needed. fn resize(&mut self, value: u32) { let num_blocks_needed = (value / 64) + 1; @@ -97,19 +93,6 @@ impl BitSet { } } - /// Union in-place with another [`BitSet`]. - pub(super) fn union(&mut self, other: &BitSet) { - let mut max_len = self.blocks().len(); - let other_len = other.blocks().len(); - if other_len > max_len { - max_len = other_len; - self.resize_blocks(max_len); - } - for (my_block, other_block) in self.blocks_mut().iter_mut().zip(other.blocks()) { - *my_block |= other_block; - } - } - /// Return an iterator over the values (in ascending order) in this [`BitSet`]. pub(super) fn iter(&self) -> BitSetIterator<'_, B> { let blocks = self.blocks(); @@ -119,6 +102,19 @@ impl BitSet { current_block: blocks[0], } } + + pub(super) fn iter_rev(&self) -> ReverseBitSetIterator<'_, B> { + let num_blocks = self.blocks().len(); + + assert!(num_blocks > 0); + + let blocks = self.blocks(); + ReverseBitSetIterator { + blocks, + current_block_index: num_blocks - 1, + current_block: blocks[num_blocks - 1], + } + } } /// Iterator over values in a [`BitSet`]. @@ -158,10 +154,46 @@ impl Iterator for BitSetIterator<'_, B> { impl std::iter::FusedIterator for BitSetIterator<'_, B> {} +/// Iterates over values in a [`BitSet`], in reverse order (highest bit first). +#[derive(Debug, Clone)] +pub(super) struct ReverseBitSetIterator<'a, const B: usize> { + /// The blocks we are iterating over. + blocks: &'a [u64], + + /// The index of the block we are currently iterating through. + current_block_index: usize, + + /// The block we are currently iterating through (and zeroing as we go.) + current_block: u64, +} + +impl Iterator for ReverseBitSetIterator<'_, B> { + type Item = u32; + + fn next(&mut self) -> Option { + while self.current_block == 0 { + if self.current_block_index == 0 { + return None; + } + self.current_block_index -= 1; + self.current_block = self.blocks[self.current_block_index]; + } + // SAFETY: current_block is non-zero, so leading_zeros must be + // strictly less than 64. + let highest_bit_set = 63 - self.current_block.leading_zeros(); + // reset the highest bit + self.current_block &= !(1u64 << highest_bit_set); + // SAFETY: see above + #[allow(clippy::cast_possible_truncation)] + Some(highest_bit_set + (64 * self.current_block_index) as u32) + } +} + #[cfg(test)] mod tests { use super::BitSet; + #[track_caller] fn assert_bitset(bitset: &BitSet, contents: &[u32]) { assert_eq!(bitset.iter().collect::>(), contents); } @@ -239,59 +271,6 @@ mod tests { assert_bitset(&b1, &[89]); } - #[test] - fn union() { - let mut b1 = BitSet::<1>::with(2); - let b2 = BitSet::<1>::with(4); - - b1.union(&b2); - assert_bitset(&b1, &[2, 4]); - } - - #[test] - fn union_mixed_1() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(89); - b2.insert(5); - - b1.union(&b2); - assert_bitset(&b1, &[4, 5, 89]); - } - - #[test] - fn union_mixed_2() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(23); - b2.insert(89); - - b1.union(&b2); - assert_bitset(&b1, &[4, 23, 89]); - } - - #[test] - fn union_heap() { - let mut b1 = BitSet::<1>::with(4); - let mut b2 = BitSet::<1>::with(4); - b1.insert(89); - b2.insert(90); - - b1.union(&b2); - assert_bitset(&b1, &[4, 89, 90]); - } - - #[test] - fn union_heap_2() { - let mut b1 = BitSet::<1>::with(89); - let mut b2 = BitSet::<1>::with(89); - b1.insert(91); - b2.insert(90); - - b1.union(&b2); - assert_bitset(&b1, &[89, 90, 91]); - } - #[test] fn multiple_blocks() { let mut b = BitSet::<2>::with(120); @@ -301,9 +280,29 @@ mod tests { } #[test] - fn empty() { - let b = BitSet::<1>::default(); - - assert!(b.is_empty()); + fn reverse_iterator() { + let empty = BitSet::<1>::default(); + assert!(empty.iter_rev().next().is_none()); + + let single_element = BitSet::<1>::with(10); + assert_eq!(single_element.iter_rev().collect::>(), vec![10]); + + let mut single_block = BitSet::<1>::with(1); + single_block.insert(2); + single_block.insert(10); + assert_eq!(single_block.iter_rev().collect::>(), vec![10, 2, 1]); + + let mut multiple_blocks = BitSet::<1>::default(); + multiple_blocks.insert(1); + multiple_blocks.insert(2); + multiple_blocks.insert(3); + multiple_blocks.insert(70); + multiple_blocks.insert(71); + multiple_blocks.insert(1000); + assert!(matches!(multiple_blocks, BitSet::Heap(_))); + assert_eq!( + multiple_blocks.iter_rev().collect::>(), + vec![1000, 71, 70, 3, 2, 1] + ); } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs index 506300067c952..3ec9f468475b5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def/symbol_state.rs @@ -43,7 +43,7 @@ //! //! Tracking live declarations is simpler, since constraints are not involved, but otherwise very //! similar to tracking live bindings. -use super::bitset::{BitSet, BitSetIterator}; +use super::bitset::{BitSet, BitSetIterator, ReverseBitSetIterator}; use ruff_index::newtype_index; use smallvec::SmallVec; @@ -55,19 +55,23 @@ pub(super) struct ScopedDefinitionId; #[newtype_index] pub(super) struct ScopedConstraintId; +/// A newtype-index for a [`crate::semantic_index::branching_condition::BranchingCondition`] in a particular scope. +#[newtype_index] +pub(super) struct ScopedBranchingConditionId; + /// Can reference this * 64 total definitions inline; more will fall back to the heap. const INLINE_BINDING_BLOCKS: usize = 3; /// A [`BitSet`] of [`ScopedDefinitionId`], representing live bindings of a symbol in a scope. type Bindings = BitSet; -type BindingsIterator<'a> = BitSetIterator<'a, INLINE_BINDING_BLOCKS>; +type ReverseBindingsIterator<'a> = ReverseBitSetIterator<'a, INLINE_BINDING_BLOCKS>; /// Can reference this * 64 total declarations inline; more will fall back to the heap. const INLINE_DECLARATION_BLOCKS: usize = 3; /// A [`BitSet`] of [`ScopedDefinitionId`], representing live declarations of a symbol in a scope. type Declarations = BitSet; -type DeclarationsIterator<'a> = BitSetIterator<'a, INLINE_DECLARATION_BLOCKS>; +type ReverseDeclarationsIterator<'a> = ReverseBitSetIterator<'a, INLINE_DECLARATION_BLOCKS>; /// Can reference this * 64 total constraints inline; more will fall back to the heap. const INLINE_CONSTRAINT_BLOCKS: usize = 2; @@ -75,18 +79,37 @@ const INLINE_CONSTRAINT_BLOCKS: usize = 2; /// Can keep inline this many live bindings per symbol at a given time; more will go to heap. const INLINE_BINDINGS_PER_SYMBOL: usize = 4; -/// One [`BitSet`] of applicable [`ScopedConstraintId`] per live binding. -type InlineConstraintArray = [BitSet; INLINE_BINDINGS_PER_SYMBOL]; -type Constraints = SmallVec; -type ConstraintsIterator<'a> = std::slice::Iter<'a, BitSet>; +/// Which constraints apply to a given binding? +type Constraints = BitSet; + +type InlineConstraintArray = [Constraints; INLINE_BINDINGS_PER_SYMBOL]; + +/// One [`BitSet`] of applicable [`ScopedConstraintId`]s per live binding. +type ConstraintsPerBinding = SmallVec; + +/// Iterate over all constraints for a single binding. +type ConstraintsIterator<'a> = std::slice::Iter<'a, Constraints>; type ConstraintsIntoIterator = smallvec::IntoIter; +/// Similar to what we have for constraints, but for active branching conditions. +const INLINE_BRANCHING_BLOCKS: usize = 2; +const INLINE_BRANCHING_CONDITIONS: usize = 4; +pub(super) type BranchingConditions = BitSet; +type InlineBranchingConditionsArray = [BranchingConditions; INLINE_BRANCHING_CONDITIONS]; +/// One [`BitSet`] of active [`ScopedBranchingConditionId`]s per live binding. +type BranchingConditionsPerBinding = SmallVec; +type BranchingConditionsIterator<'a> = std::slice::Iter<'a, BranchingConditions>; +type BranchingConditionsIntoIterator = smallvec::IntoIter; + /// Live declarations for a single symbol at some point in control flow. #[derive(Clone, Debug, PartialEq, Eq)] pub(super) struct SymbolDeclarations { /// [`BitSet`]: which declarations (as [`ScopedDefinitionId`]) can reach the current location? live_declarations: Declarations, + /// For each live declaration, which [`BranchingConditions`] were active at that declaration? + branching_conditions: BranchingConditionsPerBinding, + /// Could the symbol be un-declared at this point? may_be_undeclared: bool, } @@ -95,14 +118,26 @@ impl SymbolDeclarations { fn undeclared() -> Self { Self { live_declarations: Declarations::default(), + branching_conditions: BranchingConditionsPerBinding::default(), may_be_undeclared: true, } } /// Record a newly-encountered declaration for this symbol. - fn record_declaration(&mut self, declaration_id: ScopedDefinitionId) { + fn record_declaration( + &mut self, + declaration_id: ScopedDefinitionId, + branching_conditions: &BranchingConditions, + ) { self.live_declarations = Declarations::with(declaration_id.into()); self.may_be_undeclared = false; + + self.branching_conditions = BranchingConditionsPerBinding::with_capacity(1); + self.branching_conditions + .push(BranchingConditions::default()); + for active_constraint_id in branching_conditions.iter() { + self.branching_conditions[0].insert(active_constraint_id); + } } /// Add undeclared as a possibility for this symbol. @@ -111,16 +146,13 @@ impl SymbolDeclarations { } /// Return an iterator over live declarations for this symbol. - pub(super) fn iter(&self) -> DeclarationIdIterator { + pub(super) fn iter_rev(&self) -> DeclarationIdIterator { DeclarationIdIterator { - inner: self.live_declarations.iter(), + inner: self.live_declarations.iter_rev(), + branching_conditions: self.branching_conditions.iter().rev(), } } - pub(super) fn is_empty(&self) -> bool { - self.live_declarations.is_empty() - } - pub(super) fn may_be_undeclared(&self) -> bool { self.may_be_undeclared } @@ -136,7 +168,10 @@ pub(super) struct SymbolBindings { /// /// This is a [`smallvec::SmallVec`] which should always have one [`BitSet`] of constraints per /// binding in `live_bindings`. - constraints: Constraints, + constraints: ConstraintsPerBinding, + + /// For each live binding, which [`BranchingConditions`] were active at that binding? + branching_conditions: BranchingConditionsPerBinding, /// Could the symbol be unbound at this point? may_be_unbound: bool, @@ -146,7 +181,8 @@ impl SymbolBindings { fn unbound() -> Self { Self { live_bindings: Bindings::default(), - constraints: Constraints::default(), + constraints: ConstraintsPerBinding::default(), + branching_conditions: BranchingConditionsPerBinding::default(), may_be_unbound: true, } } @@ -157,12 +193,23 @@ impl SymbolBindings { } /// Record a newly-encountered binding for this symbol. - pub(super) fn record_binding(&mut self, binding_id: ScopedDefinitionId) { + pub(super) fn record_binding( + &mut self, + binding_id: ScopedDefinitionId, + branching_conditions: &BranchingConditions, + ) { // The new binding replaces all previous live bindings in this path, and has no // constraints. self.live_bindings = Bindings::with(binding_id.into()); - self.constraints = Constraints::with_capacity(1); - self.constraints.push(BitSet::default()); + self.constraints = ConstraintsPerBinding::with_capacity(1); + self.constraints.push(Constraints::default()); + + self.branching_conditions = BranchingConditionsPerBinding::with_capacity(1); + self.branching_conditions + .push(BranchingConditions::default()); + for id in branching_conditions.iter() { + self.branching_conditions[0].insert(id); + } self.may_be_unbound = false; } @@ -173,11 +220,12 @@ impl SymbolBindings { } } - /// Iterate over currently live bindings for this symbol. - pub(super) fn iter(&self) -> BindingIdWithConstraintsIterator { + /// Iterate over currently live bindings for this symbol, in reverse order. + pub(super) fn iter_rev(&self) -> BindingIdWithConstraintsIterator { BindingIdWithConstraintsIterator { - definitions: self.live_bindings.iter(), - constraints: self.constraints.iter(), + definitions: self.live_bindings.iter_rev(), + constraints: self.constraints.iter().rev(), + branching_conditions: self.branching_conditions.iter().rev(), } } @@ -207,8 +255,13 @@ impl SymbolState { } /// Record a newly-encountered binding for this symbol. - pub(super) fn record_binding(&mut self, binding_id: ScopedDefinitionId) { - self.bindings.record_binding(binding_id); + pub(super) fn record_binding( + &mut self, + binding_id: ScopedDefinitionId, + branching_conditions: &BranchingConditions, + ) { + self.bindings + .record_binding(binding_id, branching_conditions); } /// Add given constraint to all live bindings. @@ -222,8 +275,13 @@ impl SymbolState { } /// Record a newly-encountered declaration of this symbol. - pub(super) fn record_declaration(&mut self, declaration_id: ScopedDefinitionId) { - self.declarations.record_declaration(declaration_id); + pub(super) fn record_declaration( + &mut self, + declaration_id: ScopedDefinitionId, + branching_conditions: &BranchingConditions, + ) { + self.declarations + .record_declaration(declaration_id, branching_conditions); } /// Merge another [`SymbolState`] into this one. @@ -231,25 +289,26 @@ impl SymbolState { let mut a = Self { bindings: SymbolBindings { live_bindings: Bindings::default(), - constraints: Constraints::default(), + constraints: ConstraintsPerBinding::default(), + branching_conditions: BranchingConditionsPerBinding::default(), may_be_unbound: self.bindings.may_be_unbound || b.bindings.may_be_unbound, }, declarations: SymbolDeclarations { live_declarations: self.declarations.live_declarations.clone(), + branching_conditions: BranchingConditionsPerBinding::default(), may_be_undeclared: self.declarations.may_be_undeclared || b.declarations.may_be_undeclared, }, }; std::mem::swap(&mut a, self); - self.declarations - .live_declarations - .union(&b.declarations.live_declarations); let mut a_defs_iter = a.bindings.live_bindings.iter(); let mut b_defs_iter = b.bindings.live_bindings.iter(); let mut a_constraints_iter = a.bindings.constraints.into_iter(); let mut b_constraints_iter = b.bindings.constraints.into_iter(); + let mut a_conditions_iter = a.bindings.branching_conditions.into_iter(); + let mut b_conditions_iter = b.bindings.branching_conditions.into_iter(); let mut opt_a_def: Option = a_defs_iter.next(); let mut opt_b_def: Option = b_defs_iter.next(); @@ -261,7 +320,10 @@ impl SymbolState { // path is irrelevant. // Helper to push `def`, with constraints in `constraints_iter`, onto `self`. - let push = |def, constraints_iter: &mut ConstraintsIntoIterator, merged: &mut Self| { + let push = |def, + constraints_iter: &mut ConstraintsIntoIterator, + branching_conditions_iter: &mut BranchingConditionsIntoIterator, + merged: &mut Self| { merged.bindings.live_bindings.insert(def); // SAFETY: we only ever create SymbolState with either no definitions and no constraint // bitsets (`::unbound`) or one definition and one constraint bitset (`::with`), and @@ -271,7 +333,14 @@ impl SymbolState { let constraints = constraints_iter .next() .expect("definitions and constraints length mismatch"); + let branching_conditions = branching_conditions_iter + .next() + .expect("definitions and branching_conditions length mismatch"); merged.bindings.constraints.push(constraints); + merged + .bindings + .branching_conditions + .push(branching_conditions); }; loop { @@ -279,17 +348,17 @@ impl SymbolState { (Some(a_def), Some(b_def)) => match a_def.cmp(&b_def) { std::cmp::Ordering::Less => { // Next definition ID is only in `a`, push it to `self` and advance `a`. - push(a_def, &mut a_constraints_iter, self); + push(a_def, &mut a_constraints_iter, &mut a_conditions_iter, self); opt_a_def = a_defs_iter.next(); } std::cmp::Ordering::Greater => { // Next definition ID is only in `b`, push it to `self` and advance `b`. - push(b_def, &mut b_constraints_iter, self); + push(b_def, &mut b_constraints_iter, &mut b_conditions_iter, self); opt_b_def = b_defs_iter.next(); } std::cmp::Ordering::Equal => { // Next definition is in both; push to `self` and intersect constraints. - push(a_def, &mut b_constraints_iter, self); + push(a_def, &mut b_constraints_iter, &mut b_conditions_iter, self); // SAFETY: we only ever create SymbolState with either no definitions and // no constraint bitsets (`::unbound`) or one definition and one constraint // bitset (`::with`), and `::merge` always pushes one definition and one @@ -298,6 +367,10 @@ impl SymbolState { let a_constraints = a_constraints_iter .next() .expect("definitions and constraints length mismatch"); + // SAFETY: The same is true for branching_conditions. + a_conditions_iter + .next() + .expect("branching_conditions length mismatch"); // If the same definition is visible through both paths, any constraint // that applies on only one path is irrelevant to the resulting type from // unioning the two paths, so we intersect the constraints. @@ -312,17 +385,74 @@ impl SymbolState { }, (Some(a_def), None) => { // We've exhausted `b`, just push the def from `a` and move on to the next. - push(a_def, &mut a_constraints_iter, self); + push(a_def, &mut a_constraints_iter, &mut a_conditions_iter, self); opt_a_def = a_defs_iter.next(); } (None, Some(b_def)) => { // We've exhausted `a`, just push the def from `b` and move on to the next. - push(b_def, &mut b_constraints_iter, self); + push(b_def, &mut b_constraints_iter, &mut b_conditions_iter, self); opt_b_def = b_defs_iter.next(); } (None, None) => break, } } + + // Same as above, but for declarations. + let mut a_decls_iter = a.declarations.live_declarations.iter(); + let mut b_decls_iter = b.declarations.live_declarations.iter(); + let mut a_conditions_iter = a.declarations.branching_conditions.into_iter(); + let mut b_conditions_iter = b.declarations.branching_conditions.into_iter(); + + let mut opt_a_decl: Option = a_decls_iter.next(); + let mut opt_b_decl: Option = b_decls_iter.next(); + + let push = + |decl, conditions_iter: &mut BranchingConditionsIntoIterator, merged: &mut Self| { + merged.declarations.live_declarations.insert(decl); + let conditions = conditions_iter + .next() + .expect("declarations and branching_conditions length mismatch"); + merged.declarations.branching_conditions.push(conditions); + }; + + loop { + match (opt_a_decl, opt_b_decl) { + (Some(a_decl), Some(b_decl)) => { + match a_decl.cmp(&b_decl) { + std::cmp::Ordering::Less => { + push(a_decl, &mut a_conditions_iter, self); + opt_a_decl = a_decls_iter.next(); + } + std::cmp::Ordering::Greater => { + push(b_decl, &mut b_conditions_iter, self); + opt_b_decl = b_decls_iter.next(); + } + std::cmp::Ordering::Equal => { + push(a_decl, &mut b_conditions_iter, self); + self.declarations + .branching_conditions + .last_mut() + .expect("declarations and branching_conditions length mismatch") + .intersect(&a_conditions_iter.next().expect( + "declarations and branching_conditions length mismatch", + )); + + opt_a_decl = a_decls_iter.next(); + opt_b_decl = b_decls_iter.next(); + } + } + } + (Some(a_decl), None) => { + push(a_decl, &mut a_conditions_iter, self); + opt_a_decl = a_decls_iter.next(); + } + (None, Some(b_decl)) => { + push(b_decl, &mut b_conditions_iter, self); + opt_b_decl = b_decls_iter.next(); + } + (None, None) => break, + } + } } pub(super) fn bindings(&self) -> &SymbolBindings { @@ -332,11 +462,6 @@ impl SymbolState { pub(super) fn declarations(&self) -> &SymbolDeclarations { &self.declarations } - - /// Could the symbol be unbound? - pub(super) fn may_be_unbound(&self) -> bool { - self.bindings.may_be_unbound() - } } /// The default state of a symbol, if we've seen no definitions of it, is undefined (that is, @@ -353,26 +478,37 @@ impl Default for SymbolState { pub(super) struct BindingIdWithConstraints<'a> { pub(super) definition: ScopedDefinitionId, pub(super) constraint_ids: ConstraintIdIterator<'a>, + pub(super) branching_conditions_ids: BranchingConditionIdIterator<'a>, } #[derive(Debug)] pub(super) struct BindingIdWithConstraintsIterator<'a> { - definitions: BindingsIterator<'a>, - constraints: ConstraintsIterator<'a>, + definitions: ReverseBindingsIterator<'a>, + constraints: std::iter::Rev>, + branching_conditions: std::iter::Rev>, } impl<'a> Iterator for BindingIdWithConstraintsIterator<'a> { type Item = BindingIdWithConstraints<'a>; fn next(&mut self) -> Option { - match (self.definitions.next(), self.constraints.next()) { - (None, None) => None, - (Some(def), Some(constraints)) => Some(BindingIdWithConstraints { - definition: ScopedDefinitionId::from_u32(def), - constraint_ids: ConstraintIdIterator { - wrapped: constraints.iter(), - }, - }), + match ( + self.definitions.next(), + self.constraints.next(), + self.branching_conditions.next(), + ) { + (None, None, None) => None, + (Some(def), Some(constraints), Some(branching_conditions)) => { + Some(BindingIdWithConstraints { + definition: ScopedDefinitionId::from_u32(def), + constraint_ids: ConstraintIdIterator { + wrapped: constraints.iter(), + }, + branching_conditions_ids: BranchingConditionIdIterator { + wrapped: branching_conditions.iter(), + }, + }) + } // SAFETY: see above. _ => unreachable!("definitions and constraints length mismatch"), } @@ -397,15 +533,43 @@ impl Iterator for ConstraintIdIterator<'_> { impl std::iter::FusedIterator for ConstraintIdIterator<'_> {} #[derive(Debug)] +pub(super) struct BranchingConditionIdIterator<'a> { + wrapped: BitSetIterator<'a, INLINE_BRANCHING_BLOCKS>, +} + +impl Iterator for BranchingConditionIdIterator<'_> { + type Item = ScopedBranchingConditionId; + + fn next(&mut self) -> Option { + self.wrapped + .next() + .map(ScopedBranchingConditionId::from_u32) + } +} + +impl std::iter::FusedIterator for BranchingConditionIdIterator<'_> {} + +#[derive(Clone)] pub(super) struct DeclarationIdIterator<'a> { - inner: DeclarationsIterator<'a>, + inner: ReverseDeclarationsIterator<'a>, + branching_conditions: std::iter::Rev>, } -impl Iterator for DeclarationIdIterator<'_> { - type Item = ScopedDefinitionId; +impl<'a> Iterator for DeclarationIdIterator<'a> { + type Item = (ScopedDefinitionId, BranchingConditionIdIterator<'a>); fn next(&mut self) -> Option { - self.inner.next().map(ScopedDefinitionId::from_u32) + match (self.inner.next(), self.branching_conditions.next()) { + (None, None) => None, + (Some(declaration), Some(branching_conditions)) => Some(( + ScopedDefinitionId::from_u32(declaration), + BranchingConditionIdIterator { + wrapped: branching_conditions.iter(), + }, + )), + // SAFETY: see above. + _ => unreachable!("declarations and branching_conditions length mismatch"), + } } } @@ -413,13 +577,14 @@ impl std::iter::FusedIterator for DeclarationIdIterator<'_> {} #[cfg(test)] mod tests { - use super::{ScopedConstraintId, ScopedDefinitionId, SymbolState}; + use super::*; + #[track_caller] fn assert_bindings(symbol: &SymbolState, may_be_unbound: bool, expected: &[&str]) { - assert_eq!(symbol.may_be_unbound(), may_be_unbound); - let actual = symbol + assert_eq!(symbol.bindings.may_be_unbound, may_be_unbound); + let mut actual = symbol .bindings() - .iter() + .iter_rev() .map(|def_id_with_constraints| { format!( "{}<{}>", @@ -433,20 +598,23 @@ mod tests { ) }) .collect::>(); + actual.reverse(); assert_eq!(actual, expected); } + #[track_caller] pub(crate) fn assert_declarations( symbol: &SymbolState, may_be_undeclared: bool, expected: &[u32], ) { assert_eq!(symbol.declarations.may_be_undeclared(), may_be_undeclared); - let actual = symbol + let mut actual = symbol .declarations() - .iter() - .map(ScopedDefinitionId::as_u32) + .iter_rev() + .map(|(d, _)| d.as_u32()) .collect::>(); + actual.reverse(); assert_eq!(actual, expected); } @@ -460,7 +628,10 @@ mod tests { #[test] fn with() { let mut sym = SymbolState::undefined(); - sym.record_binding(ScopedDefinitionId::from_u32(0)); + sym.record_binding( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); assert_bindings(&sym, false, &["0<>"]); } @@ -468,7 +639,10 @@ mod tests { #[test] fn set_may_be_unbound() { let mut sym = SymbolState::undefined(); - sym.record_binding(ScopedDefinitionId::from_u32(0)); + sym.record_binding( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); sym.set_may_be_unbound(); assert_bindings(&sym, true, &["0<>"]); @@ -477,7 +651,10 @@ mod tests { #[test] fn record_constraint() { let mut sym = SymbolState::undefined(); - sym.record_binding(ScopedDefinitionId::from_u32(0)); + sym.record_binding( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); sym.record_constraint(ScopedConstraintId::from_u32(0)); assert_bindings(&sym, false, &["0<0>"]); @@ -487,11 +664,17 @@ mod tests { fn merge() { // merging the same definition with the same constraint keeps the constraint let mut sym0a = SymbolState::undefined(); - sym0a.record_binding(ScopedDefinitionId::from_u32(0)); + sym0a.record_binding( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); sym0a.record_constraint(ScopedConstraintId::from_u32(0)); let mut sym0b = SymbolState::undefined(); - sym0b.record_binding(ScopedDefinitionId::from_u32(0)); + sym0b.record_binding( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); sym0b.record_constraint(ScopedConstraintId::from_u32(0)); sym0a.merge(sym0b); @@ -500,11 +683,17 @@ mod tests { // merging the same definition with differing constraints drops all constraints let mut sym1a = SymbolState::undefined(); - sym1a.record_binding(ScopedDefinitionId::from_u32(1)); + sym1a.record_binding( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); sym1a.record_constraint(ScopedConstraintId::from_u32(1)); let mut sym1b = SymbolState::undefined(); - sym1b.record_binding(ScopedDefinitionId::from_u32(1)); + sym1b.record_binding( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); sym1b.record_constraint(ScopedConstraintId::from_u32(2)); sym1a.merge(sym1b); @@ -513,7 +702,10 @@ mod tests { // merging a constrained definition with unbound keeps both let mut sym2a = SymbolState::undefined(); - sym2a.record_binding(ScopedDefinitionId::from_u32(2)); + sym2a.record_binding( + ScopedDefinitionId::from_u32(2), + &BranchingConditions::default(), + ); sym2a.record_constraint(ScopedConstraintId::from_u32(3)); let sym2b = SymbolState::undefined(); @@ -538,7 +730,10 @@ mod tests { #[test] fn record_declaration() { let mut sym = SymbolState::undefined(); - sym.record_declaration(ScopedDefinitionId::from_u32(1)); + sym.record_declaration( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); assert_declarations(&sym, false, &[1]); } @@ -546,8 +741,14 @@ mod tests { #[test] fn record_declaration_override() { let mut sym = SymbolState::undefined(); - sym.record_declaration(ScopedDefinitionId::from_u32(1)); - sym.record_declaration(ScopedDefinitionId::from_u32(2)); + sym.record_declaration( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); + sym.record_declaration( + ScopedDefinitionId::from_u32(2), + &BranchingConditions::default(), + ); assert_declarations(&sym, false, &[2]); } @@ -555,10 +756,16 @@ mod tests { #[test] fn record_declaration_merge() { let mut sym = SymbolState::undefined(); - sym.record_declaration(ScopedDefinitionId::from_u32(1)); + sym.record_declaration( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); let mut sym2 = SymbolState::undefined(); - sym2.record_declaration(ScopedDefinitionId::from_u32(2)); + sym2.record_declaration( + ScopedDefinitionId::from_u32(2), + &BranchingConditions::default(), + ); sym.merge(sym2); @@ -568,7 +775,10 @@ mod tests { #[test] fn record_declaration_merge_partial_undeclared() { let mut sym = SymbolState::undefined(); - sym.record_declaration(ScopedDefinitionId::from_u32(1)); + sym.record_declaration( + ScopedDefinitionId::from_u32(1), + &BranchingConditions::default(), + ); let sym2 = SymbolState::undefined(); @@ -580,7 +790,10 @@ mod tests { #[test] fn set_may_be_undeclared() { let mut sym = SymbolState::undefined(); - sym.record_declaration(ScopedDefinitionId::from_u32(0)); + sym.record_declaration( + ScopedDefinitionId::from_u32(0), + &BranchingConditions::default(), + ); sym.set_may_be_undeclared(); assert_declarations(&sym, true, &[0]); diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 83b02eddb6cf5..c39757994bf6a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -31,6 +31,7 @@ use crate::types::diagnostic::TypeCheckDiagnosticsBuilder; use crate::types::mro::{ClassBase, Mro, MroError, MroIterator}; use crate::types::narrow::narrowing_constraint; use crate::{Db, FxOrderSet, Module, Program, PythonVersion}; +pub(crate) use static_truthiness::StaticTruthiness; mod builder; mod call; @@ -40,6 +41,7 @@ mod infer; mod mro; mod narrow; mod signatures; +mod static_truthiness; mod string_annotation; mod unpacker; @@ -71,50 +73,63 @@ fn symbol_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymbolI // If the symbol is declared, the public type is based on declarations; otherwise, it's based // on inference from bindings. - if use_def.has_public_declarations(symbol) { - let declarations = use_def.public_declarations(symbol); - // If the symbol is undeclared in some paths, include the inferred type in the public type. - let undeclared_ty = if declarations.may_be_undeclared() { + let declaredness = use_def.public_declarations(symbol).declaredness(db); + + // TODO (ticket: https://github.com/astral-sh/ruff/issues/14297) Our handling of boundness + // currently only depends on bindings, and ignores declarations. This is inconsistent, since + // we only look at bindings if the symbol may be undeclared. Consider the following example: + // ```py + // x: int + // + // if flag: + // y: int + // else + // y = 3 + // ``` + // If we import from this module, we will currently report `x` as a definitely-bound symbol + // (even though it has no bindings at all!) but report `y` as possibly-unbound (even though + // every path has either a binding or a declaration for it.) + let undeclared_ty = match declaredness { + None => { + return bindings_ty(db, use_def.public_bindings(symbol)) + .map(|bindings_ty| { + if let Some(boundness) = use_def.public_boundness(db, symbol) { + Symbol::Type(bindings_ty, boundness) + } else { + Symbol::Unbound + } + }) + .unwrap_or(Symbol::Unbound); + } + Some(Boundness::PossiblyUnbound) => { + // If the symbol is undeclared in some paths, include the inferred type in the public type. Some( bindings_ty(db, use_def.public_bindings(symbol)) - .map(|bindings_ty| Symbol::Type(bindings_ty, use_def.public_boundness(symbol))) + .map(|bindings_ty| { + if let Some(boundness) = use_def.public_boundness(db, symbol) { + Symbol::Type(bindings_ty, boundness) + } else { + Symbol::Unbound + } + }) .unwrap_or(Symbol::Unbound), ) - } else { - None - }; - // Intentionally ignore conflicting declared types; that's not our problem, it's the - // problem of the module we are importing from. - - // TODO: Our handling of boundness currently only depends on bindings, and ignores - // declarations. This is inconsistent, since we only look at bindings if the symbol - // may be undeclared. Consider the following example: - // ```py - // x: int - // - // if flag: - // y: int - // else - // y = 3 - // ``` - // If we import from this module, we will currently report `x` as a definitely-bound - // symbol (even though it has no bindings at all!) but report `y` as possibly-unbound - // (even though every path has either a binding or a declaration for it.) - - match undeclared_ty { - Some(Symbol::Type(ty, boundness)) => Symbol::Type( - declarations_ty(db, declarations, Some(ty)).unwrap_or_else(|(ty, _)| ty), - boundness, - ), - None | Some(Symbol::Unbound) => Symbol::Type( - declarations_ty(db, declarations, None).unwrap_or_else(|(ty, _)| ty), - Boundness::Bound, - ), } - } else { - bindings_ty(db, use_def.public_bindings(symbol)) - .map(|bindings_ty| Symbol::Type(bindings_ty, use_def.public_boundness(symbol))) - .unwrap_or(Symbol::Unbound) + Some(Boundness::Bound) => None, + }; + + // Intentionally ignore conflicting declared types; that's not our problem, it's the + // problem of the module we are importing from. + let declarations = use_def.public_declarations(symbol); + match undeclared_ty { + Some(Symbol::Type(ty, boundness)) => Symbol::Type( + declarations_ty(db, declarations, Some(ty)).unwrap_or_else(|(ty, _)| ty), + boundness, + ), + None | Some(Symbol::Unbound) => Symbol::Type( + declarations_ty(db, declarations, None).unwrap_or_else(|(ty, _)| ty), + Boundness::Bound, + ), } } @@ -236,6 +251,52 @@ fn definition_expression_ty<'db>( } } +/// The 'visibility' of a binding or declaration. +/// +/// Consider the following example: +/// ```py +/// x = 1 +/// +/// if True: +/// x = 2 +/// +/// if False: +/// x = 3 +/// +/// if flag(): +/// x = 4 +/// ``` +/// When we infer the type of `x`, we look back "through" the bindings in reverse order. +/// The first binding is `x = 4`. It is "transparent" because we could have either taken +/// the `if flag()` branch or not. The second binding `x = 3` is "invisible" because we +/// can statically determine that the `if False` branch is never taken. The third binding +/// `x = 2` is "opaque" because we can statically determine that the `if True` branch is +/// always taken. If the visibility of a binding is "opaque", bindings behind it are not +/// visible. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Visibility<'db> { + Invisible, + Transparent(Type<'db>), + Opaque(Type<'db>), +} + +impl<'db> Visibility<'db> { + fn is_not_opaque(&self) -> bool { + !matches!(self, Visibility::Opaque(_)) + } + + fn is_invisible(&self) -> bool { + matches!(self, Visibility::Invisible) + } + + fn unwrap_or_never(self) -> Type<'db> { + match self { + Visibility::Invisible => Type::Never, + Visibility::Transparent(ty) | Visibility::Opaque(ty) => ty, + } + } +} + /// Infer the combined type of an iterator of bindings. /// /// Will return a union if there is more than one binding. @@ -243,34 +304,60 @@ fn bindings_ty<'db>( db: &'db dyn Db, bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>, ) -> Option> { - let mut def_types = bindings_with_constraints.map( - |BindingWithConstraints { - binding, - constraints, - }| { - let mut constraint_tys = constraints - .filter_map(|constraint| narrowing_constraint(db, constraint, binding)) - .peekable(); - - let binding_ty = binding_ty(db, binding); - if constraint_tys.peek().is_some() { - constraint_tys - .fold( - IntersectionBuilder::new(db).add_positive(binding_ty), - IntersectionBuilder::add_positive, - ) - .build() - } else { - binding_ty - } - }, - ); + let types = bindings_with_constraints + .map( + |BindingWithConstraints { + binding, + constraints, + branching_conditions, + }| { + let result = StaticTruthiness::analyze(db, branching_conditions); + + if result.any_always_false { + Visibility::Invisible + } else { + let mut constraint_tys = constraints + .filter_map(|constraint| narrowing_constraint(db, constraint, binding)) + .peekable(); + + let binding_ty = binding_ty(db, binding); + let ty = if constraint_tys.peek().is_some() { + let intersection_ty = constraint_tys + .fold( + IntersectionBuilder::new(db).add_positive(binding_ty), + IntersectionBuilder::add_positive, + ) + .build(); + intersection_ty + } else { + binding_ty + }; + + if result.at_least_one_condition && result.all_always_true { + Visibility::Opaque(ty) + } else { + Visibility::Transparent(ty) + } + } + }, + ) + .take_while_inclusive(Visibility::is_not_opaque); - if let Some(first) = def_types.next() { - if let Some(second) = def_types.next() { + // TODO: try to get rid of the `collect` here + let types: Vec<_> = types.collect(); + + if !types.is_empty() && types.iter().all(Visibility::is_invisible) { + // If all bindings are invisible, the symbol is unbound. + return Some(Type::Unknown); + } + + let mut types = types.iter().map(|v| v.unwrap_or_never()).rev(); + + if let Some(first) = types.next() { + if let Some(second) = types.next() { Some(UnionType::from_elements( db, - [first, second].into_iter().chain(def_types), + [first, second].into_iter().chain(types), )) } else { Some(first) @@ -301,9 +388,28 @@ fn declarations_ty<'db>( declarations: DeclarationsIterator<'_, 'db>, undeclared_ty: Option>, ) -> DeclaredTypeResult<'db> { - let decl_types = declarations.map(|declaration| declaration_ty(db, declaration)); + let types = declarations + .map(|(declaration, branching_conditions)| { + let result = StaticTruthiness::analyze(db, branching_conditions); + + if result.any_always_false { + Visibility::Invisible + } else { + if result.at_least_one_condition && result.all_always_true { + Visibility::Opaque(declaration_ty(db, declaration)) + } else { + Visibility::Transparent(declaration_ty(db, declaration)) + } + } + }) + .take_while_inclusive(Visibility::is_not_opaque) + .map(Visibility::unwrap_or_never); - let mut all_types = undeclared_ty.into_iter().chain(decl_types); + // TODO: try to get rid of the `collect` here (see above) + let types: Vec<_> = types.collect(); + let types = types.into_iter().rev(); + + let mut all_types = undeclared_ty.into_iter().chain(types); let first = all_types.next().expect( "declarations_ty must not be called with zero declarations and no may-be-undeclared", @@ -827,26 +933,6 @@ impl<'db> Type<'db> { return false; } - // TODO: The following is a workaround that is required to unify the two different versions - // of `NoneType` and `NoDefaultType` in typeshed. This should not be required anymore once - // we understand `sys.version_info` branches. - if let ( - Type::Instance(InstanceType { class: self_class }), - Type::Instance(InstanceType { - class: target_class, - }), - ) = (self, other) - { - let self_known = self_class.known(db); - if matches!( - self_known, - Some(KnownClass::NoneType | KnownClass::NoDefaultType) - ) && self_known == target_class.known(db) - { - return true; - } - } - // type[object] ≡ type if let ( Type::SubclassOf(SubclassOfType { @@ -1437,7 +1523,7 @@ impl<'db> Type<'db> { /// /// This is used to determine the value that would be returned /// when `bool(x)` is called on an object `x`. - fn bool(&self, db: &'db dyn Db) -> Truthiness { + pub(crate) fn bool(&self, db: &'db dyn Db) -> Truthiness { match self { Type::Any | Type::Todo(_) | Type::Never | Type::Unknown => Truthiness::Ambiguous, Type::FunctionLiteral(_) => Truthiness::AlwaysTrue, @@ -2518,11 +2604,19 @@ pub enum Truthiness { } impl Truthiness { - const fn is_ambiguous(self) -> bool { + pub(crate) const fn is_ambiguous(self) -> bool { matches!(self, Truthiness::Ambiguous) } - const fn negate(self) -> Self { + pub(crate) const fn is_always_false(self) -> bool { + matches!(self, Truthiness::AlwaysFalse) + } + + pub(crate) const fn is_always_true(self) -> bool { + matches!(self, Truthiness::AlwaysTrue) + } + + pub(crate) const fn negate(self) -> Self { match self { Self::AlwaysTrue => Self::AlwaysFalse, Self::AlwaysFalse => Self::AlwaysTrue, @@ -2530,6 +2624,14 @@ impl Truthiness { } } + pub(crate) const fn negate_if(self, condition: bool) -> Self { + if condition { + self.negate() + } else { + self + } + } + fn into_type(self, db: &dyn Db) -> Type { match self { Self::AlwaysTrue => Type::BooleanLiteral(true), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 1b732ce5311b8..49fcfd2d12feb 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -783,7 +783,7 @@ impl<'db> TypeInferenceBuilder<'db> { debug_assert!(binding.is_binding(self.db)); let use_def = self.index.use_def_map(binding.file_scope(self.db)); let declarations = use_def.declarations_at_binding(binding); - let undeclared_ty = if declarations.may_be_undeclared() { + let undeclared_ty = if declarations.clone().may_be_undeclared(self.db) { Some(Type::Unknown) } else { None @@ -1718,13 +1718,24 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { // TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 + match pattern { + ast::Pattern::MatchValue(match_value) => { + self.infer_standalone_expression(&match_value.value); + } + _ => { + self.infer_match_pattern_impl(pattern); + } + } + } + + fn infer_match_pattern_impl(&mut self, pattern: &ast::Pattern) { match pattern { ast::Pattern::MatchValue(match_value) => { self.infer_expression(&match_value.value); } ast::Pattern::MatchSequence(match_sequence) => { for pattern in &match_sequence.patterns { - self.infer_match_pattern(pattern); + self.infer_match_pattern_impl(pattern); } } ast::Pattern::MatchMapping(match_mapping) => { @@ -1738,7 +1749,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(key); } for pattern in patterns { - self.infer_match_pattern(pattern); + self.infer_match_pattern_impl(pattern); } } ast::Pattern::MatchClass(match_class) => { @@ -1748,21 +1759,21 @@ impl<'db> TypeInferenceBuilder<'db> { arguments, } = match_class; for pattern in &arguments.patterns { - self.infer_match_pattern(pattern); + self.infer_match_pattern_impl(pattern); } for keyword in &arguments.keywords { - self.infer_match_pattern(&keyword.pattern); + self.infer_match_pattern_impl(&keyword.pattern); } self.infer_expression(cls); } ast::Pattern::MatchAs(match_as) => { if let Some(pattern) = &match_as.pattern { - self.infer_match_pattern(pattern); + self.infer_match_pattern_impl(pattern); } } ast::Pattern::MatchOr(match_or) => { for pattern in &match_or.patterns { - self.infer_match_pattern(pattern); + self.infer_match_pattern_impl(pattern); } } ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {} @@ -2995,24 +3006,25 @@ impl<'db> TypeInferenceBuilder<'db> { if let Some(symbol) = self.index.symbol_table(file_scope_id).symbol_id_by_name(id) { ( bindings_ty(self.db, use_def.public_bindings(symbol)), - use_def.public_boundness(symbol), + use_def.public_boundness(self.db, symbol), ) } else { assert!( self.deferred_state.in_string_annotation(), "Expected the symbol table to create a symbol for every Name node" ); - (None, Boundness::PossiblyUnbound) + (None, Some(Boundness::PossiblyUnbound)) } } else { let use_id = name.scoped_use_id(self.db, self.scope()); ( bindings_ty(self.db, use_def.bindings_at_use(use_id)), - use_def.use_boundness(use_id), + use_def.use_boundness(self.db, use_id), ) }; - - if boundness == Boundness::PossiblyUnbound { + if boundness == Some(Boundness::Bound) { + bindings_ty.unwrap_or(Type::Unknown) + } else { match self.lookup_name(name) { Symbol::Type(looked_up_ty, looked_up_boundness) => { if looked_up_boundness == Boundness::PossiblyUnbound { @@ -3025,15 +3037,17 @@ impl<'db> TypeInferenceBuilder<'db> { } Symbol::Unbound => { if bindings_ty.is_some() { - self.diagnostics.add_possibly_unresolved_reference(name); + if boundness == Some(Boundness::PossiblyUnbound) { + self.diagnostics.add_possibly_unresolved_reference(name); + } else { + self.diagnostics.add_unresolved_reference(name); + } } else { self.diagnostics.add_unresolved_reference(name); } bindings_ty.unwrap_or(Type::Unknown) } } - } else { - bindings_ty.unwrap_or(Type::Unknown) } } diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 69513ccfba4c1..bf74bdfd38364 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -1,5 +1,7 @@ use crate::semantic_index::ast_ids::HasScopedExpressionId; -use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraint}; +use crate::semantic_index::constraint::{ + Constraint, ConstraintNode, PatternConstraint, PatternConstraintKind, +}; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; @@ -215,31 +217,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> { ) -> Option> { let subject = pattern.subject(self.db); - match pattern.pattern(self.db).node() { - ast::Pattern::MatchValue(_) => { - None // TODO - } - ast::Pattern::MatchSingleton(singleton_pattern) => { - self.evaluate_match_pattern_singleton(subject, singleton_pattern) - } - ast::Pattern::MatchSequence(_) => { - None // TODO - } - ast::Pattern::MatchMapping(_) => { - None // TODO - } - ast::Pattern::MatchClass(_) => { - None // TODO - } - ast::Pattern::MatchStar(_) => { - None // TODO - } - ast::Pattern::MatchAs(_) => { - None // TODO - } - ast::Pattern::MatchOr(_) => { - None // TODO + match pattern.kind(self.db) { + PatternConstraintKind::Singleton(singleton) => { + self.evaluate_match_pattern_singleton(*subject, *singleton) } + // TODO: support more pattern kinds + PatternConstraintKind::Value(_) | PatternConstraintKind::Unsupported => None, } } @@ -457,14 +440,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn evaluate_match_pattern_singleton( &mut self, - subject: &ast::Expr, - pattern: &ast::PatternMatchSingleton, + subject: Expression<'db>, + singleton: ast::Singleton, ) -> Option> { - if let Some(ast::ExprName { id, .. }) = subject.as_name_expr() { + if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let ty = match pattern.value { + let ty = match singleton { ast::Singleton::None => Type::none(self.db), ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::False => Type::BooleanLiteral(false), diff --git a/crates/red_knot_python_semantic/src/types/static_truthiness.rs b/crates/red_knot_python_semantic/src/types/static_truthiness.rs new file mode 100644 index 0000000000000..99fc94d5a0d65 --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/static_truthiness.rs @@ -0,0 +1,137 @@ +use crate::semantic_index::{ + ast_ids::HasScopedExpressionId, + branching_condition::BranchingCondition, + constraint::{Constraint, ConstraintNode, PatternConstraintKind}, + BranchingConditionsIterator, +}; +use crate::types::{infer_expression_types, Truthiness}; +use crate::Db; + +/// The result of a static-truthiness analysis. +/// +/// Consider the following example: +/// ```py +/// a = 1 +/// if True: +/// b = 1 +/// if : +/// c = 1 +/// if False: +/// d = 1 +/// ``` +/// +/// Given an iterator over the branching conditions for each of these bindings, we would get: +/// ```txt +/// - a: {any_always_false: false, all_always_true: true, at_least_one_condition: false} +/// - b: {any_always_false: false, all_always_true: true, at_least_one_condition: true} +/// - c: {any_always_false: false, all_always_true: false, at_least_one_condition: true} +/// - d: {any_always_false: true, all_always_true: false, at_least_one_condition: true} +/// ``` +#[derive(Debug)] +pub(crate) struct StaticTruthiness { + /// Is any of the branching conditions always false? (false if there are no conditions) + pub(crate) any_always_false: bool, + /// Are all of the branching conditions always true? (true if there are no conditions) + pub(crate) all_always_true: bool, + /// Is there at least one branching condition? + pub(crate) at_least_one_condition: bool, +} + +impl StaticTruthiness { + /// Analyze the (statically known) truthiness for a list of branching conditions. + pub(crate) fn analyze<'db>( + db: &'db dyn Db, + branching_conditions: BranchingConditionsIterator<'_, 'db>, + ) -> Self { + let mut result = Self { + any_always_false: false, + all_always_true: true, + at_least_one_condition: false, + }; + + for condition in branching_conditions { + let truthiness = match condition { + BranchingCondition::ConditionalOn(Constraint { + node: ConstraintNode::Expression(test_expr), + is_positive, + }) => { + let inference = infer_expression_types(db, test_expr); + let scope = test_expr.scope(db); + let ty = inference + .expression_ty(test_expr.node_ref(db).scoped_expression_id(db, scope)); + + ty.bool(db).negate_if(!is_positive) + } + BranchingCondition::ConditionalOn(Constraint { + node: ConstraintNode::Pattern(inner), + .. + }) => match inner.kind(db) { + PatternConstraintKind::Value(value) => { + let subject_expression = inner.subject(db); + let inference = infer_expression_types(db, *subject_expression); + let scope = subject_expression.scope(db); + let subject_ty = inference.expression_ty( + subject_expression + .node_ref(db) + .scoped_expression_id(db, scope), + ); + + let inference = infer_expression_types(db, *value); + let scope = value.scope(db); + let value_ty = inference + .expression_ty(value.node_ref(db).scoped_expression_id(db, scope)); + + if subject_ty.is_single_valued(db) { + Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)) + } else { + Truthiness::Ambiguous + } + } + PatternConstraintKind::Singleton(_) | PatternConstraintKind::Unsupported => { + Truthiness::Ambiguous + } + }, + BranchingCondition::Ambiguous => Truthiness::Ambiguous, + }; + + result.any_always_false |= truthiness.is_always_false(); + result.all_always_true &= truthiness.is_always_true(); + result.at_least_one_condition = true; + } + + result + } + + /// Merge two static truthiness results, as if they came from two different control-flow paths. + /// + /// Note that the logical operations are exactly opposite to what one would expect from the names + /// of the fields. The reason for this is that we want to draw conclusions like "this symbol can + /// not be bound because one of the branching conditions is always false". We can only draw this + /// conclusion if this is true in both control-flow paths. Similarly, we want to infer that the + /// binding of a symbol is unconditionally visible if all branching conditions are known to be + /// statically true. It is enough if this is the case for either of the two control-flow paths. + /// The other paths can not be taken if this is the case. + pub(crate) fn flow_merge(self, other: &Self) -> Self { + Self { + any_always_false: self.any_always_false && other.any_always_false, + all_always_true: self.all_always_true || other.all_always_true, + at_least_one_condition: self.at_least_one_condition && other.at_least_one_condition, + } + } + + /// A static truthiness result that states our knowledge before we have seen any bindings. + /// + /// This is used as a starting point for merging multiple results. + pub(crate) fn no_bindings() -> Self { + Self { + // Corresponds to "definitely unbound". Before we haven't seen any bindings, we + // can conclude that the symbol is not bound. + any_always_false: true, + // Corresponds to "definitely bound". Before we haven't seen any bindings, we + // can not conclude that the symbol is bound. + all_always_true: false, + // Irrelevant for this analysis. + at_least_one_condition: false, + } + } +}