Skip to content

Commit ca5184f

Browse files
committed
further improve handling of walruses in lambdas and comprehensions
1 parent 79bd553 commit ca5184f

File tree

4 files changed

+79
-48
lines changed

4 files changed

+79
-48
lines changed

crates/red_knot_python_semantic/resources/mdtest/import/star.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ print((
212212
))
213213
```
214214

215-
### Definitions in comprehension-like scopes are not global definitions
215+
### Definitions in function-like scopes are not global definitions
216216

217-
Except for some cases involving walrus expressions.
217+
Except for some cases involving walrus expressions inside comprehension scopes.
218218

219219
`a.py`:
220220

@@ -231,10 +231,13 @@ class Iterable:
231231
{b for b in Iterable()}
232232
{c: c for c in Iterable()}
233233
(d for d in Iterable())
234+
lambda e: (f := 42)
234235

235-
[(e := f * 2) for f in Iterable()]
236-
[g for h in Iterable() if (g := h - 10) > 0]
237-
{(i := j * 2): (k := j * 3) for j in Iterable()}
236+
# Definitions created by walruses in a comprehension scope are unique;
237+
# they "leak out" of the scope and are stored in the surrounding scope
238+
[(g := h * 2) for h in Iterable()]
239+
[i for j in Iterable() if (i := j - 10) > 0]
240+
{(k := l * 2): (m := l * 3) for l in Iterable()}
238241
```
239242

240243
`b.py`:
@@ -251,6 +254,8 @@ reveal_type(c) # revealed: Unknown
251254
# error: [unresolved-reference]
252255
reveal_type(d) # revealed: Unknown
253256
# error: [unresolved-reference]
257+
reveal_type(e) # revealed: Unknown
258+
# error: [unresolved-reference]
254259
reveal_type(f) # revealed: Unknown
255260
# error: [unresolved-reference]
256261
reveal_type(h) # revealed: Unknown
@@ -260,10 +265,10 @@ reveal_type(j) # revealed: Unknown
260265
# TODO: these should all reveal `int`
261266
# (we don't generally model elsewhere in red-knot that bindings from walruses
262267
# "leak" from comprehension scopes into outer scopes, but we should)
263-
reveal_type(e) # revealed: Unknown
264268
reveal_type(g) # revealed: Unknown
265269
reveal_type(i) # revealed: Unknown
266270
reveal_type(k) # revealed: Unknown
271+
reveal_type(m) # revealed: Unknown
267272
```
268273

269274
### An annotation without a value is a definition in a stub but not a `.py` file

crates/red_knot_python_semantic/src/semantic_index/re_exports.rs

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -221,48 +221,46 @@ impl<'db> Visitor<'db> for ExportFinder<'db> {
221221
}
222222
}
223223

224-
ast::Expr::SetComp(ast::ExprSetComp {
225-
elt,
226-
generators,
227-
range: _,
228-
})
229-
| ast::Expr::ListComp(ast::ExprListComp {
230-
elt,
231-
generators,
232-
range: _,
233-
})
234-
| ast::Expr::Generator(ast::ExprGenerator {
235-
elt,
236-
generators,
237-
range: _,
238-
parenthesized: _,
239-
}) => {
240-
let mut walrus_finder = WalrusFinder {
241-
export_finder: self,
242-
};
243-
walrus_finder.visit_expr(elt);
244-
for generator in generators {
245-
walrus_finder.visit_comprehension(generator);
246-
}
247-
}
224+
ast::Expr::Lambda(_)
225+
| ast::Expr::BooleanLiteral(_)
226+
| ast::Expr::NoneLiteral(_)
227+
| ast::Expr::NumberLiteral(_)
228+
| ast::Expr::BytesLiteral(_)
229+
| ast::Expr::EllipsisLiteral(_)
230+
| ast::Expr::StringLiteral(_) => {}
248231

249-
ast::Expr::DictComp(ast::ExprDictComp {
250-
key,
251-
value,
252-
generators,
253-
range: _,
254-
}) => {
232+
// Walrus definitions "leak" from comprehension scopes into the comprehension's
233+
// enclosing scope; they thus need special handling
234+
ast::Expr::SetComp(_)
235+
| ast::Expr::ListComp(_)
236+
| ast::Expr::Generator(_)
237+
| ast::Expr::DictComp(_) => {
255238
let mut walrus_finder = WalrusFinder {
256239
export_finder: self,
257240
};
258-
walrus_finder.visit_expr(key);
259-
walrus_finder.visit_expr(value);
260-
for generator in generators {
261-
walrus_finder.visit_comprehension(generator);
262-
}
241+
walk_expr(&mut walrus_finder, expr);
263242
}
264243

265-
_ => walk_expr(self, expr),
244+
ast::Expr::BoolOp(_)
245+
| ast::Expr::Named(_)
246+
| ast::Expr::BinOp(_)
247+
| ast::Expr::UnaryOp(_)
248+
| ast::Expr::If(_)
249+
| ast::Expr::Attribute(_)
250+
| ast::Expr::Subscript(_)
251+
| ast::Expr::Starred(_)
252+
| ast::Expr::Call(_)
253+
| ast::Expr::Compare(_)
254+
| ast::Expr::Yield(_)
255+
| ast::Expr::YieldFrom(_)
256+
| ast::Expr::FString(_)
257+
| ast::Expr::Tuple(_)
258+
| ast::Expr::List(_)
259+
| ast::Expr::Slice(_)
260+
| ast::Expr::IpyEscapeCommand(_)
261+
| ast::Expr::Dict(_)
262+
| ast::Expr::Set(_)
263+
| ast::Expr::Await(_) => walk_expr(self, expr),
266264
}
267265
}
268266
}
@@ -274,10 +272,20 @@ struct WalrusFinder<'a, 'db> {
274272
impl<'db> Visitor<'db> for WalrusFinder<'_, 'db> {
275273
fn visit_expr(&mut self, expr: &'db ast::Expr) {
276274
match expr {
275+
// anything that creates a nested scope or that cannot contain a walrus
276+
// can be short-circuited
277277
ast::Expr::DictComp(_)
278278
| ast::Expr::SetComp(_)
279279
| ast::Expr::ListComp(_)
280-
| ast::Expr::Generator(_) => {}
280+
| ast::Expr::Generator(_)
281+
| ast::Expr::Lambda(_)
282+
| ast::Expr::BooleanLiteral(_)
283+
| ast::Expr::NoneLiteral(_)
284+
| ast::Expr::NumberLiteral(_)
285+
| ast::Expr::BytesLiteral(_)
286+
| ast::Expr::EllipsisLiteral(_)
287+
| ast::Expr::StringLiteral(_)
288+
| ast::Expr::Name(_) => {}
281289

282290
ast::Expr::Named(ast::ExprNamed {
283291
target,
@@ -294,7 +302,25 @@ impl<'db> Visitor<'db> for WalrusFinder<'_, 'db> {
294302
}
295303
}
296304

297-
_ => walk_expr(self, expr),
305+
ast::Expr::BoolOp(_)
306+
| ast::Expr::BinOp(_)
307+
| ast::Expr::UnaryOp(_)
308+
| ast::Expr::If(_)
309+
| ast::Expr::Attribute(_)
310+
| ast::Expr::Subscript(_)
311+
| ast::Expr::Starred(_)
312+
| ast::Expr::Call(_)
313+
| ast::Expr::Compare(_)
314+
| ast::Expr::Yield(_)
315+
| ast::Expr::YieldFrom(_)
316+
| ast::Expr::FString(_)
317+
| ast::Expr::Tuple(_)
318+
| ast::Expr::List(_)
319+
| ast::Expr::Slice(_)
320+
| ast::Expr::IpyEscapeCommand(_)
321+
| ast::Expr::Dict(_)
322+
| ast::Expr::Set(_)
323+
| ast::Expr::Await(_) => walk_expr(self, expr),
298324
}
299325
}
300326
}

crates/red_knot_python_semantic/src/semantic_model.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,8 @@ macro_rules! impl_binding_has_ty {
149149
#[inline]
150150
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
151151
let index = semantic_index(model.db, model.file);
152-
let definitions = index.definitions(self);
153-
debug_assert_eq!(definitions.len(), 1);
154-
binding_type(model.db, definitions[0])
152+
let binding = index.expect_single_definition(self);
153+
binding_type(model.db, binding)
155154
}
156155
}
157156
};

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,8 @@ impl<'db> TypeInferenceBuilder<'db> {
13391339

13401340
fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey> + std::fmt::Debug + Copy) {
13411341
let definition = self.index.expect_single_definition(node);
1342-
self.extend(infer_definition_types(self.db(), definition));
1342+
let result = infer_definition_types(self.db(), definition);
1343+
self.extend(result);
13431344
}
13441345

13451346
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {

0 commit comments

Comments
 (0)