Skip to content

Commit 2564d04

Browse files
committed
Refactor where predicates, and reserve for attributes support
1 parent 328b759 commit 2564d04

File tree

46 files changed

+491
-375
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+491
-375
lines changed

compiler/rustc_ast/src/ast.rs

+31-6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use rustc_macros::{Decodable, Encodable, HashStable_Generic};
3232
pub use rustc_span::AttrId;
3333
use rustc_span::source_map::{Spanned, respan};
3434
use rustc_span::symbol::{Ident, Symbol, kw, sym};
35-
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
35+
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span, SyntaxContext};
3636
use thin_vec::{ThinVec, thin_vec};
3737

3838
pub use crate::format::*;
@@ -428,7 +428,32 @@ impl Default for WhereClause {
428428

429429
/// A single predicate in a where-clause.
430430
#[derive(Clone, Encodable, Decodable, Debug)]
431-
pub enum WherePredicate {
431+
pub struct WherePredicate {
432+
pub kind: WherePredicateKind,
433+
pub id: NodeId,
434+
pub span: Span,
435+
}
436+
437+
impl WherePredicate {
438+
pub fn with_kind(&self, kind: WherePredicateKind) -> WherePredicate {
439+
self.map_kind(None, |_| kind)
440+
}
441+
pub fn map_kind(
442+
&self,
443+
ctxt: Option<SyntaxContext>,
444+
f: impl FnOnce(&WherePredicateKind) -> WherePredicateKind,
445+
) -> WherePredicate {
446+
WherePredicate {
447+
kind: f(&self.kind),
448+
id: DUMMY_NODE_ID,
449+
span: ctxt.map_or(self.span, |ctxt| self.span.with_ctxt(ctxt)),
450+
}
451+
}
452+
}
453+
454+
/// Predicate kind in where-clause.
455+
#[derive(Clone, Encodable, Decodable, Debug)]
456+
pub enum WherePredicateKind {
432457
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
433458
BoundPredicate(WhereBoundPredicate),
434459
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
@@ -437,12 +462,12 @@ pub enum WherePredicate {
437462
EqPredicate(WhereEqPredicate),
438463
}
439464

440-
impl WherePredicate {
465+
impl WherePredicateKind {
441466
pub fn span(&self) -> Span {
442467
match self {
443-
WherePredicate::BoundPredicate(p) => p.span,
444-
WherePredicate::RegionPredicate(p) => p.span,
445-
WherePredicate::EqPredicate(p) => p.span,
468+
WherePredicateKind::BoundPredicate(p) => p.span,
469+
WherePredicateKind::RegionPredicate(p) => p.span,
470+
WherePredicateKind::EqPredicate(p) => p.span,
446471
}
447472
}
448473
}

compiler/rustc_ast/src/ast_traits.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::tokenstream::LazyAttrTokenStream;
1111
use crate::{
1212
Arm, AssocItem, AttrItem, AttrKind, AttrVec, Attribute, Block, Crate, Expr, ExprField,
1313
FieldDef, ForeignItem, GenericParam, Item, NodeId, Param, Pat, PatField, Path, Stmt, StmtKind,
14-
Ty, Variant, Visibility,
14+
Ty, Variant, Visibility, WherePredicate,
1515
};
1616

1717
/// A utility trait to reduce boilerplate.
@@ -79,6 +79,7 @@ impl_has_node_id!(
7979
Stmt,
8080
Ty,
8181
Variant,
82+
WherePredicate,
8283
);
8384

8485
impl<T: AstDeref<Target: HasNodeId>> HasNodeId for T {
@@ -127,7 +128,16 @@ macro_rules! impl_has_tokens_none {
127128
}
128129

129130
impl_has_tokens!(AssocItem, AttrItem, Block, Expr, ForeignItem, Item, Pat, Path, Ty, Visibility);
130-
impl_has_tokens_none!(Arm, ExprField, FieldDef, GenericParam, Param, PatField, Variant);
131+
impl_has_tokens_none!(
132+
Arm,
133+
ExprField,
134+
FieldDef,
135+
GenericParam,
136+
Param,
137+
PatField,
138+
Variant,
139+
WherePredicate
140+
);
131141

132142
impl<T: AstDeref<Target: HasTokens>> HasTokens for T {
133143
fn tokens(&self) -> Option<&LazyAttrTokenStream> {
@@ -290,7 +300,7 @@ impl_has_attrs!(
290300
PatField,
291301
Variant,
292302
);
293-
impl_has_attrs_none!(Attribute, AttrItem, Block, Pat, Path, Ty, Visibility);
303+
impl_has_attrs_none!(Attribute, AttrItem, Block, Pat, Path, Ty, Visibility, WherePredicate);
294304

295305
impl<T: AstDeref<Target: HasAttrs>> HasAttrs for T {
296306
const SUPPORTS_CUSTOM_INNER_ATTRS: bool = T::Target::SUPPORTS_CUSTOM_INNER_ATTRS;

compiler/rustc_ast/src/mut_visit.rs

+26-8
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,15 @@ pub trait MutVisitor: Sized {
286286
walk_where_clause(self, where_clause);
287287
}
288288

289-
fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
290-
walk_where_predicate(self, where_predicate);
289+
fn flat_map_where_predicate(
290+
&mut self,
291+
where_predicate: WherePredicate,
292+
) -> SmallVec<[WherePredicate; 1]> {
293+
walk_flat_map_where_predicate(self, where_predicate)
294+
}
295+
296+
fn visit_where_predicate_kind(&mut self, kind: &mut WherePredicateKind) {
297+
walk_where_predicate_kind(self, kind)
291298
}
292299

293300
fn visit_vis(&mut self, vis: &mut Visibility) {
@@ -987,26 +994,37 @@ fn walk_ty_alias_where_clauses<T: MutVisitor>(vis: &mut T, tawcs: &mut TyAliasWh
987994

988995
fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
989996
let WhereClause { has_where_token: _, predicates, span } = wc;
990-
visit_thin_vec(predicates, |predicate| vis.visit_where_predicate(predicate));
997+
predicates.flat_map_in_place(|predicate| vis.flat_map_where_predicate(predicate));
991998
vis.visit_span(span);
992999
}
9931000

994-
fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
995-
match pred {
996-
WherePredicate::BoundPredicate(bp) => {
1001+
pub fn walk_flat_map_where_predicate<T: MutVisitor>(
1002+
vis: &mut T,
1003+
mut pred: WherePredicate,
1004+
) -> SmallVec<[WherePredicate; 1]> {
1005+
let WherePredicate { ref mut kind, ref mut id, ref mut span } = pred;
1006+
vis.visit_id(id);
1007+
vis.visit_where_predicate_kind(kind);
1008+
vis.visit_span(span);
1009+
smallvec![pred]
1010+
}
1011+
1012+
pub fn walk_where_predicate_kind<T: MutVisitor>(vis: &mut T, kind: &mut WherePredicateKind) {
1013+
match kind {
1014+
WherePredicateKind::BoundPredicate(bp) => {
9971015
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
9981016
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
9991017
vis.visit_ty(bounded_ty);
10001018
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
10011019
vis.visit_span(span);
10021020
}
1003-
WherePredicate::RegionPredicate(rp) => {
1021+
WherePredicateKind::RegionPredicate(rp) => {
10041022
let WhereRegionPredicate { span, lifetime, bounds } = rp;
10051023
vis.visit_lifetime(lifetime);
10061024
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
10071025
vis.visit_span(span);
10081026
}
1009-
WherePredicate::EqPredicate(ep) => {
1027+
WherePredicateKind::EqPredicate(ep) => {
10101028
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
10111029
vis.visit_ty(lhs_ty);
10121030
vis.visit_ty(rhs_ty);

compiler/rustc_ast/src/visit.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ pub trait Visitor<'ast>: Sized {
188188
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
189189
walk_where_predicate(self, p)
190190
}
191+
fn visit_where_predicate_kind(&mut self, k: &'ast WherePredicateKind) -> Self::Result {
192+
walk_where_predicate_kind(self, k)
193+
}
191194
fn visit_fn(&mut self, fk: FnKind<'ast>, _: Span, _: NodeId) -> Self::Result {
192195
walk_fn(self, fk)
193196
}
@@ -786,8 +789,16 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
786789
visitor: &mut V,
787790
predicate: &'a WherePredicate,
788791
) -> V::Result {
789-
match predicate {
790-
WherePredicate::BoundPredicate(WhereBoundPredicate {
792+
let WherePredicate { kind, id: _, span: _ } = predicate;
793+
visitor.visit_where_predicate_kind(kind)
794+
}
795+
796+
pub fn walk_where_predicate_kind<'a, V: Visitor<'a>>(
797+
visitor: &mut V,
798+
kind: &'a WherePredicateKind,
799+
) -> V::Result {
800+
match kind {
801+
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
791802
bounded_ty,
792803
bounds,
793804
bound_generic_params,
@@ -797,11 +808,11 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
797808
try_visit!(visitor.visit_ty(bounded_ty));
798809
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
799810
}
800-
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
811+
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
801812
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
802813
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
803814
}
804-
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
815+
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
805816
try_visit!(visitor.visit_ty(lhs_ty));
806817
try_visit!(visitor.visit_ty(rhs_ty));
807818
}

compiler/rustc_ast_lowering/src/index.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,10 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
381381
}
382382

383383
fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
384-
match predicate {
385-
WherePredicate::BoundPredicate(pred) => {
386-
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
387-
self.with_parent(pred.hir_id, |this| {
388-
intravisit::walk_where_predicate(this, predicate)
389-
})
390-
}
391-
_ => intravisit::walk_where_predicate(self, predicate),
392-
}
384+
self.insert(predicate.span, predicate.hir_id, Node::WherePredicate(predicate));
385+
self.with_parent(predicate.hir_id, |this| {
386+
intravisit::walk_where_predicate(this, predicate)
387+
});
393388
}
394389

395390
fn visit_array_length(&mut self, len: &'hir ArrayLen<'hir>) {

compiler/rustc_ast_lowering/src/item.rs

+32-27
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
14001400
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
14011401
// checks both param bounds and where clauses for `?Sized`.
14021402
for pred in &generics.where_clause.predicates {
1403-
let WherePredicate::BoundPredicate(bound_pred) = pred else {
1403+
let WherePredicateKind::BoundPredicate(ref bound_pred) = pred.kind else {
14041404
continue;
14051405
};
14061406
let compute_is_param = || {
@@ -1538,8 +1538,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
15381538
});
15391539
let span = self.lower_span(span);
15401540

1541-
match kind {
1542-
GenericParamKind::Const { .. } => None,
1541+
let kind = match kind {
1542+
GenericParamKind::Const { .. } => return None,
15431543
GenericParamKind::Type { .. } => {
15441544
let def_id = self.local_def_id(id).to_def_id();
15451545
let hir_id = self.next_id();
@@ -1554,38 +1554,38 @@ impl<'hir> LoweringContext<'_, 'hir> {
15541554
let ty_id = self.next_id();
15551555
let bounded_ty =
15561556
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
1557-
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
1558-
hir_id: self.next_id(),
1557+
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
15591558
bounded_ty: self.arena.alloc(bounded_ty),
15601559
bounds,
15611560
span,
15621561
bound_generic_params: &[],
15631562
origin,
1564-
}))
1563+
})
15651564
}
15661565
GenericParamKind::Lifetime => {
15671566
let ident = self.lower_ident(ident);
15681567
let lt_id = self.next_node_id();
15691568
let lifetime = self.new_named_lifetime(id, lt_id, ident);
1570-
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
1569+
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
15711570
lifetime,
15721571
span,
15731572
bounds,
15741573
in_where_clause: false,
1575-
}))
1574+
})
15761575
}
1577-
}
1576+
};
1577+
Some(hir::WherePredicate { hir_id: self.next_id(), kind: self.arena.alloc(kind), span })
15781578
}
15791579

15801580
fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
1581-
match pred {
1582-
WherePredicate::BoundPredicate(WhereBoundPredicate {
1581+
let hir_id = self.lower_node_id(pred.id);
1582+
let kind = match &pred.kind {
1583+
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
15831584
bound_generic_params,
15841585
bounded_ty,
15851586
bounds,
15861587
span,
1587-
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
1588-
hir_id: self.next_id(),
1588+
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
15891589
bound_generic_params: self
15901590
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
15911591
bounded_ty: self
@@ -1597,26 +1597,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
15971597
span: self.lower_span(*span),
15981598
origin: PredicateOrigin::WhereClause,
15991599
}),
1600-
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
1601-
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
1602-
span: self.lower_span(*span),
1603-
lifetime: self.lower_lifetime(lifetime),
1604-
bounds: self.lower_param_bounds(
1605-
bounds,
1606-
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
1607-
),
1608-
in_where_clause: true,
1609-
})
1610-
}
1611-
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
1612-
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
1600+
WherePredicateKind::RegionPredicate(WhereRegionPredicate {
1601+
lifetime,
1602+
bounds,
1603+
span,
1604+
}) => hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
1605+
span: self.lower_span(*span),
1606+
lifetime: self.lower_lifetime(lifetime),
1607+
bounds: self.lower_param_bounds(
1608+
bounds,
1609+
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
1610+
),
1611+
in_where_clause: true,
1612+
}),
1613+
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
1614+
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
16131615
lhs_ty: self
16141616
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
16151617
rhs_ty: self
16161618
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
16171619
span: self.lower_span(*span),
16181620
})
16191621
}
1620-
}
1622+
};
1623+
let kind = self.arena.alloc(kind);
1624+
let span = self.lower_span(pred.span);
1625+
hir::WherePredicate { hir_id, kind, span }
16211626
}
16221627
}

0 commit comments

Comments
 (0)