Skip to content

Commit aa64990

Browse files
committed
Specialize generic base class in generic subclass
1 parent d3fd822 commit aa64990

File tree

8 files changed

+149
-89
lines changed

8 files changed

+149
-89
lines changed

crates/red_knot_python_semantic/resources/mdtest/generics/classes.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ class Base[T]:
176176
class Sub[U](Base[U]): ...
177177

178178
reveal_type(Base[int].x) # revealed: int | None
179-
# TODO: revealed: int | None
180-
reveal_type(Sub[int].x) # revealed: T | None
179+
reveal_type(Sub[int].x) # revealed: int | None
181180
```
182181

183182
## Cyclic class definition

crates/red_knot_python_semantic/resources/mdtest/stubs/class.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ class Foo[T]: ...
1111
class Bar(Foo[Bar]): ...
1212

1313
reveal_type(Bar) # revealed: Literal[Bar]
14-
# TODO: Instead of `Literal[Foo]`, we might eventually want to show a type that involves the type parameter.
15-
reveal_type(Bar.__mro__) # revealed: tuple[Literal[Bar], Literal[Foo], Literal[object]]
14+
reveal_type(Bar.__mro__) # revealed: tuple[Literal[Bar], Foo[Bar], Literal[object]]
1615
```
1716

1817
## Access to attributes declared in stubs

crates/red_knot_python_semantic/src/types.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ impl<'db> Type<'db> {
931931
(Type::ClassLiteral(class), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
932932
.subclass_of()
933933
.into_class()
934-
.is_some_and(|target_class| class.is_subclass_of(db, target_class)),
934+
.is_some_and(|target_class| class.is_subclass_of(db, None, target_class)),
935935
(Type::GenericAlias(alias), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
936936
.subclass_of()
937937
.into_class()
@@ -1439,7 +1439,7 @@ impl<'db> Type<'db> {
14391439
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
14401440
match subclass_of_ty.subclass_of() {
14411441
ClassBase::Dynamic(_) => false,
1442-
ClassBase::Class(class_a) => !class_b.is_subclass_of(db, class_a),
1442+
ClassBase::Class(class_a) => !class_b.is_subclass_of(db, None, class_a),
14431443
}
14441444
}
14451445

@@ -1979,7 +1979,7 @@ impl<'db> Type<'db> {
19791979
"__get__" | "__set__" | "__delete__",
19801980
) => Some(Symbol::Unbound.into()),
19811981

1982-
_ => Some(class.class_member(db, name)),
1982+
_ => Some(class.class_member(db, None, name)),
19831983
}
19841984
}
19851985

crates/red_knot_python_semantic/src/types/class.rs

+99-62
Large diffs are not rendered by default.

crates/red_knot_python_semantic/src/types/class_base.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ impl<'db> ClassBase<'db> {
6767
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6868
match self.base {
6969
ClassBase::Dynamic(dynamic) => dynamic.fmt(f),
70-
ClassBase::Class(class) => write!(f, "<class '{}'>", class.name(self.db)),
70+
ClassBase::Class(class @ ClassType::NonGeneric(_)) => {
71+
write!(f, "<class '{}'>", class.name(self.db))
72+
}
73+
ClassBase::Class(ClassType::Generic(alias)) => {
74+
write!(f, "<class '{}'>", alias.display(self.db))
75+
}
7176
}
7277
}
7378
}

crates/red_knot_python_semantic/src/types/infer.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ impl<'db> TypeInferenceBuilder<'db> {
755755
}
756756

757757
// (3) Check that the class's MRO is resolvable
758-
match class.try_mro(self.db()).as_ref() {
758+
match class.try_mro(self.db(), None).as_ref() {
759759
Err(mro_error) => {
760760
match mro_error.reason() {
761761
MroErrorKind::DuplicateBases(duplicates) => {
@@ -4361,7 +4361,7 @@ impl<'db> TypeInferenceBuilder<'db> {
43614361
LookupError::Unbound(_) => {
43624362
let bound_on_instance = match value_type {
43634363
Type::ClassLiteral(class) => {
4364-
!class.instance_member(db, attr).symbol.is_unbound()
4364+
!class.instance_member(db, None, attr).symbol.is_unbound()
43654365
}
43664366
Type::SubclassOf(subclass_of @ SubclassOfType { .. }) => {
43674367
match subclass_of.subclass_of() {

crates/red_knot_python_semantic/src/types/mro.rs

+34-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::ops::Deref;
44
use rustc_hash::FxHashSet;
55

66
use crate::types::class_base::ClassBase;
7+
use crate::types::generics::Specialization;
78
use crate::types::{ClassLiteralType, ClassType, Type};
89
use crate::Db;
910

@@ -47,9 +48,11 @@ impl<'db> Mro<'db> {
4748
pub(super) fn of_class(
4849
db: &'db dyn Db,
4950
class: ClassLiteralType<'db>,
51+
specialization: Option<Specialization<'db>>,
5052
) -> Result<Self, MroError<'db>> {
51-
Self::of_class_impl(db, class)
52-
.map_err(|err| err.into_mro_error(db, class.default_specialization(db)))
53+
Self::of_class_impl(db, class, specialization).map_err(|err| {
54+
err.into_mro_error(db, class.apply_optional_specialization(db, specialization))
55+
})
5356
}
5457

5558
pub(super) fn from_error(db: &'db dyn Db, class: ClassType<'db>) -> Self {
@@ -63,22 +66,26 @@ impl<'db> Mro<'db> {
6366
fn of_class_impl(
6467
db: &'db dyn Db,
6568
class: ClassLiteralType<'db>,
69+
specialization: Option<Specialization<'db>>,
6670
) -> Result<Self, MroErrorKind<'db>> {
6771
let class_bases = class.explicit_bases(db);
6872

6973
if !class_bases.is_empty() && class.inheritance_cycle(db).is_some() {
7074
// We emit errors for cyclically defined classes elsewhere.
7175
// It's important that we don't even try to infer the MRO for a cyclically defined class,
7276
// or we'll end up in an infinite loop.
73-
return Ok(Mro::from_error(db, class.default_specialization(db)));
77+
return Ok(Mro::from_error(
78+
db,
79+
class.apply_optional_specialization(db, specialization),
80+
));
7481
}
7582

7683
match class_bases {
7784
// `builtins.object` is the special case:
7885
// the only class in Python that has an MRO with length <2
7986
[] if class.is_object(db) => Ok(Self::from([
8087
// object is not generic, so the default specialization should be a no-op
81-
ClassBase::Class(class.default_specialization(db)),
88+
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
8289
])),
8390

8491
// All other classes in Python have an MRO with length >=2.
@@ -95,7 +102,7 @@ impl<'db> Mro<'db> {
95102
// (<class '__main__.Foo'>, <class 'object'>)
96103
// ```
97104
[] => Ok(Self::from([
98-
ClassBase::Class(class.default_specialization(db)),
105+
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
99106
ClassBase::object(db),
100107
])),
101108

@@ -107,11 +114,11 @@ impl<'db> Mro<'db> {
107114
[single_base] => ClassBase::try_from_type(db, *single_base).map_or_else(
108115
|| Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))),
109116
|single_base| {
110-
Ok(
111-
std::iter::once(ClassBase::Class(class.default_specialization(db)))
112-
.chain(single_base.mro(db))
113-
.collect(),
114-
)
117+
Ok(std::iter::once(ClassBase::Class(
118+
class.apply_optional_specialization(db, specialization),
119+
))
120+
.chain(single_base.mro(db))
121+
.collect())
115122
},
116123
),
117124

@@ -136,7 +143,7 @@ impl<'db> Mro<'db> {
136143
}
137144

138145
let mut seqs = vec![VecDeque::from([ClassBase::Class(
139-
class.default_specialization(db),
146+
class.apply_optional_specialization(db, specialization),
140147
)])];
141148
for base in &valid_bases {
142149
seqs.push(base.mro(db).collect());
@@ -152,7 +159,8 @@ impl<'db> Mro<'db> {
152159
.filter_map(|(index, base)| Some((index, base.into_class()?)))
153160
{
154161
if !seen_bases.insert(base) {
155-
duplicate_bases.push((index, base.class_literal(db)));
162+
let (base_class_literal, _) = base.class_literal(db);
163+
duplicate_bases.push((index, base_class_literal));
156164
}
157165
}
158166

@@ -219,6 +227,9 @@ pub(super) struct MroIterator<'db> {
219227
/// The class whose MRO we're iterating over
220228
class: ClassLiteralType<'db>,
221229

230+
/// The specialization to apply to each MRO element, if any
231+
specialization: Option<Specialization<'db>>,
232+
222233
/// Whether or not we've already yielded the first element of the MRO
223234
first_element_yielded: bool,
224235

@@ -231,10 +242,15 @@ pub(super) struct MroIterator<'db> {
231242
}
232243

233244
impl<'db> MroIterator<'db> {
234-
pub(super) fn new(db: &'db dyn Db, class: ClassLiteralType<'db>) -> Self {
245+
pub(super) fn new(
246+
db: &'db dyn Db,
247+
class: ClassLiteralType<'db>,
248+
specialization: Option<Specialization<'db>>,
249+
) -> Self {
235250
Self {
236251
db,
237252
class,
253+
specialization,
238254
first_element_yielded: false,
239255
subsequent_elements: None,
240256
}
@@ -245,7 +261,7 @@ impl<'db> MroIterator<'db> {
245261
fn full_mro_except_first_element(&mut self) -> impl Iterator<Item = ClassBase<'db>> + '_ {
246262
self.subsequent_elements
247263
.get_or_insert_with(|| {
248-
let mut full_mro_iter = match self.class.try_mro(self.db) {
264+
let mut full_mro_iter = match self.class.try_mro(self.db, self.specialization) {
249265
Ok(mro) => mro.iter(),
250266
Err(error) => error.fallback_mro().iter(),
251267
};
@@ -262,7 +278,10 @@ impl<'db> Iterator for MroIterator<'db> {
262278
fn next(&mut self) -> Option<Self::Item> {
263279
if !self.first_element_yielded {
264280
self.first_element_yielded = true;
265-
return Some(ClassBase::Class(self.class.default_specialization(self.db)));
281+
return Some(ClassBase::Class(
282+
self.class
283+
.apply_optional_specialization(self.db, self.specialization),
284+
));
266285
}
267286
self.full_mro_except_first_element().next()
268287
}

crates/red_knot_python_semantic/src/types/slots.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ pub(super) fn check_class_slots(
6666
continue;
6767
};
6868

69-
let solid_base = base.iter_mro(db).find_map(|current| {
69+
let solid_base = base.iter_mro(db, None).find_map(|current| {
7070
let ClassBase::Class(current) = current else {
7171
return None;
7272
};
7373

74-
match SlotsKind::from(db, current.class_literal(db)) {
74+
let (class_literal, _) = current.class_literal(db);
75+
match SlotsKind::from(db, class_literal) {
7576
SlotsKind::NotEmpty => Some(current),
7677
SlotsKind::NotSpecified | SlotsKind::Empty => None,
7778
SlotsKind::Dynamic => None,

0 commit comments

Comments
 (0)