Skip to content

Commit c1d64ee

Browse files
committed
Auto merge of #708 - nikomatsakis:extract-generic-cache-logic, r=jackh726
rework recursive solver for better integration into an expanded version of salsa The overall plan is [described here](https://hackmd.io/FCrUiW27TnKw3MTvRtfPjQ). I should probably move that into a tracking issue or project board or *something*. =) This branch so far contains: * [x] Separate out the logic to iterate until a fixed point is reached and to manage the caching during that iteration from the logic to search program clauses.
2 parents 9456600 + 8ce0360 commit c1d64ee

File tree

9 files changed

+459
-316
lines changed

9 files changed

+459
-316
lines changed

chalk-integration/src/lib.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub mod tls;
1414
use chalk_engine::solve::SLGSolver;
1515
use chalk_ir::interner::HasInterner;
1616
use chalk_ir::Binders;
17-
use chalk_recursive::RecursiveSolver;
17+
use chalk_recursive::{Cache, RecursiveSolver};
1818
use chalk_solve::Solver;
1919
use interner::ChalkIr;
2020

@@ -104,7 +104,11 @@ impl SolverChoice {
104104
} => Box::new(RecursiveSolver::new(
105105
overflow_depth,
106106
max_size,
107-
caching_enabled,
107+
if caching_enabled {
108+
Some(Cache::default())
109+
} else {
110+
None
111+
},
108112
)),
109113
}
110114
}

chalk-recursive/src/fixed_point.rs

+237
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
use std::fmt::Debug;
2+
use std::hash::Hash;
3+
use tracing::debug;
4+
use tracing::{info, instrument};
5+
6+
mod cache;
7+
mod search_graph;
8+
mod stack;
9+
10+
pub use cache::Cache;
11+
use search_graph::{DepthFirstNumber, SearchGraph};
12+
use stack::{Stack, StackDepth};
13+
14+
pub(super) struct RecursiveContext<K, V>
15+
where
16+
K: Hash + Eq + Debug + Clone,
17+
V: Debug + Clone,
18+
{
19+
stack: Stack,
20+
21+
/// The "search graph" stores "in-progress results" that are still being
22+
/// solved.
23+
search_graph: SearchGraph<K, V>,
24+
25+
/// The "cache" stores results for goals that we have completely solved.
26+
/// Things are added to the cache when we have completely processed their
27+
/// result.
28+
cache: Option<Cache<K, V>>,
29+
30+
/// The maximum size for goals.
31+
max_size: usize,
32+
}
33+
34+
pub(super) trait SolverStuff<K, V>: Copy
35+
where
36+
K: Hash + Eq + Debug + Clone,
37+
V: Debug + Clone,
38+
{
39+
fn is_coinductive_goal(self, goal: &K) -> bool;
40+
fn initial_value(self, goal: &K, coinductive_goal: bool) -> V;
41+
fn solve_iteration(
42+
self,
43+
context: &mut RecursiveContext<K, V>,
44+
goal: &K,
45+
minimums: &mut Minimums,
46+
) -> V;
47+
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
48+
fn error_value(self) -> V;
49+
}
50+
51+
/// The `minimums` struct is used while solving to track whether we encountered
52+
/// any cycles in the process.
53+
#[derive(Copy, Clone, Debug)]
54+
pub(super) struct Minimums {
55+
positive: DepthFirstNumber,
56+
}
57+
58+
impl Minimums {
59+
pub fn new() -> Self {
60+
Minimums {
61+
positive: DepthFirstNumber::MAX,
62+
}
63+
}
64+
65+
pub fn update_from(&mut self, minimums: Minimums) {
66+
self.positive = ::std::cmp::min(self.positive, minimums.positive);
67+
}
68+
}
69+
70+
impl<K, V> RecursiveContext<K, V>
71+
where
72+
K: Hash + Eq + Debug + Clone,
73+
V: Debug + Clone,
74+
{
75+
pub fn new(overflow_depth: usize, max_size: usize, cache: Option<Cache<K, V>>) -> Self {
76+
RecursiveContext {
77+
stack: Stack::new(overflow_depth),
78+
search_graph: SearchGraph::new(),
79+
cache,
80+
max_size,
81+
}
82+
}
83+
84+
pub fn max_size(&self) -> usize {
85+
self.max_size
86+
}
87+
88+
/// Solves a canonical goal. The substitution returned in the
89+
/// solution will be for the fully decomposed goal. For example, given the
90+
/// program
91+
///
92+
/// ```ignore
93+
/// struct u8 { }
94+
/// struct SomeType<T> { }
95+
/// trait Foo<T> { }
96+
/// impl<U> Foo<u8> for SomeType<U> { }
97+
/// ```
98+
///
99+
/// and the goal `exists<V> { forall<U> { SomeType<U>: Foo<V> }
100+
/// }`, `into_peeled_goal` can be used to create a canonical goal
101+
/// `SomeType<!1>: Foo<?0>`. This function will then return a
102+
/// solution with the substitution `?0 := u8`.
103+
pub fn solve_root_goal(
104+
&mut self,
105+
canonical_goal: &K,
106+
solver_stuff: impl SolverStuff<K, V>,
107+
) -> V {
108+
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
109+
assert!(self.stack.is_empty());
110+
let minimums = &mut Minimums::new();
111+
self.solve_goal(canonical_goal, minimums, solver_stuff)
112+
}
113+
114+
/// Attempt to solve a goal that has been fully broken down into leaf form
115+
/// and canonicalized. This is where the action really happens, and is the
116+
/// place where we would perform caching in rustc (and may eventually do in Chalk).
117+
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
118+
pub fn solve_goal(
119+
&mut self,
120+
goal: &K,
121+
minimums: &mut Minimums,
122+
solver_stuff: impl SolverStuff<K, V>,
123+
) -> V {
124+
// First check the cache.
125+
if let Some(cache) = &self.cache {
126+
if let Some(value) = cache.get(&goal) {
127+
debug!("solve_reduced_goal: cache hit, value={:?}", value);
128+
return value.clone();
129+
}
130+
}
131+
132+
// Next, check if the goal is in the search tree already.
133+
if let Some(dfn) = self.search_graph.lookup(&goal) {
134+
// Check if this table is still on the stack.
135+
if let Some(depth) = self.search_graph[dfn].stack_depth {
136+
self.stack[depth].flag_cycle();
137+
// Mixed cycles are not allowed. For more information about this
138+
// see the corresponding section in the coinduction chapter:
139+
// https://rust-lang.github.io/chalk/book/recursive/coinduction.html#mixed-co-inductive-and-inductive-cycles
140+
if self.stack.mixed_inductive_coinductive_cycle_from(depth) {
141+
return solver_stuff.error_value();
142+
}
143+
}
144+
145+
minimums.update_from(self.search_graph[dfn].links);
146+
147+
// Return the solution from the table.
148+
let previous_solution = self.search_graph[dfn].solution.clone();
149+
info!(
150+
"solve_goal: cycle detected, previous solution {:?}",
151+
previous_solution,
152+
);
153+
previous_solution
154+
} else {
155+
// Otherwise, push the goal onto the stack and create a table.
156+
// The initial result for this table depends on whether the goal is coinductive.
157+
let coinductive_goal = solver_stuff.is_coinductive_goal(goal);
158+
let initial_solution = solver_stuff.initial_value(goal, coinductive_goal);
159+
let depth = self.stack.push(coinductive_goal);
160+
let dfn = self.search_graph.insert(&goal, depth, initial_solution);
161+
162+
let subgoal_minimums = self.solve_new_subgoal(&goal, depth, dfn, solver_stuff);
163+
164+
self.search_graph[dfn].links = subgoal_minimums;
165+
self.search_graph[dfn].stack_depth = None;
166+
self.stack.pop(depth);
167+
minimums.update_from(subgoal_minimums);
168+
169+
// Read final result from table.
170+
let result = self.search_graph[dfn].solution.clone();
171+
172+
// If processing this subgoal did not involve anything
173+
// outside of its subtree, then we can promote it to the
174+
// cache now. This is a sort of hack to alleviate the
175+
// worst of the repeated work that we do during tabling.
176+
if subgoal_minimums.positive >= dfn {
177+
if let Some(cache) = &mut self.cache {
178+
self.search_graph.move_to_cache(dfn, cache);
179+
debug!("solve_reduced_goal: SCC head encountered, moving to cache");
180+
} else {
181+
debug!(
182+
"solve_reduced_goal: SCC head encountered, rolling back as caching disabled"
183+
);
184+
self.search_graph.rollback_to(dfn);
185+
}
186+
}
187+
188+
info!("solve_goal: solution = {:?}", result);
189+
result
190+
}
191+
}
192+
193+
#[instrument(level = "debug", skip(self, solver_stuff))]
194+
fn solve_new_subgoal(
195+
&mut self,
196+
canonical_goal: &K,
197+
depth: StackDepth,
198+
dfn: DepthFirstNumber,
199+
solver_stuff: impl SolverStuff<K, V>,
200+
) -> Minimums {
201+
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
202+
// `answer` will be updated with the result of the solving process. If we detect a cycle
203+
// during the solving process, we cache `answer` and try to solve the goal again. We repeat
204+
// until we reach a fixed point for `answer`.
205+
// Considering the partial order:
206+
// - None < Some(Unique) < Some(Ambiguous)
207+
// - None < Some(CannotProve)
208+
// the function which maps the loop iteration to `answer` is a nondecreasing function
209+
// so this function will eventually be constant and the loop terminates.
210+
loop {
211+
let minimums = &mut Minimums::new();
212+
let current_answer = solver_stuff.solve_iteration(self, &canonical_goal, minimums);
213+
214+
debug!(
215+
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",
216+
current_answer, minimums
217+
);
218+
219+
if !self.stack[depth].read_and_reset_cycle_flag() {
220+
// None of our subgoals depended on us directly.
221+
// We can return.
222+
self.search_graph[dfn].solution = current_answer;
223+
return *minimums;
224+
}
225+
226+
let old_answer =
227+
std::mem::replace(&mut self.search_graph[dfn].solution, current_answer);
228+
229+
if solver_stuff.reached_fixed_point(&old_answer, &self.search_graph[dfn].solution) {
230+
return *minimums;
231+
}
232+
233+
// Otherwise: rollback the search tree and try again.
234+
self.search_graph.rollback_to(dfn + 1);
235+
}
236+
}
237+
}
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use rustc_hash::FxHashMap;
2+
use std::fmt::Debug;
3+
use std::hash::Hash;
4+
use std::sync::{Arc, Mutex};
5+
use tracing::debug;
6+
use tracing::instrument;
7+
/// The "cache" stores results for goals that we have completely solved.
8+
/// Things are added to the cache when we have completely processed their
9+
/// result, and it can be shared amongst many solvers.
10+
pub struct Cache<K, V>
11+
where
12+
K: Hash + Eq + Debug,
13+
V: Debug + Clone,
14+
{
15+
data: Arc<Mutex<CacheData<K, V>>>,
16+
}
17+
struct CacheData<K, V>
18+
where
19+
K: Hash + Eq + Debug,
20+
V: Debug + Clone,
21+
{
22+
cache: FxHashMap<K, V>,
23+
}
24+
25+
impl<K, V> Cache<K, V>
26+
where
27+
K: Hash + Eq + Debug,
28+
V: Debug + Clone,
29+
{
30+
pub fn new() -> Self {
31+
Self::default()
32+
}
33+
34+
/// Record a cache result.
35+
#[instrument(skip(self))]
36+
pub fn insert(&self, goal: K, result: V) {
37+
let mut data = self.data.lock().unwrap();
38+
data.cache.insert(goal, result);
39+
}
40+
41+
/// Record a cache result.
42+
pub fn get(&self, goal: &K) -> Option<V> {
43+
let data = self.data.lock().unwrap();
44+
if let Some(result) = data.cache.get(&goal) {
45+
debug!(?goal, ?result, "Cache hit");
46+
Some(result.clone())
47+
} else {
48+
debug!(?goal, "Cache miss");
49+
None
50+
}
51+
}
52+
}
53+
54+
impl<K, V> Clone for Cache<K, V>
55+
where
56+
K: Hash + Eq + Debug,
57+
V: Debug + Clone,
58+
{
59+
fn clone(&self) -> Self {
60+
Self {
61+
data: self.data.clone(),
62+
}
63+
}
64+
}
65+
66+
impl<K, V> Default for Cache<K, V>
67+
where
68+
K: Hash + Eq + Debug,
69+
V: Debug + Clone,
70+
{
71+
fn default() -> Self {
72+
Self {
73+
data: Default::default(),
74+
}
75+
}
76+
}
77+
78+
impl<K, V> Default for CacheData<K, V>
79+
where
80+
K: Hash + Eq + Debug,
81+
V: Debug + Clone,
82+
{
83+
fn default() -> Self {
84+
Self {
85+
cache: Default::default(),
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)