Skip to content

Commit 193c381

Browse files
authored
[red-knot] Check whether two callable types are equivalent (#16698)
## Summary This PR checks whether two callable types are equivalent or not. This is required because for an equivalence relationship, the default value does not necessarily need to be the same but if the parameter in one of the callable has a default value then the corresponding parameter in the other callable should also have a default value. This is the main reason a manual implementation is required. And, as per https://typing.python.org/en/latest/spec/callables.html#id4, the default _type_ doesn't participate in a subtype relationship, only the optionality (required or not) participates. This means that the following two callable types are equivalent: ```py def f1(a: int = 1) -> None: ... def f2(a: int = 2) -> None: ... ``` Additionally, the name of positional-only, variadic and keyword-variadic are not required to be the same for an equivalence relation. A potential solution to avoid the manual implementation would be to only store whether a parameter has a default value or not but the type is currently required to check for assignability. ## Test plan Add tests for callable types in `is_equivalent_to.md`
1 parent 63e78b4 commit 193c381

File tree

3 files changed

+216
-3
lines changed

3 files changed

+216
-3
lines changed

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

+122
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,126 @@ class R: ...
118118
static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R]))
119119
```
120120

121+
## Callable
122+
123+
### Equivalent
124+
125+
For an equivalence relationship, the default value does not necessarily need to be the same but if
126+
the parameter in one of the callable has a default value then the corresponding parameter in the
127+
other callable should also have a default value.
128+
129+
```py
130+
from knot_extensions import CallableTypeFromFunction, is_equivalent_to, static_assert
131+
from typing import Callable
132+
133+
def f1(a: int = 1) -> None: ...
134+
def f2(a: int = 2) -> None: ...
135+
136+
static_assert(is_equivalent_to(CallableTypeFromFunction[f1], CallableTypeFromFunction[f2]))
137+
```
138+
139+
The names of the positional-only, variadic and keyword-variadic parameters does not need to be the
140+
same.
141+
142+
```py
143+
def f3(a1: int, /, *args1: int, **kwargs2: int) -> None: ...
144+
def f4(a2: int, /, *args2: int, **kwargs1: int) -> None: ...
145+
146+
static_assert(is_equivalent_to(CallableTypeFromFunction[f3], CallableTypeFromFunction[f4]))
147+
```
148+
149+
Putting it all together, the following two callables are equivalent:
150+
151+
```py
152+
def f5(a1: int, /, b: float, c: bool = False, *args1: int, d: int = 1, e: str, **kwargs1: float) -> None: ...
153+
def f6(a2: int, /, b: float, c: bool = True, *args2: int, d: int = 2, e: str, **kwargs2: float) -> None: ...
154+
155+
static_assert(is_equivalent_to(CallableTypeFromFunction[f5], CallableTypeFromFunction[f6]))
156+
```
157+
158+
### Not equivalent
159+
160+
There are multiple cases when two callable types are not equivalent which are enumerated below.
161+
162+
```py
163+
from knot_extensions import CallableTypeFromFunction, is_equivalent_to, static_assert
164+
from typing import Callable
165+
```
166+
167+
When the number of parameters is different:
168+
169+
```py
170+
def f1(a: int) -> None: ...
171+
def f2(a: int, b: int) -> None: ...
172+
173+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f1], CallableTypeFromFunction[f2]))
174+
```
175+
176+
When either of the callable types uses a gradual form for the parameters:
177+
178+
```py
179+
static_assert(not is_equivalent_to(Callable[..., None], Callable[[int], None]))
180+
static_assert(not is_equivalent_to(Callable[[int], None], Callable[..., None]))
181+
```
182+
183+
When the return types are not equivalent or absent in one or both of the callable types:
184+
185+
```py
186+
def f3(): ...
187+
def f4() -> None: ...
188+
189+
static_assert(not is_equivalent_to(Callable[[], int], Callable[[], None]))
190+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f3], CallableTypeFromFunction[f3]))
191+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f3], CallableTypeFromFunction[f4]))
192+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f4], CallableTypeFromFunction[f3]))
193+
```
194+
195+
When the parameter names are different:
196+
197+
```py
198+
def f5(a: int) -> None: ...
199+
def f6(b: int) -> None: ...
200+
201+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f5], CallableTypeFromFunction[f6]))
202+
```
203+
204+
When only one of the callable types has parameter names:
205+
206+
```py
207+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f5], Callable[[int], None]))
208+
```
209+
210+
When the parameter kinds are different:
211+
212+
```py
213+
def f7(a: int, /) -> None: ...
214+
def f8(a: int) -> None: ...
215+
216+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f7], CallableTypeFromFunction[f8]))
217+
```
218+
219+
When the annotated types of the parameters are not equivalent or absent in one or both of the
220+
callable types:
221+
222+
```py
223+
def f9(a: int) -> None: ...
224+
def f10(a: str) -> None: ...
225+
def f11(a) -> None: ...
226+
227+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f9], CallableTypeFromFunction[f10]))
228+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f10], CallableTypeFromFunction[f11]))
229+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f11], CallableTypeFromFunction[f10]))
230+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f11], CallableTypeFromFunction[f11]))
231+
```
232+
233+
When the default value for a parameter is present only in one of the callable type:
234+
235+
```py
236+
def f12(a: int) -> None: ...
237+
def f13(a: int = 2) -> None: ...
238+
239+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f12], CallableTypeFromFunction[f13]))
240+
static_assert(not is_equivalent_to(CallableTypeFromFunction[f13], CallableTypeFromFunction[f12]))
241+
```
242+
121243
[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent

crates/red_knot_python_semantic/src/types.rs

+89-3
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ impl<'db> Type<'db> {
898898
left.is_equivalent_to(db, right)
899899
}
900900
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
901+
(
902+
Type::Callable(CallableType::General(left)),
903+
Type::Callable(CallableType::General(right)),
904+
) => left.is_equivalent_to(db, right),
901905
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
902906
}
903907
}
@@ -4362,10 +4366,8 @@ impl<'db> FunctionType<'db> {
43624366

43634367
/// Convert the `FunctionType` into a [`Type::Callable`].
43644368
///
4365-
/// Returns `None` if the function is overloaded. This powers the `CallableTypeFromFunction`
4366-
/// special form from the `knot_extensions` module.
4369+
/// This powers the `CallableTypeFromFunction` special form from the `knot_extensions` module.
43674370
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
4368-
// TODO: Add support for overloaded callables
43694371
Type::Callable(CallableType::General(GeneralCallableType::new(
43704372
db,
43714373
self.signature(db).clone(),
@@ -4611,6 +4613,90 @@ impl<'db> GeneralCallableType<'db> {
46114613
.is_some_and(|return_type| return_type.is_fully_static(db))
46124614
}
46134615

4616+
/// Return `true` if `self` represents the exact same set of possible runtime objects as `other`.
4617+
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
4618+
let self_signature = self.signature(db);
4619+
let other_signature = other.signature(db);
4620+
4621+
let self_parameters = self_signature.parameters();
4622+
let other_parameters = other_signature.parameters();
4623+
4624+
if self_parameters.len() != other_parameters.len() {
4625+
return false;
4626+
}
4627+
4628+
if self_parameters.is_gradual() || other_parameters.is_gradual() {
4629+
return false;
4630+
}
4631+
4632+
// Check equivalence relationship between two optional types. If either of them is `None`,
4633+
// then it is not a fully static type which means it's not equivalent either.
4634+
let is_equivalent = |self_type: Option<Type<'db>>, other_type: Option<Type<'db>>| match (
4635+
self_type, other_type,
4636+
) {
4637+
(Some(self_type), Some(other_type)) => self_type.is_equivalent_to(db, other_type),
4638+
_ => false,
4639+
};
4640+
4641+
if !is_equivalent(self_signature.return_ty, other_signature.return_ty) {
4642+
return false;
4643+
}
4644+
4645+
for (self_parameter, other_parameter) in self_parameters.iter().zip(other_parameters) {
4646+
match (self_parameter.kind(), other_parameter.kind()) {
4647+
(
4648+
ParameterKind::PositionalOnly {
4649+
default_ty: self_default,
4650+
..
4651+
},
4652+
ParameterKind::PositionalOnly {
4653+
default_ty: other_default,
4654+
..
4655+
},
4656+
) if self_default.is_some() == other_default.is_some() => {}
4657+
4658+
(
4659+
ParameterKind::PositionalOrKeyword {
4660+
name: self_name,
4661+
default_ty: self_default,
4662+
},
4663+
ParameterKind::PositionalOrKeyword {
4664+
name: other_name,
4665+
default_ty: other_default,
4666+
},
4667+
) if self_default.is_some() == other_default.is_some()
4668+
&& self_name == other_name => {}
4669+
4670+
(ParameterKind::Variadic { .. }, ParameterKind::Variadic { .. }) => {}
4671+
4672+
(
4673+
ParameterKind::KeywordOnly {
4674+
name: self_name,
4675+
default_ty: self_default,
4676+
},
4677+
ParameterKind::KeywordOnly {
4678+
name: other_name,
4679+
default_ty: other_default,
4680+
},
4681+
) if self_default.is_some() == other_default.is_some()
4682+
&& self_name == other_name => {}
4683+
4684+
(ParameterKind::KeywordVariadic { .. }, ParameterKind::KeywordVariadic { .. }) => {}
4685+
4686+
_ => return false,
4687+
}
4688+
4689+
if !is_equivalent(
4690+
self_parameter.annotated_type(),
4691+
other_parameter.annotated_type(),
4692+
) {
4693+
return false;
4694+
}
4695+
}
4696+
4697+
true
4698+
}
4699+
46144700
/// Return `true` if `self` has exactly the same set of possible static materializations as
46154701
/// `other` (if `self` represents the same set of possible sets of possible runtime objects as
46164702
/// `other`).

crates/red_knot_python_semantic/src/types/signatures.rs

+5
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,11 @@ impl<'db> Parameter<'db> {
601601
self.annotated_ty
602602
}
603603

604+
/// Kind of the parameter.
605+
pub(crate) fn kind(&self) -> &ParameterKind<'db> {
606+
&self.kind
607+
}
608+
604609
/// Name of the parameter (if it has one).
605610
pub(crate) fn name(&self) -> Option<&ast::name::Name> {
606611
match &self.kind {

0 commit comments

Comments
 (0)