Skip to content

Commit 4dc2c25

Browse files
authored
[red-knot] Fix type inference for except* definitions (#13320)
1 parent b72d49b commit 4dc2c25

File tree

3 files changed

+179
-51
lines changed

3 files changed

+179
-51
lines changed

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ use crate::semantic_index::SemanticIndex;
2727
use crate::Db;
2828

2929
use super::constraint::{Constraint, PatternConstraint};
30-
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
30+
use super::definition::{
31+
ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
32+
};
3133

3234
pub(super) struct SemanticIndexBuilder<'db> {
3335
// Builder state
@@ -696,6 +698,51 @@ where
696698
self.flow_merge(after_subject);
697699
}
698700
}
701+
ast::Stmt::Try(ast::StmtTry {
702+
body,
703+
handlers,
704+
orelse,
705+
finalbody,
706+
is_star,
707+
range: _,
708+
}) => {
709+
self.visit_body(body);
710+
711+
for except_handler in handlers {
712+
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
713+
let ast::ExceptHandlerExceptHandler {
714+
name: symbol_name,
715+
type_: handled_exceptions,
716+
body: handler_body,
717+
range: _,
718+
} = except_handler;
719+
720+
if let Some(handled_exceptions) = handled_exceptions {
721+
self.visit_expr(handled_exceptions);
722+
}
723+
724+
// If `handled_exceptions` above was `None`, it's something like `except as e:`,
725+
// which is invalid syntax. However, it's still pretty obvious here that the user
726+
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
727+
if let Some(symbol_name) = symbol_name {
728+
let symbol = self
729+
.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);
730+
731+
self.add_definition(
732+
symbol,
733+
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
734+
handler: except_handler,
735+
is_star: *is_star,
736+
}),
737+
);
738+
}
739+
740+
self.visit_body(handler_body);
741+
}
742+
743+
self.visit_body(orelse);
744+
self.visit_body(finalbody);
745+
}
699746
_ => {
700747
walk_stmt(self, stmt);
701748
}
@@ -958,30 +1005,6 @@ where
9581005

9591006
self.current_match_case.as_mut().unwrap().index += 1;
9601007
}
961-
962-
fn visit_except_handler(&mut self, except_handler: &'ast ast::ExceptHandler) {
963-
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
964-
let ast::ExceptHandlerExceptHandler {
965-
name: symbol_name,
966-
type_: handled_exceptions,
967-
body,
968-
range: _,
969-
} = except_handler;
970-
971-
if let Some(handled_exceptions) = handled_exceptions {
972-
self.visit_expr(handled_exceptions);
973-
}
974-
975-
// If `handled_exceptions` above was `None`, it's something like `except as e:`,
976-
// which is invalid syntax. However, it's still pretty obvious here that the user
977-
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
978-
if let Some(symbol_name) = symbol_name {
979-
let symbol = self.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);
980-
self.add_definition(symbol, except_handler);
981-
}
982-
983-
self.visit_body(body);
984-
}
9851008
}
9861009

9871010
#[derive(Copy, Clone, Debug)]

crates/red_knot_python_semantic/src/semantic_index/definition.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
5050
Parameter(ast::AnyParameterRef<'a>),
5151
WithItem(WithItemDefinitionNodeRef<'a>),
5252
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
53-
ExceptHandler(&'a ast::ExceptHandlerExceptHandler),
53+
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
5454
}
5555

5656
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
@@ -131,12 +131,6 @@ impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
131131
}
132132
}
133133

134-
impl<'a> From<&'a ast::ExceptHandlerExceptHandler> for DefinitionNodeRef<'a> {
135-
fn from(node: &'a ast::ExceptHandlerExceptHandler) -> Self {
136-
Self::ExceptHandler(node)
137-
}
138-
}
139-
140134
#[derive(Copy, Clone, Debug)]
141135
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
142136
pub(crate) node: &'a ast::StmtImportFrom,
@@ -162,6 +156,12 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {
162156
pub(crate) is_async: bool,
163157
}
164158

159+
#[derive(Copy, Clone, Debug)]
160+
pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
161+
pub(crate) handler: &'a ast::ExceptHandlerExceptHandler,
162+
pub(crate) is_star: bool,
163+
}
164+
165165
#[derive(Copy, Clone, Debug)]
166166
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
167167
pub(crate) iterable: &'a ast::Expr,
@@ -258,9 +258,13 @@ impl DefinitionNodeRef<'_> {
258258
identifier: AstNodeRef::new(parsed, identifier),
259259
index,
260260
}),
261-
DefinitionNodeRef::ExceptHandler(handler) => {
262-
DefinitionKind::ExceptHandler(AstNodeRef::new(parsed, handler))
263-
}
261+
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
262+
handler,
263+
is_star,
264+
}) => DefinitionKind::ExceptHandler(ExceptHandlerDefinitionKind {
265+
handler: AstNodeRef::new(parsed.clone(), handler),
266+
is_star,
267+
}),
264268
}
265269
}
266270

@@ -293,7 +297,7 @@ impl DefinitionNodeRef<'_> {
293297
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
294298
identifier.into()
295299
}
296-
Self::ExceptHandler(handler) => handler.into(),
300+
Self::ExceptHandler(ExceptHandlerDefinitionNodeRef { handler, .. }) => handler.into(),
297301
}
298302
}
299303
}
@@ -314,7 +318,7 @@ pub enum DefinitionKind {
314318
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
315319
WithItem(WithItemDefinitionKind),
316320
MatchPattern(MatchPatternDefinitionKind),
317-
ExceptHandler(AstNodeRef<ast::ExceptHandlerExceptHandler>),
321+
ExceptHandler(ExceptHandlerDefinitionKind),
318322
}
319323

320324
#[derive(Clone, Debug)]
@@ -430,6 +434,22 @@ impl ForStmtDefinitionKind {
430434
}
431435
}
432436

437+
#[derive(Clone, Debug)]
438+
pub struct ExceptHandlerDefinitionKind {
439+
handler: AstNodeRef<ast::ExceptHandlerExceptHandler>,
440+
is_star: bool,
441+
}
442+
443+
impl ExceptHandlerDefinitionKind {
444+
pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> {
445+
self.handler.node().type_.as_deref()
446+
}
447+
448+
pub(crate) fn is_star(&self) -> bool {
449+
self.is_star
450+
}
451+
}
452+
433453
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
434454
pub(crate) struct DefinitionNodeKey(NodeKey);
435455

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ use ruff_text_size::Ranged;
4040
use crate::module_name::ModuleName;
4141
use crate::module_resolver::{file_to_module, resolve_module};
4242
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
43-
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
43+
use crate::semantic_index::definition::{
44+
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
45+
};
4446
use crate::semantic_index::expression::Expression;
4547
use crate::semantic_index::semantic_index;
4648
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
@@ -426,8 +428,8 @@ impl<'db> TypeInferenceBuilder<'db> {
426428
definition,
427429
);
428430
}
429-
DefinitionKind::ExceptHandler(handler) => {
430-
self.infer_except_handler_definition(handler, definition);
431+
DefinitionKind::ExceptHandler(except_handler_definition) => {
432+
self.infer_except_handler_definition(except_handler_definition, definition);
431433
}
432434
}
433435
}
@@ -821,22 +823,29 @@ impl<'db> TypeInferenceBuilder<'db> {
821823

822824
fn infer_except_handler_definition(
823825
&mut self,
824-
handler: &'db ast::ExceptHandlerExceptHandler,
826+
except_handler_definition: &ExceptHandlerDefinitionKind,
825827
definition: Definition<'db>,
826828
) {
827-
let node_ty = handler
828-
.type_
829-
.as_deref()
829+
let node_ty = except_handler_definition
830+
.handled_exceptions()
830831
.map(|ty| self.infer_expression(ty))
831832
.unwrap_or(Type::Unknown);
832833

833-
// TODO: anything that's a consistent subtype of
834-
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
835-
// anything else should be invalid --Alex
836-
let symbol_ty = match node_ty {
837-
Type::Any | Type::Unknown => node_ty,
838-
Type::Class(class_ty) => Type::Instance(class_ty),
839-
_ => Type::Unknown,
834+
let symbol_ty = if except_handler_definition.is_star() {
835+
// TODO should be generic --Alex
836+
//
837+
// TODO should infer `ExceptionGroup` if all caught exceptions
838+
// are subclasses of `Exception` --Alex
839+
builtins_symbol_ty(self.db, "BaseExceptionGroup").to_instance(self.db)
840+
} else {
841+
// TODO: anything that's a consistent subtype of
842+
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
843+
// anything else should be invalid --Alex
844+
match node_ty {
845+
Type::Any | Type::Unknown => node_ty,
846+
Type::Class(class_ty) => Type::Instance(class_ty),
847+
_ => Type::Unknown,
848+
}
840849
};
841850

842851
self.types.definitions.insert(definition, symbol_ty);
@@ -4563,6 +4572,82 @@ mod tests {
45634572
Ok(())
45644573
}
45654574

4575+
#[test]
4576+
fn except_star_handler_baseexception() -> anyhow::Result<()> {
4577+
let mut db = setup_db();
4578+
4579+
db.write_dedented(
4580+
"src/a.py",
4581+
"
4582+
try:
4583+
x
4584+
except* BaseException as e:
4585+
pass
4586+
",
4587+
)?;
4588+
4589+
assert_file_diagnostics(&db, "src/a.py", &[]);
4590+
4591+
// TODO: once we support `sys.version_info` branches,
4592+
// we can set `--target-version=py311` in this test
4593+
// and the inferred type will just be `BaseExceptionGroup` --Alex
4594+
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
4595+
4596+
Ok(())
4597+
}
4598+
4599+
#[test]
4600+
fn except_star_handler() -> anyhow::Result<()> {
4601+
let mut db = setup_db();
4602+
4603+
db.write_dedented(
4604+
"src/a.py",
4605+
"
4606+
try:
4607+
x
4608+
except* OSError as e:
4609+
pass
4610+
",
4611+
)?;
4612+
4613+
assert_file_diagnostics(&db, "src/a.py", &[]);
4614+
4615+
// TODO: once we support `sys.version_info` branches,
4616+
// we can set `--target-version=py311` in this test
4617+
// and the inferred type will just be `BaseExceptionGroup` --Alex
4618+
//
4619+
// TODO more precise would be `ExceptionGroup[OSError]` --Alex
4620+
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
4621+
4622+
Ok(())
4623+
}
4624+
4625+
#[test]
4626+
fn except_star_handler_multiple_types() -> anyhow::Result<()> {
4627+
let mut db = setup_db();
4628+
4629+
db.write_dedented(
4630+
"src/a.py",
4631+
"
4632+
try:
4633+
x
4634+
except* (TypeError, AttributeError) as e:
4635+
pass
4636+
",
4637+
)?;
4638+
4639+
assert_file_diagnostics(&db, "src/a.py", &[]);
4640+
4641+
// TODO: once we support `sys.version_info` branches,
4642+
// we can set `--target-version=py311` in this test
4643+
// and the inferred type will just be `BaseExceptionGroup` --Alex
4644+
//
4645+
// TODO more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex
4646+
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");
4647+
4648+
Ok(())
4649+
}
4650+
45664651
#[test]
45674652
fn basic_comprehension() -> anyhow::Result<()> {
45684653
let mut db = setup_db();

0 commit comments

Comments
 (0)