Skip to content

Commit 7339902

Browse files
authored
[red-knot] Optimise visibility constraints for *-import definitions (#17317)
1 parent ff376fc commit 7339902

File tree

2 files changed

+103
-25
lines changed

2 files changed

+103
-25
lines changed

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,15 @@ impl<'db> SemanticIndexBuilder<'db> {
331331
self.current_use_def_map_mut().merge(state);
332332
}
333333

334-
fn add_symbol(&mut self, name: Name) -> ScopedSymbolId {
334+
/// Return a 2-element tuple, where the first element is the [`ScopedSymbolId`] of the
335+
/// symbol added, and the second element is a boolean indicating whether the symbol was *newly*
336+
/// added or not
337+
fn add_symbol(&mut self, name: Name) -> (ScopedSymbolId, bool) {
335338
let (symbol_id, added) = self.current_symbol_table().add_symbol(name);
336339
if added {
337340
self.current_use_def_map_mut().add_symbol(symbol_id);
338341
}
339-
symbol_id
342+
(symbol_id, added)
340343
}
341344

342345
fn mark_symbol_bound(&mut self, id: ScopedSymbolId) {
@@ -516,6 +519,7 @@ impl<'db> SemanticIndexBuilder<'db> {
516519
}
517520

518521
/// Records a visibility constraint by applying it to all live bindings and declarations.
522+
#[must_use = "A visibility constraint must always be negated after it is added"]
519523
fn record_visibility_constraint(
520524
&mut self,
521525
predicate: Predicate<'db>,
@@ -747,7 +751,7 @@ impl<'db> SemanticIndexBuilder<'db> {
747751
..
748752
}) => (name, &None, default),
749753
};
750-
let symbol = self.add_symbol(name.id.clone());
754+
let (symbol, _) = self.add_symbol(name.id.clone());
751755
// TODO create Definition for PEP 695 typevars
752756
// note that the "bound" on the typevar is a totally different thing than whether
753757
// or not a name is "bound" by a typevar declaration; the latter is always true.
@@ -841,20 +845,20 @@ impl<'db> SemanticIndexBuilder<'db> {
841845
self.declare_parameter(parameter);
842846
}
843847
if let Some(vararg) = parameters.vararg.as_ref() {
844-
let symbol = self.add_symbol(vararg.name.id().clone());
848+
let (symbol, _) = self.add_symbol(vararg.name.id().clone());
845849
self.add_definition(
846850
symbol,
847851
DefinitionNodeRef::VariadicPositionalParameter(vararg),
848852
);
849853
}
850854
if let Some(kwarg) = parameters.kwarg.as_ref() {
851-
let symbol = self.add_symbol(kwarg.name.id().clone());
855+
let (symbol, _) = self.add_symbol(kwarg.name.id().clone());
852856
self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg));
853857
}
854858
}
855859

856860
fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
857-
let symbol = self.add_symbol(parameter.name().id().clone());
861+
let (symbol, _) = self.add_symbol(parameter.name().id().clone());
858862

859863
let definition = self.add_definition(symbol, parameter);
860864

@@ -1071,7 +1075,7 @@ where
10711075
// The symbol for the function name itself has to be evaluated
10721076
// at the end to match the runtime evaluation of parameter defaults
10731077
// and return-type annotations.
1074-
let symbol = self.add_symbol(name.id.clone());
1078+
let (symbol, _) = self.add_symbol(name.id.clone());
10751079
self.add_definition(symbol, function_def);
10761080
}
10771081
ast::Stmt::ClassDef(class) => {
@@ -1095,11 +1099,11 @@ where
10951099
);
10961100

10971101
// In Python runtime semantics, a class is registered after its scope is evaluated.
1098-
let symbol = self.add_symbol(class.name.id.clone());
1102+
let (symbol, _) = self.add_symbol(class.name.id.clone());
10991103
self.add_definition(symbol, class);
11001104
}
11011105
ast::Stmt::TypeAlias(type_alias) => {
1102-
let symbol = self.add_symbol(
1106+
let (symbol, _) = self.add_symbol(
11031107
type_alias
11041108
.name
11051109
.as_name_expr()
@@ -1133,7 +1137,7 @@ where
11331137
(Name::new(alias.name.id.split('.').next().unwrap()), false)
11341138
};
11351139

1136-
let symbol = self.add_symbol(symbol_name);
1140+
let (symbol, _) = self.add_symbol(symbol_name);
11371141
self.add_definition(
11381142
symbol,
11391143
ImportDefinitionNodeRef {
@@ -1200,7 +1204,7 @@ where
12001204
//
12011205
// For more details, see the doc-comment on `StarImportPlaceholderPredicate`.
12021206
for export in exported_names(self.db, referenced_module) {
1203-
let symbol_id = self.add_symbol(export.clone());
1207+
let (symbol_id, newly_added) = self.add_symbol(export.clone());
12041208
let node_ref = StarImportDefinitionNodeRef { node, symbol_id };
12051209
let star_import = StarImportPlaceholderPredicate::new(
12061210
self.db,
@@ -1210,13 +1214,38 @@ where
12101214
);
12111215
let pre_definition = self.flow_snapshot();
12121216
self.push_additional_definition(symbol_id, node_ref);
1213-
let constraint_id =
1214-
self.record_visibility_constraint(star_import.into());
1215-
let post_definition = self.flow_snapshot();
1216-
self.flow_restore(pre_definition.clone());
1217-
self.record_negated_visibility_constraint(constraint_id);
1218-
self.flow_merge(post_definition);
1219-
self.simplify_visibility_constraints(pre_definition);
1217+
1218+
// Fast path for if there were no previous definitions
1219+
// of the symbol defined through the `*` import:
1220+
// we can apply the visibility constraint to *only* the added definition,
1221+
// rather than all definitions
1222+
if newly_added {
1223+
let constraint_id = self
1224+
.current_use_def_map_mut()
1225+
.record_star_import_visibility_constraint(
1226+
star_import,
1227+
symbol_id,
1228+
);
1229+
1230+
let post_definition = self.flow_snapshot();
1231+
self.flow_restore(pre_definition);
1232+
1233+
self.current_use_def_map_mut()
1234+
.negate_star_import_visibility_constraint(
1235+
symbol_id,
1236+
constraint_id,
1237+
);
1238+
1239+
self.flow_merge(post_definition);
1240+
} else {
1241+
let constraint_id =
1242+
self.record_visibility_constraint(star_import.into());
1243+
let post_definition = self.flow_snapshot();
1244+
self.flow_restore(pre_definition.clone());
1245+
self.record_negated_visibility_constraint(constraint_id);
1246+
self.flow_merge(post_definition);
1247+
self.simplify_visibility_constraints(pre_definition);
1248+
}
12201249
}
12211250

12221251
continue;
@@ -1236,7 +1265,7 @@ where
12361265
self.has_future_annotations |= alias.name.id == "annotations"
12371266
&& node.module.as_deref() == Some("__future__");
12381267

1239-
let symbol = self.add_symbol(symbol_name.clone());
1268+
let (symbol, _) = self.add_symbol(symbol_name.clone());
12401269

12411270
self.add_definition(
12421271
symbol,
@@ -1636,7 +1665,7 @@ where
16361665
// which is invalid syntax. However, it's still pretty obvious here that the user
16371666
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
16381667
if let Some(symbol_name) = symbol_name {
1639-
let symbol = self.add_symbol(symbol_name.id.clone());
1668+
let (symbol, _) = self.add_symbol(symbol_name.id.clone());
16401669

16411670
self.add_definition(
16421671
symbol,
@@ -1721,7 +1750,7 @@ where
17211750
(ast::ExprContext::Del, _) => (false, true),
17221751
(ast::ExprContext::Invalid, _) => (false, false),
17231752
};
1724-
let symbol = self.add_symbol(id.clone());
1753+
let (symbol, _) = self.add_symbol(id.clone());
17251754

17261755
if is_use {
17271756
self.mark_symbol_used(symbol);
@@ -2007,7 +2036,7 @@ where
20072036
range: _,
20082037
}) = pattern
20092038
{
2010-
let symbol = self.add_symbol(name.id().clone());
2039+
let (symbol, _) = self.add_symbol(name.id().clone());
20112040
let state = self.current_match_case.as_ref().unwrap();
20122041
self.add_definition(
20132042
symbol,
@@ -2028,7 +2057,7 @@ where
20282057
rest: Some(name), ..
20292058
}) = pattern
20302059
{
2031-
let symbol = self.add_symbol(name.id().clone());
2060+
let (symbol, _) = self.add_symbol(name.id().clone());
20322061
let state = self.current_match_case.as_ref().unwrap();
20332062
self.add_definition(
20342063
symbol,

crates/red_knot_python_semantic/src/semantic_index/use_def.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ use crate::semantic_index::narrowing_constraints::{
269269
NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator,
270270
};
271271
use crate::semantic_index::predicate::{
272-
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId,
272+
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
273273
};
274274
use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId};
275275
use crate::semantic_index::visibility_constraints::{
@@ -603,7 +603,7 @@ pub(super) struct UseDefMapBuilder<'db> {
603603
/// x # we store a reachability constraint of [test] for this use of `x`
604604
///
605605
/// y = 2
606-
///
606+
///
607607
/// # we record a visibility constraint of [test] here, which retroactively affects
608608
/// # the `y = 1` and the `y = 2` binding.
609609
/// else:
@@ -701,6 +701,34 @@ impl<'db> UseDefMapBuilder<'db> {
701701
.add_and_constraint(self.scope_start_visibility, constraint);
702702
}
703703

704+
#[must_use = "A `*`-import visibility constraint must always be negated after it is added"]
705+
pub(super) fn record_star_import_visibility_constraint(
706+
&mut self,
707+
star_import: StarImportPlaceholderPredicate<'db>,
708+
symbol: ScopedSymbolId,
709+
) -> StarImportVisibilityConstraintId {
710+
let predicate_id = self.add_predicate(star_import.into());
711+
let visibility_id = self.visibility_constraints.add_atom(predicate_id);
712+
self.symbol_states[symbol]
713+
.record_visibility_constraint(&mut self.visibility_constraints, visibility_id);
714+
StarImportVisibilityConstraintId(visibility_id)
715+
}
716+
717+
pub(super) fn negate_star_import_visibility_constraint(
718+
&mut self,
719+
symbol_id: ScopedSymbolId,
720+
constraint: StarImportVisibilityConstraintId,
721+
) {
722+
let negated_constraint = self
723+
.visibility_constraints
724+
.add_not_constraint(constraint.into_scoped_constraint_id());
725+
self.symbol_states[symbol_id]
726+
.record_visibility_constraint(&mut self.visibility_constraints, negated_constraint);
727+
self.scope_start_visibility = self
728+
.visibility_constraints
729+
.add_and_constraint(self.scope_start_visibility, negated_constraint);
730+
}
731+
704732
/// This method resets the visibility constraints for all symbols to a previous state
705733
/// *if* there have been no new declarations or bindings since then. Consider the
706734
/// following example:
@@ -900,3 +928,24 @@ impl<'db> UseDefMapBuilder<'db> {
900928
}
901929
}
902930
}
931+
932+
/// Newtype wrapper over [`ScopedVisibilityConstraintId`] to improve type safety.
933+
///
934+
/// By returning this type from [`UseDefMapBuilder::record_star_import_visibility_constraint`]
935+
/// rather than [`ScopedVisibilityConstraintId`] directly, we ensure that
936+
/// [`UseDefMapBuilder::negate_star_import_visibility_constraint`] must be called after the
937+
/// visibility constraint has been added, and we ensure that
938+
/// [`super::SemanticIndexBuilder::record_negated_visibility_constraint`] *cannot* be called with
939+
/// the narrowing constraint (which would lead to incorrect behaviour).
940+
///
941+
/// This type is defined here rather than in the [`super::visibility_constraints`] module
942+
/// because it should only ever be constructed and deconstructed from methods in the
943+
/// [`UseDefMapBuilder`].
944+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
945+
pub(super) struct StarImportVisibilityConstraintId(ScopedVisibilityConstraintId);
946+
947+
impl StarImportVisibilityConstraintId {
948+
fn into_scoped_constraint_id(self) -> ScopedVisibilityConstraintId {
949+
self.0
950+
}
951+
}

0 commit comments

Comments
 (0)