Skip to content

Commit 7f9747a

Browse files
committed
Handle multiple definitions of the same relation properly
So that the last definition wins, including being the one whose span is used in generated code
1 parent a1e3997 commit 7f9747a

File tree

3 files changed

+97
-14
lines changed

3 files changed

+97
-14
lines changed

ascent_macro/src/ascent_hir.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::ascent_syntax::{
1313
RuleNode, Signatures,
1414
};
1515
use crate::syn_utils::{expr_get_vars, pattern_get_vars};
16-
use crate::utils::{expr_to_ident, is_wild_card, tuple_type};
16+
use crate::utils::{dedup_all_keep_last_by, expr_to_ident, is_wild_card, tuple_type};
1717

1818
#[derive(Clone)]
1919
pub(crate) struct AscentConfig {
@@ -233,8 +233,12 @@ pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: b
233233
let mut relations_metadata = HashMap::with_capacity(num_relations);
234234
// let mut relations_no_indices = HashMap::new();
235235
let mut lattices_full_indices = HashMap::new();
236-
for rel in prog.relations.iter() {
237-
let rel_identity = RelationIdentity::from(rel);
236+
237+
let mut rel_identities = prog.relations.iter().map(|rel| (rel, RelationIdentity::from(rel))).collect_vec();
238+
dedup_all_keep_last_by(&mut rel_identities, |x, y| RelationIdentity::eq(&x.1, &y.1));
239+
240+
for (rel, rel_identity) in rel_identities {
241+
let ds_attribute = get_ds_attr(&rel.attrs)?;
238242

239243
if rel.is_lattice {
240244
let indices = (0..rel_identity.field_types.len() - 1).collect_vec();
@@ -252,9 +256,8 @@ pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: b
252256
if let Some(init_expr) = &rel.initialization {
253257
relations_initializations.insert(rel_identity.clone(), Rc::new(init_expr.clone()));
254258
}
255-
let ds_attribute = get_ds_attr(&rel.attrs)?;
256-
257-
let ds_attribute = match (ds_attribute, rel.is_lattice) {
259+
260+
let ds_attr = match (ds_attribute, rel.is_lattice) {
258261
(None, true) => None,
259262
(None, false) => Some(config.default_ds.clone()),
260263
(Some(attr), true) =>
@@ -273,7 +276,7 @@ pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: b
273276
.cloned()
274277
.collect_vec(),
275278
),
276-
ds_attr: ds_attribute,
279+
ds_attr,
277280
});
278281
// relations_no_indices.insert(rel_identity, rel_no_index);
279282
}
@@ -449,13 +452,9 @@ fn compile_rule_to_ir_rule(rule: &RuleNode, prog: &AscentProgram) -> syn::Result
449452
let mut head_clauses = vec![];
450453
for hcl_node in rule.head_clauses.iter() {
451454
let hcl_node = hcl_node.clause();
452-
let rel = prog.relations.iter().find(|r| hcl_node.rel == r.name);
453-
let rel = match rel {
454-
Some(rel) => rel,
455-
None => return Err(Error::new(hcl_node.rel.span(), format!("relation `{}` is not defined", hcl_node.rel))),
456-
};
457-
455+
let rel = prog_get_relation(prog, &hcl_node.rel, hcl_node.args.len())?;
458456
let rel = RelationIdentity::from(rel);
457+
459458
let head_clause = IrHeadClause {
460459
rel,
461460
args: hcl_node.args.iter().cloned().collect(),
@@ -514,7 +513,7 @@ pub fn get_indices_given_grounded_variables(args: &[Expr], vars: &[Ident]) -> Ve
514513
pub(crate) fn prog_get_relation<'a>(
515514
prog: &'a AscentProgram, name: &Ident, arity: usize,
516515
) -> syn::Result<&'a RelationNode> {
517-
let relation = prog.relations.iter().find(|r| name == &r.name);
516+
let relation = prog.relations.iter().rev().find(|r| name == &r.name);
518517
match relation {
519518
Some(rel) =>
520519
if rel.field_types.len() != arity {

ascent_macro/src/utils.rs

+67
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,70 @@ fn test_subsumes_and_intersects() {
274274
assert_eq!(intersects(s1.iter(), s2.iter()), intersects_expected);
275275
}
276276
}
277+
278+
#[allow(unused)]
279+
pub fn dedup_all<T: Eq>(vec: &mut Vec<T>) {
280+
let mut delete = vec![false; vec.len()];
281+
for i in 0..vec.len() {
282+
if delete[i] {
283+
continue;
284+
}
285+
for j in i + 1..vec.len() {
286+
if vec[j] == vec[i] {
287+
delete[j] = true;
288+
}
289+
}
290+
}
291+
let mut delete = delete.into_iter();
292+
vec.retain(|_| !delete.next().unwrap());
293+
}
294+
295+
#[allow(unused)]
296+
pub fn dedup_all_keep_last_by<T, F: Fn(&T, &T) -> bool>(vec: &mut Vec<T>, compare: F) {
297+
let mut delete = vec![false; vec.len()];
298+
for i in (0..vec.len()).rev() {
299+
if delete[i] {
300+
continue;
301+
}
302+
for j in (0..i).rev() {
303+
if compare(&vec[j], &vec[i]) {
304+
delete[j] = true;
305+
}
306+
}
307+
}
308+
let mut delete = delete.into_iter();
309+
vec.retain(|_| !delete.next().unwrap());
310+
}
311+
312+
#[allow(unused)]
313+
pub fn dedup_all_keep_last<T: Eq>(vec: &mut Vec<T>) { dedup_all_keep_last_by(vec, |x, y| x == y); }
314+
315+
#[test]
316+
fn test_dedup_all() {
317+
let test_cases = [
318+
(vec![], vec![]),
319+
(vec![1], vec![1]),
320+
(vec![1, 1], vec![1]),
321+
(vec![1, 2, 2, 3, 1, 1, 4, 5, 6, 3, 2], vec![1, 2, 3, 4, 5, 6]),
322+
(vec![1, 1, 2, 2, 1, 3, 3, 3, 4], vec![1, 2, 3, 4]),
323+
];
324+
for (mut inp, expected_out) in test_cases {
325+
dedup_all(&mut inp);
326+
assert_eq!(inp, expected_out);
327+
}
328+
}
329+
330+
#[test]
331+
fn test_dedup_all_keep_last() {
332+
let test_cases = [
333+
(vec![], vec![]),
334+
(vec![1], vec![1]),
335+
(vec![1, 1], vec![1]),
336+
(vec![1, 2, 2, 3, 1, 1, 4, 5, 6, 3, 2], vec![1, 4, 5, 6, 3, 2]),
337+
(vec![1, 1, 2, 2, 1, 3, 3, 3, 4], vec![2, 1, 3, 4]),
338+
];
339+
for (mut inp, expected_out) in test_cases {
340+
dedup_all_keep_last(&mut inp);
341+
assert_eq!(inp, expected_out);
342+
}
343+
}

ascent_tests/src/tests.rs

+17
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,20 @@ fn test_rel_empty_check() {
992992
println!("{:?}", res.path);
993993
assert_eq!(res.path.len(), 9 * 10 / 2);
994994
}
995+
996+
#[test]
997+
fn test_multiple_rel_definitions() {
998+
// When there are multiple definitions of a relation that agree on arity and column types,
999+
// the last one wins
1000+
let res = ascent_run! {
1001+
relation r1(usize);
1002+
relation r1(usize) = vec![(1,), (2,)];
1003+
1004+
relation r2(usize) = vec![(3,), (4,)];
1005+
relation r2(usize);
1006+
1007+
r2(x) <-- r1(x);
1008+
};
1009+
1010+
assert_rels_eq!(res.r2, vec![(1,), (2,)]);
1011+
}

0 commit comments

Comments
 (0)