Skip to content

Commit 82ce2de

Browse files
committed
fix cast() calls that pass recursive protocols as the first argument
1 parent 85cbc33 commit 82ce2de

File tree

6 files changed

+152
-55
lines changed

6 files changed

+152
-55
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,27 @@ def f(x: PGconn):
19351935
isinstance(x, Connection)
19361936
```
19371937

1938+
### Recursive protocols used as the first argument to `cast()`
1939+
1940+
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
1941+
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
1942+
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
1943+
1944+
```toml
1945+
[environment]
1946+
python-version = "3.12"
1947+
```
1948+
1949+
```py
1950+
from typing import cast, Protocol
1951+
1952+
class Iterator[T](Protocol):
1953+
def __iter__(self) -> Iterator[T]: ...
1954+
1955+
def f(value: Iterator):
1956+
cast(Iterator, value) # error: [redundant-cast]
1957+
```
1958+
19381959
## TODO
19391960

19401961
Add tests for:

crates/ty_python_semantic/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mod util;
4242
pub mod pull_types;
4343

4444
type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
45+
type FxIndexSet<V> = indexmap::IndexSet<V, BuildHasherDefault<FxHasher>>;
4546
type FxIndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
4647

4748
/// Returns the default registry with all known semantic lints.

crates/ty_python_semantic/src/types.rs

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub(crate) use crate::types::narrow::infer_narrowing_constraint;
5353
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
5454
use crate::types::tuple::{TupleSpec, TupleType};
5555
pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic;
56-
use crate::{Db, FxOrderSet, Module, Program};
56+
use crate::{Db, FxIndexSet, FxOrderSet, Module, Program};
5757
pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass};
5858
use instance::Protocol;
5959
pub use instance::{NominalInstanceType, ProtocolInstanceType};
@@ -415,12 +415,18 @@ impl<'db> PropertyInstanceType<'db> {
415415
)
416416
}
417417

418-
fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
418+
fn any_over_type_impl(
419+
self,
420+
db: &'db dyn Db,
421+
type_fn: &dyn Fn(Type<'db>) -> bool,
422+
recursive_fallback: bool,
423+
seen_types: &mut FxIndexSet<Type<'db>>,
424+
) -> bool {
419425
self.getter(db)
420-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
421-
|| self
422-
.setter(db)
423-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
426+
.is_some_and(|ty| ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types))
427+
|| self.setter(db).is_some_and(|ty| {
428+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
429+
})
424430
}
425431
}
426432

@@ -755,8 +761,27 @@ impl<'db> Type<'db> {
755761
}
756762
}
757763

764+
pub(crate) fn any_over_type(
765+
self,
766+
db: &'db dyn Db,
767+
type_fn: &dyn Fn(Type<'db>) -> bool,
768+
recursive_fallback: bool,
769+
) -> bool {
770+
self.any_over_type_impl(db, type_fn, recursive_fallback, &mut FxIndexSet::default())
771+
}
772+
758773
/// Return `true` if `self`, or any of the types contained in `self`, match the closure passed in.
759-
pub fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
774+
pub(crate) fn any_over_type_impl(
775+
self,
776+
db: &'db dyn Db,
777+
type_fn: &dyn Fn(Type<'db>) -> bool,
778+
recursive_fallback: bool,
779+
seen_types: &mut FxIndexSet<Type<'db>>,
780+
) -> bool {
781+
if !seen_types.insert(self) {
782+
return recursive_fallback;
783+
}
784+
760785
if type_fn(self) {
761786
return true;
762787
}
@@ -787,64 +812,83 @@ impl<'db> Type<'db> {
787812
.types(db)
788813
.iter()
789814
.copied()
790-
.any(|ty| ty.any_over_type(db, type_fn)),
791-
792-
Self::Callable(callable) => callable.any_over_type(db, type_fn),
815+
.any(|ty| ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)),
793816

794-
Self::SubclassOf(subclass_of) => {
795-
Type::from(subclass_of.subclass_of()).any_over_type(db, type_fn)
817+
Self::Callable(callable) => {
818+
callable.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
796819
}
797820

821+
Self::SubclassOf(subclass_of) => Type::from(subclass_of.subclass_of())
822+
.any_over_type_impl(db, type_fn, recursive_fallback, seen_types),
823+
798824
Self::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
799825
None => false,
800826
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
801-
bound.any_over_type(db, type_fn)
827+
bound.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
828+
}
829+
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
830+
constraints.elements(db).iter().any(|constraint| {
831+
constraint.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
832+
})
802833
}
803-
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => constraints
804-
.elements(db)
805-
.iter()
806-
.any(|constraint| constraint.any_over_type(db, type_fn)),
807834
},
808835

809836
Self::BoundSuper(bound_super) => {
810-
Type::from(bound_super.pivot_class(db)).any_over_type(db, type_fn)
811-
|| Type::from(bound_super.owner(db)).any_over_type(db, type_fn)
837+
Type::from(bound_super.pivot_class(db)).any_over_type_impl(
838+
db,
839+
type_fn,
840+
recursive_fallback,
841+
seen_types,
842+
) || Type::from(bound_super.owner(db)).any_over_type_impl(
843+
db,
844+
type_fn,
845+
recursive_fallback,
846+
seen_types,
847+
)
812848
}
813849

814850
Self::Tuple(tuple) => tuple
815851
.tuple(db)
816852
.all_elements()
817-
.any(|ty| ty.any_over_type(db, type_fn)),
853+
.any(|ty| ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)),
818854

819855
Self::Union(union) => union
820856
.elements(db)
821857
.iter()
822-
.any(|ty| ty.any_over_type(db, type_fn)),
858+
.any(|ty| ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)),
823859

824860
Self::Intersection(intersection) => {
825861
intersection
826862
.positive(db)
827863
.iter()
828-
.any(|ty| ty.any_over_type(db, type_fn))
829-
|| intersection
830-
.negative(db)
831-
.iter()
832-
.any(|ty| ty.any_over_type(db, type_fn))
864+
.any(|ty| ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types))
865+
|| intersection.negative(db).iter().any(|ty| {
866+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
867+
})
833868
}
834869

835-
Self::ProtocolInstance(protocol) => protocol.any_over_type(db, type_fn),
836-
Self::PropertyInstance(property) => property.any_over_type(db, type_fn),
870+
Self::ProtocolInstance(protocol) => {
871+
protocol.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
872+
}
873+
Self::PropertyInstance(property) => {
874+
property.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
875+
}
837876

838877
Self::NominalInstance(instance) => match instance.class {
839878
ClassType::NonGeneric(_) => false,
840-
ClassType::Generic(generic) => generic
841-
.specialization(db)
842-
.types(db)
843-
.iter()
844-
.any(|ty| ty.any_over_type(db, type_fn)),
879+
ClassType::Generic(generic) => {
880+
generic.specialization(db).types(db).iter().any(|ty| {
881+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
882+
})
883+
}
845884
},
846885

847-
Self::TypeIs(type_is) => type_is.return_type(db).any_over_type(db, type_fn),
886+
Self::TypeIs(type_is) => type_is.return_type(db).any_over_type_impl(
887+
db,
888+
type_fn,
889+
recursive_fallback,
890+
seen_types,
891+
),
848892
}
849893
}
850894

@@ -7298,15 +7342,21 @@ impl<'db> CallableType<'db> {
72987342
.is_equivalent_to(db, other.signatures(db))
72997343
}
73007344

7301-
fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
7345+
fn any_over_type_impl(
7346+
self,
7347+
db: &'db dyn Db,
7348+
type_fn: &dyn Fn(Type<'db>) -> bool,
7349+
recursive_fallback: bool,
7350+
seen_types: &mut FxIndexSet<Type<'db>>,
7351+
) -> bool {
73027352
self.signatures(db).iter().any(|signature| {
73037353
signature.parameters().iter().any(|param| {
7304-
param
7305-
.annotated_type()
7306-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
7307-
}) || signature
7308-
.return_ty
7309-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
7354+
param.annotated_type().is_some_and(|ty| {
7355+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
7356+
})
7357+
}) || signature.return_ty.is_some_and(|ty| {
7358+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
7359+
})
73107360
})
73117361
}
73127362
}

crates/ty_python_semantic/src/types/function.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,11 +1145,18 @@ impl KnownFunction {
11451145
let [Some(casted_type), Some(source_type)] = parameter_types else {
11461146
return None;
11471147
};
1148-
let contains_unknown_or_todo =
1149-
|ty| matches!(ty, Type::Dynamic(dynamic) if dynamic != DynamicType::Any);
1148+
1149+
let contains_unknown_or_todo = |ty: Type<'db>| {
1150+
ty.any_over_type(
1151+
db,
1152+
&|ty| matches!(ty, Type::Dynamic(dynamic) if dynamic != DynamicType::Any),
1153+
false,
1154+
)
1155+
};
1156+
11501157
if source_type.is_equivalent_to(db, *casted_type)
1151-
&& !casted_type.any_over_type(db, &|ty| contains_unknown_or_todo(ty))
1152-
&& !source_type.any_over_type(db, &|ty| contains_unknown_or_todo(ty))
1158+
&& !contains_unknown_or_todo(*casted_type)
1159+
&& !contains_unknown_or_todo(*source_type)
11531160
{
11541161
let builder = context.report_lint(&REDUNDANT_CAST, call_expression)?;
11551162
builder.into_diagnostic(format_args!(

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
77
use crate::place::PlaceAndQualifiers;
88
use crate::types::tuple::TupleType;
99
use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeVarInstance, TypeVisitor};
10-
use crate::{Db, FxOrderSet};
10+
use crate::{Db, FxIndexSet, FxOrderSet};
1111

1212
pub(super) use synthesized_protocol::SynthesizedProtocolType;
1313

@@ -230,12 +230,16 @@ impl<'db> ProtocolInstanceType<'db> {
230230
}
231231

232232
/// Return `true` if the types of any of the members match the closure passed in.
233-
pub(super) fn any_over_type(
233+
pub(super) fn any_over_type_impl(
234234
self,
235235
db: &'db dyn Db,
236236
type_fn: &dyn Fn(Type<'db>) -> bool,
237+
recursive_fallback: bool,
238+
seen_types: &mut FxIndexSet<Type<'db>>,
237239
) -> bool {
238-
self.inner.interface(db).any_over_type(db, type_fn)
240+
self.inner
241+
.interface(db)
242+
.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
239243
}
240244

241245
/// Return `true` if this protocol type has the given type relation to the protocol `other`.

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use itertools::Itertools;
55
use ruff_python_ast::name::Name;
66

77
use crate::{
8-
Db, FxOrderSet,
8+
Db, FxIndexSet, FxOrderSet,
99
place::{Boundness, Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
1010
semantic_index::{place_table, use_def_map},
1111
types::{
@@ -153,13 +153,15 @@ impl<'db> ProtocolInterface<'db> {
153153
}
154154

155155
/// Return `true` if the types of any of the members match the closure passed in.
156-
pub(super) fn any_over_type(
156+
pub(super) fn any_over_type_impl(
157157
self,
158158
db: &'db dyn Db,
159159
type_fn: &dyn Fn(Type<'db>) -> bool,
160+
recursive_fallback: bool,
161+
seen_types: &mut FxIndexSet<Type<'db>>,
160162
) -> bool {
161163
self.members(db)
162-
.any(|member| member.any_over_type(db, type_fn))
164+
.any(|member| member.any_over_type_impl(db, type_fn, recursive_fallback, seen_types))
163165
}
164166

165167
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &mut TypeVisitor<'db>) -> Self {
@@ -372,11 +374,23 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
372374
}
373375
}
374376

375-
fn any_over_type(&self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
377+
fn any_over_type_impl(
378+
&self,
379+
db: &'db dyn Db,
380+
type_fn: &dyn Fn(Type<'db>) -> bool,
381+
recursive_fallback: bool,
382+
seen_types: &mut FxIndexSet<Type<'db>>,
383+
) -> bool {
376384
match &self.kind {
377-
ProtocolMemberKind::Method(callable) => callable.any_over_type(db, type_fn),
378-
ProtocolMemberKind::Property(property) => property.any_over_type(db, type_fn),
379-
ProtocolMemberKind::Other(ty) => ty.any_over_type(db, type_fn),
385+
ProtocolMemberKind::Method(callable) => {
386+
callable.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
387+
}
388+
ProtocolMemberKind::Property(property) => {
389+
property.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
390+
}
391+
ProtocolMemberKind::Other(ty) => {
392+
ty.any_over_type_impl(db, type_fn, recursive_fallback, seen_types)
393+
}
380394
}
381395
}
382396
}

0 commit comments

Comments
 (0)