Skip to content

Commit 058439d

Browse files
authored
[syntax-errors] Async comprehension in sync comprehension (#17177)
Summary -- Detect async comprehensions nested in sync comprehensions in async functions before Python 3.11, when this was [changed]. The actual logic of this rule is very straightforward, but properly tracking the async scopes took a bit of work. An alternative to the current approach is to offload the `in_async_context` check into the `SemanticSyntaxContext` trait, but that actually required much more extensive changes to the `TestContext` and also to ruff's semantic model, as you can see in the changes up to 31554b4. This version has the benefit of mostly centralizing the state tracking in `SemanticSyntaxChecker`, although there was some subtlety around deferred function body traversal that made the changes to `Checker` more intrusive too (hence the new linter test). The `Checkpoint` struct/system is obviously overkill for now since it's only tracking a single `bool`, but I thought it might be more useful later. [changed]: python/cpython#77527 Test Plan -- New inline tests and a new linter integration test.
1 parent dc02732 commit 058439d

18 files changed

+2076
-28
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "6a70904e-dbfe-441c-99ec-12e6cf57f8ba",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"async def elements(n): yield n"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "5412fc2f-76eb-42c0-8db1-b5af6fdc46aa",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"[x async for x in elements(5)] # okay, async at top level"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "dc3c94a7-2e64-42de-9351-260b3f41c3fd",
27+
"metadata": {
28+
"scrolled": true
29+
},
30+
"outputs": [],
31+
"source": [
32+
"[[x async for x in elements(5)] for i in range(5)] # error on 3.10, okay after"
33+
]
34+
}
35+
],
36+
"metadata": {
37+
"kernelspec": {
38+
"display_name": "Python 3 (ipykernel)",
39+
"language": "python",
40+
"name": "python3"
41+
},
42+
"language_info": {
43+
"codemirror_mode": {
44+
"name": "ipython",
45+
"version": 3
46+
},
47+
"file_extension": ".py",
48+
"mimetype": "text/x-python",
49+
"name": "python",
50+
"nbconvert_exporter": "python",
51+
"pygments_lexer": "ipython3",
52+
"version": "3.10.16"
53+
}
54+
},
55+
"nbformat": 4,
56+
"nbformat_minor": 5
57+
}

crates/ruff_linter/src/checkers/ast/mod.rs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use std::path::Path;
2727
use itertools::Itertools;
2828
use log::debug;
2929
use ruff_python_parser::semantic_errors::{
30-
SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError, SemanticSyntaxErrorKind,
30+
Checkpoint, SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError,
31+
SemanticSyntaxErrorKind,
3132
};
3233
use rustc_hash::{FxHashMap, FxHashSet};
3334

@@ -282,7 +283,7 @@ impl<'a> Checker<'a> {
282283
last_stmt_end: TextSize::default(),
283284
docstring_state: DocstringState::default(),
284285
target_version,
285-
semantic_checker: SemanticSyntaxChecker::new(),
286+
semantic_checker: SemanticSyntaxChecker::new(source_type),
286287
semantic_errors: RefCell::default(),
287288
}
288289
}
@@ -525,10 +526,14 @@ impl<'a> Checker<'a> {
525526
self.target_version
526527
}
527528

528-
fn with_semantic_checker(&mut self, f: impl FnOnce(&mut SemanticSyntaxChecker, &Checker)) {
529+
fn with_semantic_checker(
530+
&mut self,
531+
f: impl FnOnce(&mut SemanticSyntaxChecker, &Checker) -> Checkpoint,
532+
) -> Checkpoint {
529533
let mut checker = std::mem::take(&mut self.semantic_checker);
530-
f(&mut checker, self);
534+
let checkpoint = f(&mut checker, self);
531535
self.semantic_checker = checker;
536+
checkpoint
532537
}
533538
}
534539

@@ -576,7 +581,8 @@ impl SemanticSyntaxContext for Checker<'_> {
576581
| SemanticSyntaxErrorKind::InvalidExpression(..)
577582
| SemanticSyntaxErrorKind::DuplicateMatchKey(_)
578583
| SemanticSyntaxErrorKind::DuplicateMatchClassAttribute(_)
579-
| SemanticSyntaxErrorKind::InvalidStarExpression => {
584+
| SemanticSyntaxErrorKind::InvalidStarExpression
585+
| SemanticSyntaxErrorKind::AsyncComprehensionOutsideAsyncFunction(_) => {
580586
if self.settings.preview.is_enabled() {
581587
self.semantic_errors.borrow_mut().push(error);
582588
}
@@ -595,7 +601,13 @@ impl SemanticSyntaxContext for Checker<'_> {
595601

596602
impl<'a> Visitor<'a> for Checker<'a> {
597603
fn visit_stmt(&mut self, stmt: &'a Stmt) {
598-
self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context));
604+
// For functions, defer semantic syntax error checks until the body of the function is
605+
// visited
606+
let checkpoint = if stmt.is_function_def_stmt() {
607+
None
608+
} else {
609+
Some(self.with_semantic_checker(|semantic, context| semantic.enter_stmt(stmt, context)))
610+
};
599611

600612
// Step 0: Pre-processing
601613
self.semantic.push_node(stmt);
@@ -1198,6 +1210,10 @@ impl<'a> Visitor<'a> for Checker<'a> {
11981210
self.semantic.flags = flags_snapshot;
11991211
self.semantic.pop_node();
12001212
self.last_stmt_end = stmt.end();
1213+
1214+
if let Some(checkpoint) = checkpoint {
1215+
self.semantic_checker.exit_stmt(checkpoint);
1216+
}
12011217
}
12021218

12031219
fn visit_annotation(&mut self, expr: &'a Expr) {
@@ -1208,7 +1224,8 @@ impl<'a> Visitor<'a> for Checker<'a> {
12081224
}
12091225

12101226
fn visit_expr(&mut self, expr: &'a Expr) {
1211-
self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context));
1227+
let checkpoint =
1228+
self.with_semantic_checker(|semantic, context| semantic.enter_expr(expr, context));
12121229

12131230
// Step 0: Pre-processing
12141231
if self.source_type.is_stub()
@@ -1755,6 +1772,8 @@ impl<'a> Visitor<'a> for Checker<'a> {
17551772
self.semantic.flags = flags_snapshot;
17561773
analyze::expression(expr, self);
17571774
self.semantic.pop_node();
1775+
1776+
self.semantic_checker.exit_expr(checkpoint);
17581777
}
17591778

17601779
fn visit_except_handler(&mut self, except_handler: &'a ExceptHandler) {
@@ -2590,17 +2609,24 @@ impl<'a> Checker<'a> {
25902609
for snapshot in deferred_functions {
25912610
self.semantic.restore(snapshot);
25922611

2612+
let stmt = self.semantic.current_statement();
2613+
25932614
let Stmt::FunctionDef(ast::StmtFunctionDef {
25942615
body, parameters, ..
2595-
}) = self.semantic.current_statement()
2616+
}) = stmt
25962617
else {
25972618
unreachable!("Expected Stmt::FunctionDef")
25982619
};
25992620

2621+
let checkpoint = self
2622+
.with_semantic_checker(|semantic, context| semantic.enter_stmt(stmt, context));
2623+
26002624
self.visit_parameters(parameters);
26012625
// Set the docstring state before visiting the function body.
26022626
self.docstring_state = DocstringState::Expected(ExpectedDocstringKind::Function);
26032627
self.visit_body(body);
2628+
2629+
self.semantic_checker.exit_stmt(checkpoint);
26042630
}
26052631
}
26062632
self.semantic.restore(snapshot);

crates/ruff_linter/src/linter.rs

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,14 +777,22 @@ mod tests {
777777
use std::path::Path;
778778

779779
use anyhow::Result;
780+
use ruff_python_ast::{PySourceType, PythonVersion};
781+
use ruff_python_codegen::Stylist;
782+
use ruff_python_index::Indexer;
783+
use ruff_python_parser::ParseOptions;
784+
use ruff_python_trivia::textwrap::dedent;
785+
use ruff_text_size::Ranged;
780786
use test_case::test_case;
781787

782788
use ruff_notebook::{Notebook, NotebookError};
783789

790+
use crate::linter::check_path;
791+
use crate::message::Message;
784792
use crate::registry::Rule;
785793
use crate::source_kind::SourceKind;
786794
use crate::test::{assert_notebook_path, test_contents, TestedNotebook};
787-
use crate::{assert_messages, settings};
795+
use crate::{assert_messages, directives, settings, Locator};
788796

789797
/// Construct a path to a Jupyter notebook in the `resources/test/fixtures/jupyter` directory.
790798
fn notebook_path(path: impl AsRef<Path>) -> std::path::PathBuf {
@@ -934,4 +942,122 @@ mod tests {
934942
}
935943
Ok(())
936944
}
945+
946+
/// Wrapper around `test_contents_syntax_errors` for testing a snippet of code instead of a
947+
/// file.
948+
fn test_snippet_syntax_errors(
949+
contents: &str,
950+
settings: &settings::LinterSettings,
951+
) -> Vec<Message> {
952+
let contents = dedent(contents);
953+
test_contents_syntax_errors(
954+
&SourceKind::Python(contents.to_string()),
955+
Path::new("<filename>"),
956+
settings,
957+
)
958+
}
959+
960+
/// A custom test runner that prints syntax errors in addition to other diagnostics. Adapted
961+
/// from `flakes` in pyflakes/mod.rs.
962+
fn test_contents_syntax_errors(
963+
source_kind: &SourceKind,
964+
path: &Path,
965+
settings: &settings::LinterSettings,
966+
) -> Vec<Message> {
967+
let source_type = PySourceType::from(path);
968+
let options =
969+
ParseOptions::from(source_type).with_target_version(settings.unresolved_target_version);
970+
let parsed = ruff_python_parser::parse_unchecked(source_kind.source_code(), options)
971+
.try_into_module()
972+
.expect("PySourceType always parses into a module");
973+
let locator = Locator::new(source_kind.source_code());
974+
let stylist = Stylist::from_tokens(parsed.tokens(), locator.contents());
975+
let indexer = Indexer::from_tokens(parsed.tokens(), locator.contents());
976+
let directives = directives::extract_directives(
977+
parsed.tokens(),
978+
directives::Flags::from_settings(settings),
979+
&locator,
980+
&indexer,
981+
);
982+
let mut messages = check_path(
983+
path,
984+
None,
985+
&locator,
986+
&stylist,
987+
&indexer,
988+
&directives,
989+
settings,
990+
settings::flags::Noqa::Enabled,
991+
source_kind,
992+
source_type,
993+
&parsed,
994+
settings.unresolved_target_version,
995+
);
996+
messages.sort_by_key(Ranged::start);
997+
messages
998+
}
999+
1000+
#[test_case(
1001+
"error_on_310",
1002+
"async def f(): return [[x async for x in foo(n)] for n in range(3)]",
1003+
PythonVersion::PY310
1004+
)]
1005+
#[test_case(
1006+
"okay_on_311",
1007+
"async def f(): return [[x async for x in foo(n)] for n in range(3)]",
1008+
PythonVersion::PY311
1009+
)]
1010+
#[test_case(
1011+
"okay_on_310",
1012+
"async def test(): return [[x async for x in elements(n)] async for n in range(3)]",
1013+
PythonVersion::PY310
1014+
)]
1015+
#[test_case(
1016+
"deferred_function_body",
1017+
"
1018+
async def f(): [x for x in foo()] and [x async for x in foo()]
1019+
async def f():
1020+
def g(): ...
1021+
[x async for x in foo()]
1022+
",
1023+
PythonVersion::PY310
1024+
)]
1025+
fn test_async_comprehension_in_sync_comprehension(
1026+
name: &str,
1027+
contents: &str,
1028+
python_version: PythonVersion,
1029+
) {
1030+
let snapshot = format!("async_comprehension_in_sync_comprehension_{name}_{python_version}");
1031+
let messages = test_snippet_syntax_errors(
1032+
contents,
1033+
&settings::LinterSettings {
1034+
rules: settings::rule_table::RuleTable::empty(),
1035+
unresolved_target_version: python_version,
1036+
preview: settings::types::PreviewMode::Enabled,
1037+
..Default::default()
1038+
},
1039+
);
1040+
assert_messages!(snapshot, messages);
1041+
}
1042+
1043+
#[test_case(PythonVersion::PY310)]
1044+
#[test_case(PythonVersion::PY311)]
1045+
fn test_async_comprehension_notebook(python_version: PythonVersion) -> Result<()> {
1046+
let snapshot =
1047+
format!("async_comprehension_in_sync_comprehension_notebook_{python_version}");
1048+
let path = Path::new("resources/test/fixtures/syntax_errors/async_comprehension.ipynb");
1049+
let messages = test_contents_syntax_errors(
1050+
&SourceKind::IpyNotebook(Notebook::from_path(path)?),
1051+
path,
1052+
&settings::LinterSettings {
1053+
unresolved_target_version: python_version,
1054+
rules: settings::rule_table::RuleTable::empty(),
1055+
preview: settings::types::PreviewMode::Enabled,
1056+
..Default::default()
1057+
},
1058+
);
1059+
assert_messages!(snapshot, messages);
1060+
1061+
Ok(())
1062+
}
9371063
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
<filename>:1:27: SyntaxError: cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax was added in 3.11)
5+
|
6+
1 | async def f(): return [[x async for x in foo(n)] for n in range(3)]
7+
| ^^^^^^^^^^^^^^^^^^^^^
8+
|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
resources/test/fixtures/syntax_errors/async_comprehension.ipynb:3:5: SyntaxError: cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax was added in 3.11)
5+
|
6+
1 | async def elements(n): yield n
7+
2 | [x async for x in elements(5)] # okay, async at top level
8+
3 | [[x async for x in elements(5)] for i in range(5)] # error on 3.10, okay after
9+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
10+
|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
source: crates/ruff_linter/src/linter.rs
3+
---
4+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# parse_options: {"target-version": "3.10"}
2+
async def f(): return [[x async for x in foo(n)] for n in range(3)] # list
3+
async def g(): return [{x: 1 async for x in foo(n)} for n in range(3)] # dict
4+
async def h(): return [{x async for x in foo(n)} for n in range(3)] # set
5+
async def i(): return [([y async for y in range(1)], [z for z in range(2)]) for x in range(5)]
6+
async def j(): return [([y for y in range(1)], [z async for z in range(2)]) for x in range(5)]

0 commit comments

Comments
 (0)