@@ -4,6 +4,7 @@ use std::ops::Deref;
4
4
use rustc_hash:: FxHashSet ;
5
5
6
6
use crate :: types:: class_base:: ClassBase ;
7
+ use crate :: types:: generics:: Specialization ;
7
8
use crate :: types:: { ClassLiteralType , ClassType , Type } ;
8
9
use crate :: Db ;
9
10
@@ -47,9 +48,11 @@ impl<'db> Mro<'db> {
47
48
pub ( super ) fn of_class (
48
49
db : & ' db dyn Db ,
49
50
class : ClassLiteralType < ' db > ,
51
+ specialization : Option < Specialization < ' db > > ,
50
52
) -> Result < Self , MroError < ' db > > {
51
- Self :: of_class_impl ( db, class)
52
- . map_err ( |err| err. into_mro_error ( db, class. default_specialization ( db) ) )
53
+ Self :: of_class_impl ( db, class, specialization) . map_err ( |err| {
54
+ err. into_mro_error ( db, class. apply_optional_specialization ( db, specialization) )
55
+ } )
53
56
}
54
57
55
58
pub ( super ) fn from_error ( db : & ' db dyn Db , class : ClassType < ' db > ) -> Self {
@@ -63,22 +66,26 @@ impl<'db> Mro<'db> {
63
66
fn of_class_impl (
64
67
db : & ' db dyn Db ,
65
68
class : ClassLiteralType < ' db > ,
69
+ specialization : Option < Specialization < ' db > > ,
66
70
) -> Result < Self , MroErrorKind < ' db > > {
67
71
let class_bases = class. explicit_bases ( db) ;
68
72
69
73
if !class_bases. is_empty ( ) && class. inheritance_cycle ( db) . is_some ( ) {
70
74
// We emit errors for cyclically defined classes elsewhere.
71
75
// It's important that we don't even try to infer the MRO for a cyclically defined class,
72
76
// or we'll end up in an infinite loop.
73
- return Ok ( Mro :: from_error ( db, class. default_specialization ( db) ) ) ;
77
+ return Ok ( Mro :: from_error (
78
+ db,
79
+ class. apply_optional_specialization ( db, specialization) ,
80
+ ) ) ;
74
81
}
75
82
76
83
match class_bases {
77
84
// `builtins.object` is the special case:
78
85
// the only class in Python that has an MRO with length <2
79
86
[ ] if class. is_object ( db) => Ok ( Self :: from ( [
80
87
// object is not generic, so the default specialization should be a no-op
81
- ClassBase :: Class ( class. default_specialization ( db) ) ,
88
+ ClassBase :: Class ( class. apply_optional_specialization ( db, specialization ) ) ,
82
89
] ) ) ,
83
90
84
91
// All other classes in Python have an MRO with length >=2.
@@ -95,7 +102,7 @@ impl<'db> Mro<'db> {
95
102
// (<class '__main__.Foo'>, <class 'object'>)
96
103
// ```
97
104
[ ] => Ok ( Self :: from ( [
98
- ClassBase :: Class ( class. default_specialization ( db) ) ,
105
+ ClassBase :: Class ( class. apply_optional_specialization ( db, specialization ) ) ,
99
106
ClassBase :: object ( db) ,
100
107
] ) ) ,
101
108
@@ -107,11 +114,11 @@ impl<'db> Mro<'db> {
107
114
[ single_base] => ClassBase :: try_from_type ( db, * single_base) . map_or_else (
108
115
|| Err ( MroErrorKind :: InvalidBases ( Box :: from ( [ ( 0 , * single_base) ] ) ) ) ,
109
116
|single_base| {
110
- Ok (
111
- std :: iter :: once ( ClassBase :: Class ( class. default_specialization ( db) ) )
112
- . chain ( single_base . mro ( db ) )
113
- . collect ( ) ,
114
- )
117
+ Ok ( std :: iter :: once ( ClassBase :: Class (
118
+ class. apply_optional_specialization ( db, specialization ) ,
119
+ ) )
120
+ . chain ( single_base . mro ( db ) )
121
+ . collect ( ) )
115
122
} ,
116
123
) ,
117
124
@@ -136,7 +143,7 @@ impl<'db> Mro<'db> {
136
143
}
137
144
138
145
let mut seqs = vec ! [ VecDeque :: from( [ ClassBase :: Class (
139
- class. default_specialization ( db) ,
146
+ class. apply_optional_specialization ( db, specialization ) ,
140
147
) ] ) ] ;
141
148
for base in & valid_bases {
142
149
seqs. push ( base. mro ( db) . collect ( ) ) ;
@@ -152,7 +159,8 @@ impl<'db> Mro<'db> {
152
159
. filter_map ( |( index, base) | Some ( ( index, base. into_class ( ) ?) ) )
153
160
{
154
161
if !seen_bases. insert ( base) {
155
- duplicate_bases. push ( ( index, base. class_literal ( db) ) ) ;
162
+ let ( base_class_literal, _) = base. class_literal ( db) ;
163
+ duplicate_bases. push ( ( index, base_class_literal) ) ;
156
164
}
157
165
}
158
166
@@ -219,6 +227,9 @@ pub(super) struct MroIterator<'db> {
219
227
/// The class whose MRO we're iterating over
220
228
class : ClassLiteralType < ' db > ,
221
229
230
+ /// The specialization to apply to each MRO element, if any
231
+ specialization : Option < Specialization < ' db > > ,
232
+
222
233
/// Whether or not we've already yielded the first element of the MRO
223
234
first_element_yielded : bool ,
224
235
@@ -231,10 +242,15 @@ pub(super) struct MroIterator<'db> {
231
242
}
232
243
233
244
impl < ' db > MroIterator < ' db > {
234
- pub ( super ) fn new ( db : & ' db dyn Db , class : ClassLiteralType < ' db > ) -> Self {
245
+ pub ( super ) fn new (
246
+ db : & ' db dyn Db ,
247
+ class : ClassLiteralType < ' db > ,
248
+ specialization : Option < Specialization < ' db > > ,
249
+ ) -> Self {
235
250
Self {
236
251
db,
237
252
class,
253
+ specialization,
238
254
first_element_yielded : false ,
239
255
subsequent_elements : None ,
240
256
}
@@ -245,7 +261,7 @@ impl<'db> MroIterator<'db> {
245
261
fn full_mro_except_first_element ( & mut self ) -> impl Iterator < Item = ClassBase < ' db > > + ' _ {
246
262
self . subsequent_elements
247
263
. get_or_insert_with ( || {
248
- let mut full_mro_iter = match self . class . try_mro ( self . db ) {
264
+ let mut full_mro_iter = match self . class . try_mro ( self . db , self . specialization ) {
249
265
Ok ( mro) => mro. iter ( ) ,
250
266
Err ( error) => error. fallback_mro ( ) . iter ( ) ,
251
267
} ;
@@ -262,7 +278,10 @@ impl<'db> Iterator for MroIterator<'db> {
262
278
fn next ( & mut self ) -> Option < Self :: Item > {
263
279
if !self . first_element_yielded {
264
280
self . first_element_yielded = true ;
265
- return Some ( ClassBase :: Class ( self . class . default_specialization ( self . db ) ) ) ;
281
+ return Some ( ClassBase :: Class (
282
+ self . class
283
+ . apply_optional_specialization ( self . db , self . specialization ) ,
284
+ ) ) ;
266
285
}
267
286
self . full_mro_except_first_element ( ) . next ( )
268
287
}
0 commit comments