Skip to content

Commit 73d2c0a

Browse files
committed
[syntax-errors] Allow yield in base classes and annotations
Summary -- This PR fixes the issue pointed out by @JelleZijlstra in #17101 (comment). Namely, I conflated two very different errors from CPython: ```pycon >>> def m[T](x: (yield from 1)): ... File "<python-input-310>", line 1 def m[T](x: (yield from 1)): ... ^^^^^^^^^^^^ SyntaxError: yield expression cannot be used within the definition of a generic >>> def m(x: (yield from 1)): ... File "<python-input-311>", line 1 def m(x: (yield from 1)): ... ^^^^^^^^^^^^ SyntaxError: 'yield from' outside function >>> def outer(): ... def m(x: (yield from 1)): ... ... >>> ``` I thought the second error was the same as the first, but `yield` (and `yield from`) is actually valid in this position when inside a function scope. The same is true for base classes, as pointed out in the original comment. We don't currently raise an error for `yield` outside of a function, but that should be handled separately. On the upside, this had the benefit of removing the `InvalidExpressionPosition::BaseClass` variant, the `InvalidExpressionPosition::TypeAnnotation` variant, and the `allow_named_expr` field from the visitor because they were all no longer used. Test Plan -- Updated inline tests.
1 parent 5cee346 commit 73d2c0a

12 files changed

+611
-690
lines changed

crates/ruff_python_parser/resources/inline/err/invalid_annotation_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
class F[T](y := list): ...
2-
class G((yield 1)): ...
3-
class H((yield from 1)): ...
42
class I[T]((yield 1)): ...
53
class J[T]((yield from 1)): ...
64
class K[T: (yield 1)]: ... # yield in TypeVar

crates/ruff_python_parser/resources/inline/err/invalid_annotation_function.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
def f[T]() -> (y := 3): ...
22
def g[T](arg: (x := 1)): ...
33
def h[T](x: (yield 1)): ...
4-
def i(x: (yield 1)): ...
54
def j[T]() -> (yield 1): ...
6-
def k() -> (yield 1): ...
75
def l[T](x: (yield from 1)): ...
8-
def m(x: (yield from 1)): ...
96
def n[T]() -> (yield from 1): ...
10-
def o() -> (yield from 1): ...
117
def p[T: (yield 1)](): ... # yield in TypeVar bound
128
def q[T = (yield 1)](): ... # yield in TypeVar default
139
def r[*Ts = (yield 1)](): ... # yield in TypeVarTuple default
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
class F(y := list): ...
2+
def f():
3+
class G((yield 1)): ...
4+
class H((yield from 1)): ...
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
def f() -> (y := 3): ...
22
def g(arg: (x := 1)): ...
3+
def outer():
4+
def i(x: (yield 1)): ...
5+
def k() -> (yield 1): ...
6+
def m(x: (yield from 1)): ...
7+
def o() -> (yield from 1): ...

crates/ruff_python_parser/src/semantic_errors.rs

Lines changed: 16 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -120,26 +120,27 @@ impl SemanticSyntaxChecker {
120120
fn check_annotation<Ctx: SemanticSyntaxContext>(stmt: &ast::Stmt, ctx: &Ctx) {
121121
match stmt {
122122
Stmt::FunctionDef(ast::StmtFunctionDef {
123-
type_params,
123+
type_params: Some(type_params),
124124
parameters,
125125
returns,
126126
..
127127
}) => {
128128
// test_ok valid_annotation_function
129129
// def f() -> (y := 3): ...
130130
// def g(arg: (x := 1)): ...
131+
// def outer():
132+
// def i(x: (yield 1)): ...
133+
// def k() -> (yield 1): ...
134+
// def m(x: (yield from 1)): ...
135+
// def o() -> (yield from 1): ...
131136

132137
// test_err invalid_annotation_function
133138
// def f[T]() -> (y := 3): ...
134139
// def g[T](arg: (x := 1)): ...
135140
// def h[T](x: (yield 1)): ...
136-
// def i(x: (yield 1)): ...
137141
// def j[T]() -> (yield 1): ...
138-
// def k() -> (yield 1): ...
139142
// def l[T](x: (yield from 1)): ...
140-
// def m(x: (yield from 1)): ...
141143
// def n[T]() -> (yield from 1): ...
142-
// def o() -> (yield from 1): ...
143144
// def p[T: (yield 1)](): ... # yield in TypeVar bound
144145
// def q[T = (yield 1)](): ... # yield in TypeVar default
145146
// def r[*Ts = (yield 1)](): ... # yield in TypeVarTuple default
@@ -148,20 +149,11 @@ impl SemanticSyntaxChecker {
148149
// def u[T = (x := 1)](): ... # named expr in TypeVar default
149150
// def v[*Ts = (x := 1)](): ... # named expr in TypeVarTuple default
150151
// def w[**Ts = (x := 1)](): ... # named expr in ParamSpec default
151-
let is_generic = type_params.is_some();
152152
let mut visitor = InvalidExpressionVisitor {
153-
allow_named_expr: !is_generic,
154-
position: InvalidExpressionPosition::TypeAnnotation,
153+
position: InvalidExpressionPosition::GenericDefinition,
155154
ctx,
156155
};
157-
if let Some(type_params) = type_params {
158-
visitor.visit_type_params(type_params);
159-
}
160-
if is_generic {
161-
visitor.position = InvalidExpressionPosition::GenericDefinition;
162-
} else {
163-
visitor.position = InvalidExpressionPosition::TypeAnnotation;
164-
}
156+
visitor.visit_type_params(type_params);
165157
for param in parameters
166158
.iter()
167159
.filter_map(ast::AnyParameterRef::annotation)
@@ -173,36 +165,29 @@ impl SemanticSyntaxChecker {
173165
}
174166
}
175167
Stmt::ClassDef(ast::StmtClassDef {
176-
type_params,
168+
type_params: Some(type_params),
177169
arguments,
178170
..
179171
}) => {
180172
// test_ok valid_annotation_class
181173
// class F(y := list): ...
174+
// def f():
175+
// class G((yield 1)): ...
176+
// class H((yield from 1)): ...
182177

183178
// test_err invalid_annotation_class
184179
// class F[T](y := list): ...
185-
// class G((yield 1)): ...
186-
// class H((yield from 1)): ...
187180
// class I[T]((yield 1)): ...
188181
// class J[T]((yield from 1)): ...
189182
// class K[T: (yield 1)]: ... # yield in TypeVar
190183
// class L[T: (x := 1)]: ... # named expr in TypeVar
191-
let is_generic = type_params.is_some();
192184
let mut visitor = InvalidExpressionVisitor {
193-
allow_named_expr: !is_generic,
194-
position: InvalidExpressionPosition::TypeAnnotation,
185+
position: InvalidExpressionPosition::TypeVarBound,
195186
ctx,
196187
};
197-
if let Some(type_params) = type_params {
198-
visitor.visit_type_params(type_params);
199-
}
200-
if is_generic {
201-
visitor.position = InvalidExpressionPosition::GenericDefinition;
202-
} else {
203-
visitor.position = InvalidExpressionPosition::BaseClass;
204-
}
188+
visitor.visit_type_params(type_params);
205189
if let Some(arguments) = arguments {
190+
visitor.position = InvalidExpressionPosition::GenericDefinition;
206191
visitor.visit_arguments(arguments);
207192
}
208193
}
@@ -217,7 +202,6 @@ impl SemanticSyntaxChecker {
217202
// type Y = (yield 1) # yield in value
218203
// type Y = (x := 1) # named expr in value
219204
let mut visitor = InvalidExpressionVisitor {
220-
allow_named_expr: false,
221205
position: InvalidExpressionPosition::TypeAlias,
222206
ctx,
223207
};
@@ -625,12 +609,6 @@ impl Display for SemanticSyntaxError {
625609
write!(f, "cannot delete `__debug__` on Python {python_version} (syntax was removed in 3.9)")
626610
}
627611
},
628-
SemanticSyntaxErrorKind::InvalidExpression(
629-
kind,
630-
InvalidExpressionPosition::BaseClass,
631-
) => {
632-
write!(f, "{kind} cannot be used as a base class")
633-
}
634612
SemanticSyntaxErrorKind::InvalidExpression(kind, position) => {
635613
write!(f, "{kind} cannot be used within a {position}")
636614
}
@@ -857,8 +835,6 @@ pub enum InvalidExpressionPosition {
857835
TypeVarDefault,
858836
TypeVarTupleDefault,
859837
ParamSpecDefault,
860-
TypeAnnotation,
861-
BaseClass,
862838
GenericDefinition,
863839
TypeAlias,
864840
}
@@ -870,9 +846,7 @@ impl Display for InvalidExpressionPosition {
870846
InvalidExpressionPosition::TypeVarDefault => "TypeVar default",
871847
InvalidExpressionPosition::TypeVarTupleDefault => "TypeVarTuple default",
872848
InvalidExpressionPosition::ParamSpecDefault => "ParamSpec default",
873-
InvalidExpressionPosition::TypeAnnotation => "type annotation",
874849
InvalidExpressionPosition::GenericDefinition => "generic definition",
875-
InvalidExpressionPosition::BaseClass => "base class",
876850
InvalidExpressionPosition::TypeAlias => "type alias",
877851
})
878852
}
@@ -1086,16 +1060,6 @@ impl<'a, Ctx: SemanticSyntaxContext> MatchPatternVisitor<'a, Ctx> {
10861060
}
10871061

10881062
struct InvalidExpressionVisitor<'a, Ctx> {
1089-
/// Allow named expressions (`x := ...`) to appear in annotations.
1090-
///
1091-
/// These are allowed in non-generic functions, for example:
1092-
///
1093-
/// ```python
1094-
/// def foo(arg: (x := int)): ... # ok
1095-
/// def foo[T](arg: (x := int)): ... # syntax error
1096-
/// ```
1097-
allow_named_expr: bool,
1098-
10991063
/// Context used for emitting errors.
11001064
ctx: &'a Ctx,
11011065

@@ -1108,7 +1072,7 @@ where
11081072
{
11091073
fn visit_expr(&mut self, expr: &Expr) {
11101074
match expr {
1111-
Expr::Named(ast::ExprNamed { range, .. }) if !self.allow_named_expr => {
1075+
Expr::Named(ast::ExprNamed { range, .. }) => {
11121076
SemanticSyntaxChecker::add_error(
11131077
self.ctx,
11141078
SemanticSyntaxErrorKind::InvalidExpression(

crates/ruff_python_parser/tests/snapshots/invalid_syntax@function_def_invalid_return_expr.py.snap

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,3 @@ Module(
179179
3 | def foo() -> yield x: ...
180180
| ^^^^^^^ Syntax Error: Yield expression cannot be used here
181181
|
182-
183-
184-
## Semantic Syntax Errors
185-
186-
|
187-
1 | def foo() -> *int: ...
188-
2 | def foo() -> (*int): ...
189-
3 | def foo() -> yield x: ...
190-
| ^^^^^^^ Syntax Error: yield expression cannot be used within a type annotation
191-
|

0 commit comments

Comments
 (0)