Skip to content

Commit 75aa6cf

Browse files
committed
Auto merge of #774 - detrumi:should-continue, r=jackh726
Implement should_continue in chalk-recursive This just returns `NoSolution` if it shouldn't continue, but that should already be useful to rust-analyzer. Note: Cloning of `should_continue` is a workaround to a rustc bug ([#95734](rust-lang/rust#95734))
2 parents 7efd275 + f6ac6f5 commit 75aa6cf

File tree

5 files changed

+72
-29
lines changed

5 files changed

+72
-29
lines changed

chalk-engine/src/slg/aggregate.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub trait AggregateOps<I: Interner> {
1717
&self,
1818
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
1919
answers: impl context::AnswerStream<I>,
20-
should_continue: impl std::ops::Fn() -> bool,
20+
should_continue: impl std::ops::Fn() -> bool + Clone,
2121
) -> Option<Solution<I>>;
2222
}
2323

@@ -28,7 +28,7 @@ impl<I: Interner> AggregateOps<I> for SlgContextOps<'_, I> {
2828
&self,
2929
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
3030
mut answers: impl context::AnswerStream<I>,
31-
should_continue: impl std::ops::Fn() -> bool,
31+
should_continue: impl std::ops::Fn() -> bool + Clone,
3232
) -> Option<Solution<I>> {
3333
let interner = self.program.interner();
3434
let CompleteAnswer { subst, ambiguous } = match answers.next_answer(&should_continue) {

chalk-recursive/src/fixed_point.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ where
4343
context: &mut RecursiveContext<K, V>,
4444
goal: &K,
4545
minimums: &mut Minimums,
46+
should_continue: impl std::ops::Fn() -> bool + Clone,
4647
) -> V;
4748
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
4849
fn error_value(self) -> V;
@@ -104,22 +105,24 @@ where
104105
&mut self,
105106
canonical_goal: &K,
106107
solver_stuff: impl SolverStuff<K, V>,
108+
should_continue: impl std::ops::Fn() -> bool + Clone,
107109
) -> V {
108110
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
109111
assert!(self.stack.is_empty());
110112
let minimums = &mut Minimums::new();
111-
self.solve_goal(canonical_goal, minimums, solver_stuff)
113+
self.solve_goal(canonical_goal, minimums, solver_stuff, should_continue)
112114
}
113115

114116
/// Attempt to solve a goal that has been fully broken down into leaf form
115117
/// and canonicalized. This is where the action really happens, and is the
116118
/// place where we would perform caching in rustc (and may eventually do in Chalk).
117-
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
119+
#[instrument(level = "info", skip(self, minimums, solver_stuff, should_continue))]
118120
pub fn solve_goal(
119121
&mut self,
120122
goal: &K,
121123
minimums: &mut Minimums,
122124
solver_stuff: impl SolverStuff<K, V>,
125+
should_continue: impl std::ops::Fn() -> bool + Clone,
123126
) -> V {
124127
// First check the cache.
125128
if let Some(cache) = &self.cache {
@@ -159,7 +162,8 @@ where
159162
let depth = self.stack.push(coinductive_goal);
160163
let dfn = self.search_graph.insert(goal, depth, initial_solution);
161164

162-
let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn, solver_stuff);
165+
let subgoal_minimums =
166+
self.solve_new_subgoal(goal, depth, dfn, solver_stuff, should_continue);
163167

164168
self.search_graph[dfn].links = subgoal_minimums;
165169
self.search_graph[dfn].stack_depth = None;
@@ -190,13 +194,14 @@ where
190194
}
191195
}
192196

193-
#[instrument(level = "debug", skip(self, solver_stuff))]
197+
#[instrument(level = "debug", skip(self, solver_stuff, should_continue))]
194198
fn solve_new_subgoal(
195199
&mut self,
196200
canonical_goal: &K,
197201
depth: StackDepth,
198202
dfn: DepthFirstNumber,
199203
solver_stuff: impl SolverStuff<K, V>,
204+
should_continue: impl std::ops::Fn() -> bool + Clone,
200205
) -> Minimums {
201206
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
202207
// `answer` will be updated with the result of the solving process. If we detect a cycle
@@ -209,7 +214,12 @@ where
209214
// so this function will eventually be constant and the loop terminates.
210215
loop {
211216
let minimums = &mut Minimums::new();
212-
let current_answer = solver_stuff.solve_iteration(self, canonical_goal, minimums);
217+
let current_answer = solver_stuff.solve_iteration(
218+
self,
219+
canonical_goal,
220+
minimums,
221+
should_continue.clone(), // Note: cloning required as workaround for https://github.com/rust-lang/rust/issues/95734
222+
);
213223

214224
debug!(
215225
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",

chalk-recursive/src/fulfill.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -342,24 +342,31 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
342342
Ok(())
343343
}
344344

345-
#[instrument(level = "debug", skip(self, minimums))]
345+
#[instrument(level = "debug", skip(self, minimums, should_continue))]
346346
fn prove(
347347
&mut self,
348348
wc: InEnvironment<Goal<I>>,
349349
minimums: &mut Minimums,
350+
should_continue: impl std::ops::Fn() -> bool + Clone,
350351
) -> Fallible<PositiveSolution<I>> {
351352
let interner = self.solver.interner();
352353
let (quantified, free_vars) = canonicalize(&mut self.infer, interner, wc);
353354
let (quantified, universes) = u_canonicalize(&mut self.infer, interner, &quantified);
354-
let result = self.solver.solve_goal(quantified, minimums);
355+
let result = self
356+
.solver
357+
.solve_goal(quantified, minimums, should_continue);
355358
Ok(PositiveSolution {
356359
free_vars,
357360
universes,
358361
solution: result?,
359362
})
360363
}
361364

362-
fn refute(&mut self, goal: InEnvironment<Goal<I>>) -> Fallible<NegativeSolution> {
365+
fn refute(
366+
&mut self,
367+
goal: InEnvironment<Goal<I>>,
368+
should_continue: impl std::ops::Fn() -> bool + Clone,
369+
) -> Fallible<NegativeSolution> {
363370
let canonicalized = match self
364371
.infer
365372
.invert_then_canonicalize(self.solver.interner(), goal)
@@ -376,7 +383,10 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
376383
let (quantified, _) =
377384
u_canonicalize(&mut self.infer, self.solver.interner(), &canonicalized);
378385
let mut minimums = Minimums::new(); // FIXME -- minimums here seems wrong
379-
if let Ok(solution) = self.solver.solve_goal(quantified, &mut minimums) {
386+
if let Ok(solution) = self
387+
.solver
388+
.solve_goal(quantified, &mut minimums, should_continue)
389+
{
380390
if solution.is_unique() {
381391
Err(NoSolution)
382392
} else {
@@ -431,7 +441,11 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
431441
}
432442
}
433443

434-
fn fulfill(&mut self, minimums: &mut Minimums) -> Fallible<Outcome> {
444+
fn fulfill(
445+
&mut self,
446+
minimums: &mut Minimums,
447+
should_continue: impl std::ops::Fn() -> bool + Clone,
448+
) -> Fallible<Outcome> {
435449
debug_span!("fulfill", obligations=?self.obligations);
436450

437451
// Try to solve all the obligations. We do this via a fixed-point
@@ -460,7 +474,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
460474
free_vars,
461475
universes,
462476
solution,
463-
} = self.prove(wc.clone(), minimums)?;
477+
} = self.prove(wc.clone(), minimums, should_continue.clone())?;
464478

465479
if let Some(constrained_subst) = solution.definite_subst(self.interner()) {
466480
// If the substitution is trivial, we won't actually make any progress by applying it!
@@ -484,7 +498,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
484498
solution.is_ambig()
485499
}
486500
Obligation::Refute(goal) => {
487-
let answer = self.refute(goal.clone())?;
501+
let answer = self.refute(goal.clone(), should_continue.clone())?;
488502
answer == NegativeSolution::Ambiguous
489503
}
490504
};
@@ -514,8 +528,12 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
514528
/// Try to fulfill all pending obligations and build the resulting
515529
/// solution. The returned solution will transform `subst` substitution with
516530
/// the outcome of type inference by updating the replacements it provides.
517-
pub(super) fn solve(mut self, minimums: &mut Minimums) -> Fallible<Solution<I>> {
518-
let outcome = match self.fulfill(minimums) {
531+
pub(super) fn solve(
532+
mut self,
533+
minimums: &mut Minimums,
534+
should_continue: impl std::ops::Fn() -> bool + Clone,
535+
) -> Fallible<Solution<I>> {
536+
let outcome = match self.fulfill(minimums, should_continue.clone()) {
519537
Ok(o) => o,
520538
Err(e) => return Err(e),
521539
};
@@ -567,7 +585,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
567585
free_vars,
568586
universes,
569587
solution,
570-
} = self.prove(goal, minimums).unwrap();
588+
} = self.prove(goal, minimums, should_continue.clone()).unwrap();
571589
if let Some(constrained_subst) =
572590
solution.constrained_subst(self.solver.interner())
573591
{

chalk-recursive/src/recursive.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ impl<I: Interner> SolverStuff<UCanonicalGoal<I>, Fallible<Solution<I>>> for &dyn
7676
context: &mut RecursiveContext<UCanonicalGoal<I>, Fallible<Solution<I>>>,
7777
goal: &UCanonicalGoal<I>,
7878
minimums: &mut Minimums,
79+
should_continue: impl std::ops::Fn() -> bool + Clone,
7980
) -> Fallible<Solution<I>> {
80-
Solver::new(context, self).solve_iteration(goal, minimums)
81+
Solver::new(context, self).solve_iteration(goal, minimums, should_continue)
8182
}
8283

8384
fn reached_fixed_point(
@@ -108,8 +109,10 @@ impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
108109
&mut self,
109110
goal: UCanonicalGoal<I>,
110111
minimums: &mut Minimums,
112+
should_continue: impl std::ops::Fn() -> bool + Clone,
111113
) -> Fallible<Solution<I>> {
112-
self.context.solve_goal(&goal, minimums, self.program)
114+
self.context
115+
.solve_goal(&goal, minimums, self.program, should_continue)
113116
}
114117

115118
fn interner(&self) -> I {
@@ -131,17 +134,18 @@ impl<I: Interner> chalk_solve::Solver<I> for RecursiveSolver<I> {
131134
program: &dyn RustIrDatabase<I>,
132135
goal: &UCanonical<InEnvironment<Goal<I>>>,
133136
) -> Option<chalk_solve::Solution<I>> {
134-
self.ctx.solve_root_goal(goal, program).ok()
137+
self.ctx.solve_root_goal(goal, program, || true).ok()
135138
}
136139

137140
fn solve_limited(
138141
&mut self,
139142
program: &dyn RustIrDatabase<I>,
140143
goal: &UCanonical<InEnvironment<Goal<I>>>,
141-
_should_continue: &dyn std::ops::Fn() -> bool,
144+
should_continue: &dyn std::ops::Fn() -> bool,
142145
) -> Option<chalk_solve::Solution<I>> {
143-
// TODO support should_continue in recursive solver
144-
self.ctx.solve_root_goal(goal, program).ok()
146+
self.ctx
147+
.solve_root_goal(goal, program, should_continue)
148+
.ok()
145149
}
146150

147151
fn solve_multiple(

chalk-recursive/src/solve.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub(super) trait SolveDatabase<I: Interner>: Sized {
2020
&mut self,
2121
goal: UCanonical<InEnvironment<Goal<I>>>,
2222
minimums: &mut Minimums,
23+
should_continue: impl std::ops::Fn() -> bool + Clone,
2324
) -> Fallible<Solution<I>>;
2425

2526
fn max_size(&self) -> usize;
@@ -35,12 +36,17 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
3536
/// Executes one iteration of the recursive solver, computing the current
3637
/// solution to the given canonical goal. This is used as part of a loop in
3738
/// the case of cyclic goals.
38-
#[instrument(level = "debug", skip(self))]
39+
#[instrument(level = "debug", skip(self, should_continue))]
3940
fn solve_iteration(
4041
&mut self,
4142
canonical_goal: &UCanonicalGoal<I>,
4243
minimums: &mut Minimums,
44+
should_continue: impl std::ops::Fn() -> bool + Clone,
4345
) -> Fallible<Solution<I>> {
46+
if !should_continue() {
47+
return Ok(Solution::Ambig(Guidance::Unknown));
48+
}
49+
4450
let UCanonical {
4551
universes,
4652
canonical:
@@ -72,7 +78,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
7278
let prog_solution = {
7379
debug_span!("prog_clauses");
7480

75-
self.solve_from_clauses(&canonical_goal, minimums)
81+
self.solve_from_clauses(&canonical_goal, minimums, should_continue)
7682
};
7783
debug!(?prog_solution);
7884

@@ -88,7 +94,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
8894
},
8995
};
9096

91-
self.solve_via_simplification(&canonical_goal, minimums)
97+
self.solve_via_simplification(&canonical_goal, minimums, should_continue)
9298
}
9399
}
94100
}
@@ -103,15 +109,16 @@ where
103109

104110
/// Helper methods for `solve_iteration`, private to this module.
105111
trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
106-
#[instrument(level = "debug", skip(self, minimums))]
112+
#[instrument(level = "debug", skip(self, minimums, should_continue))]
107113
fn solve_via_simplification(
108114
&mut self,
109115
canonical_goal: &UCanonicalGoal<I>,
110116
minimums: &mut Minimums,
117+
should_continue: impl std::ops::Fn() -> bool + Clone,
111118
) -> Fallible<Solution<I>> {
112119
let (infer, subst, goal) = self.new_inference_table(canonical_goal);
113120
match Fulfill::new_with_simplification(self, infer, subst, goal) {
114-
Ok(fulfill) => fulfill.solve(minimums),
121+
Ok(fulfill) => fulfill.solve(minimums, should_continue),
115122
Err(e) => Err(e),
116123
}
117124
}
@@ -123,6 +130,7 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
123130
&mut self,
124131
canonical_goal: &UCanonical<InEnvironment<DomainGoal<I>>>,
125132
minimums: &mut Minimums,
133+
should_continue: impl std::ops::Fn() -> bool + Clone,
126134
) -> Fallible<Solution<I>> {
127135
let mut clauses = vec![];
128136

@@ -159,7 +167,10 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
159167
let subst = subst.clone();
160168
let goal = goal.clone();
161169
let res = match Fulfill::new_with_clause(self, infer, subst, goal, implication) {
162-
Ok(fulfill) => (fulfill.solve(minimums), implication.skip_binders().priority),
170+
Ok(fulfill) => (
171+
fulfill.solve(minimums, should_continue.clone()),
172+
implication.skip_binders().priority,
173+
),
163174
Err(e) => (Err(e), ClausePriority::High),
164175
};
165176

0 commit comments

Comments
 (0)