@@ -42,6 +42,7 @@ Algorithm for extracting canonical form from an E-graph:
42
42
#include " ast/ast_pp.h"
43
43
#include " ast/ast_util.h"
44
44
#include " ast/euf/euf_egraph.h"
45
+ #include " ast/rewriter/var_subst.h"
45
46
#include " ast/simplifiers/euf_completion.h"
46
47
#include " ast/shared_occs.h"
47
48
@@ -50,6 +51,7 @@ namespace euf {
50
51
completion::completion (ast_manager& m, dependent_expr_state& fmls):
51
52
dependent_expr_simplifier (m, fmls),
52
53
m_egraph (m),
54
+ m_mam (mam::mk(*this , *this )),
53
55
m_canonical (m),
54
56
m_eargs (m),
55
57
m_deps (m),
@@ -58,6 +60,19 @@ namespace euf {
58
60
m_ff = m_egraph.mk (m.mk_false (), 0 , 0 , nullptr );
59
61
m_rewriter.set_order_eq (true );
60
62
m_rewriter.set_flat_and_or (false );
63
+
64
+ std::function<void (euf::enode*, euf::enode*)> _on_merge =
65
+ [&](euf::enode* root, euf::enode* other) {
66
+ m_mam->on_merge (root, other);
67
+ };
68
+
69
+ std::function<void (euf::enode*)> _on_make =
70
+ [&](euf::enode* n) {
71
+ m_mam->add_node (n, false );
72
+ };
73
+
74
+ m_egraph.set_on_merge (_on_merge);
75
+ m_egraph.set_on_make (_on_make);
61
76
}
62
77
63
78
void completion::reduce () {
@@ -75,33 +90,67 @@ namespace euf {
75
90
void completion::add_egraph () {
76
91
m_nodes_to_canonize.reset ();
77
92
unsigned sz = qtail ();
93
+
94
+ for (unsigned i = qhead (); i < sz; ++i) {
95
+ auto [f, p, d] = m_fmls[i]();
96
+ add_constraint (f, d);
97
+ }
98
+ m_should_propagate = true ;
99
+ while (m_should_propagate) {
100
+ m_should_propagate = false ;
101
+ m_egraph.propagate ();
102
+ m_mam->propagate ();
103
+ }
104
+ }
105
+
106
+ void completion::add_constraint (expr* f, expr_dependency* d) {
78
107
auto add_children = [&](enode* n) {
79
108
for (auto * ch : enode_args (n))
80
109
m_nodes_to_canonize.push_back (ch);
81
110
};
82
-
83
- for (unsigned i = qhead (); i < sz; ++i) {
84
- expr* x, * y;
85
- auto [f, p, d] = m_fmls[i]();
86
- if (m.is_eq (f, x, y)) {
87
- enode* a = mk_enode (x);
88
- enode* b = mk_enode (y);
89
- m_egraph.merge (a, b, d);
90
- add_children (a);
91
- add_children (b);
92
- }
93
- else if (m.is_not (f, f)) {
94
- enode* n = mk_enode (f);
95
- m_egraph.merge (n, m_ff, d);
96
- add_children (n);
97
- }
98
- else {
99
- enode* n = mk_enode (f);
100
- m_egraph.merge (n, m_tt, d);
101
- add_children (n);
111
+ expr* x, * y;
112
+ if (m.is_eq (f, x, y)) {
113
+ enode* a = mk_enode (x);
114
+ enode* b = mk_enode (y);
115
+ m_egraph.merge (a, b, d);
116
+ add_children (a);
117
+ add_children (b);
118
+ }
119
+ else if (m.is_not (f, f)) {
120
+ enode* n = mk_enode (f);
121
+ m_egraph.merge (n, m_ff, d);
122
+ add_children (n);
123
+ }
124
+ else {
125
+ enode* n = mk_enode (f);
126
+ m_egraph.merge (n, m_tt, d);
127
+ add_children (n);
128
+ if (is_forall (f)) {
129
+ quantifier* q = to_quantifier (f);
130
+ ptr_vector<app> ground;
131
+ for (unsigned i = 0 ; i < q->get_num_patterns (); ++i) {
132
+ auto p = to_app (q->get_pattern (i));
133
+ mam::ground_subterms (p, ground);
134
+ for (expr* g : ground)
135
+ mk_enode (g);
136
+ m_mam->add_pattern (q, p);
137
+ }
138
+ if (!get_dependency (q)) {
139
+ m_q2dep.insert (q, d);
140
+ get_trail ().push (insert_obj_map (m_q2dep, q));
141
+ }
102
142
}
103
143
}
104
- m_egraph.propagate ();
144
+ }
145
+
146
+ void completion::on_binding (quantifier* q, app* pat, enode* const * binding, unsigned mg, unsigned ming, unsigned mx) {
147
+ var_subst subst (m);
148
+ expr_ref_vector _binding (m);
149
+ for (unsigned i = 0 ; i < q->get_num_decls (); ++i)
150
+ _binding.push_back (binding[i]->get_expr ());
151
+ expr_ref r = subst (q->get_expr (), _binding);
152
+ add_constraint (r, get_dependency (q));
153
+ m_should_propagate = true ;
105
154
}
106
155
107
156
void completion::read_egraph () {
0 commit comments