Skip to content

Commit 799b613

Browse files
avoid repeated internalization of lambda #4169
Signed-off-by: Nikolaj Bjorner <[email protected]>
1 parent 7ae2047 commit 799b613

8 files changed

+88
-31
lines changed

src/smt/params/smt_params.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void smt_params::updt_local_params(params_ref const & _p) {
2626
m_random_seed = p.random_seed();
2727
m_relevancy_lvl = p.relevancy();
2828
m_ematching = p.ematching();
29+
m_induction = p.induction();
2930
m_clause_proof = p.clause_proof();
3031
m_phase_selection = static_cast<phase_selection>(p.phase_selection());
3132
if (m_phase_selection > PS_THEORY) throw default_exception("illegal phase selection numeral");
@@ -111,6 +112,7 @@ void smt_params::display(std::ostream & out) const {
111112
DISPLAY_PARAM(m_display_features);
112113
DISPLAY_PARAM(m_new_core2th_eq);
113114
DISPLAY_PARAM(m_ematching);
115+
DISPLAY_PARAM(m_induction);
114116
DISPLAY_PARAM(m_clause_proof);
115117

116118
DISPLAY_PARAM(m_case_split_strategy);

src/smt/params/smt_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ struct smt_params : public preprocessor_params,
110110
bool m_display_features;
111111
bool m_new_core2th_eq;
112112
bool m_ematching;
113+
bool m_induction;
113114
bool m_clause_proof;
114115

115116
// -----------------------------------
@@ -262,6 +263,7 @@ struct smt_params : public preprocessor_params,
262263
m_display_features(false),
263264
m_new_core2th_eq(true),
264265
m_ematching(true),
266+
m_induction(false),
265267
m_clause_proof(false),
266268
m_case_split_strategy(CS_ACTIVITY_DELAY_NEW),
267269
m_rel_case_split_order(0),

src/smt/params/smt_params_helper.pyg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def_module_params(module_name='smt',
3737
('qi.cost', STRING, '(+ weight generation)', 'expression specifying what is the cost of a given quantifier instantiation'),
3838
('qi.max_multi_patterns', UINT, 0, 'specify the number of extra multi patterns'),
3939
('qi.quick_checker', UINT, 0, 'specify quick checker mode, 0 - no quick checker, 1 - using unsat instances, 2 - using both unsat and no-sat instances'),
40+
('induction', BOOL, False, 'enable generation of induction lemmas'),
4041
('bv.reflect', BOOL, True, 'create enode for every bit-vector term'),
4142
('bv.enable_int2bv', BOOL, True, 'enable support for int2bv and bv2int operators'),
4243
('arith.random_initial_value', BOOL, False, 'use random initial values in the simplex-based procedure for linear arithmetic'),

src/smt/smt_induction.cpp

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ literal_vector collect_induction_literals::pre_select() {
5555
continue;
5656
result.push_back(lit);
5757
}
58+
TRACE("induction", ctx.display(tout << "literal index: " << m_literal_index << "\n" << result << "\n"););
59+
5860
ctx.push_trail(value_trail<context, unsigned>(m_literal_index));
5961
m_literal_index = ctx.assigned_literals().size();
6062
return result;
@@ -68,11 +70,6 @@ void collect_induction_literals::model_sweep_filter(literal_vector& candidates)
6870
vector<expr_ref_vector> values;
6971
vs(terms, values);
7072
unsigned j = 0;
71-
IF_VERBOSE(1,
72-
verbose_stream() << "terms: " << terms << "\n";
73-
for (auto const& vec : values) {
74-
verbose_stream() << "assignment: " << vec << "\n";
75-
});
7673
for (unsigned i = 0; i < terms.size(); ++i) {
7774
literal lit = candidates[i];
7875
bool is_viable_candidate = true;
@@ -109,14 +106,26 @@ literal_vector collect_induction_literals::operator()() {
109106
// create_induction_lemmas
110107

111108
bool create_induction_lemmas::is_induction_candidate(enode* n) {
112-
expr* e = n->get_owner();
109+
app* e = n->get_owner();
113110
if (m.is_value(e))
114111
return false;
115-
// TBD: filter if n is equivalent to a value.
112+
bool in_good_context = false;
113+
for (enode* p : n->get_parents()) {
114+
app* o = p->get_owner();
115+
if (o->get_family_id() != m.get_basic_family_id())
116+
in_good_context = true;
117+
}
118+
if (!in_good_context)
119+
return false;
120+
121+
// avoid recursively unfolding skolem terms.
122+
if (e->get_num_args() > 0 && e->get_family_id() == null_family_id) {
123+
return false;
124+
}
116125
sort* s = m.get_sort(e);
117126
if (m_dt.is_datatype(s) && m_dt.is_recursive(s))
118127
return true;
119-
128+
120129
// potentially also induction on integers, sequences
121130
// m_arith.is_int(s)
122131
// return true;
@@ -160,14 +169,24 @@ enode_vector create_induction_lemmas::induction_positions(enode* n) {
160169
* TDD: add depth throttle.
161170
*/
162171
void create_induction_lemmas::abstract(enode* n, enode* t, expr* x, abstractions& result) {
172+
std::cout << "abs: " << result.size() << ": " << mk_pp(n->get_owner(), m) << "\n";
163173
if (n->get_root() == t->get_root()) {
164174
result.push_back(abstraction(m, x, n->get_owner(), t->get_owner()));
165175
}
176+
#if 0
177+
// check if n is a s
178+
if (is_skolem(n->get_owner())) {
179+
result.push_back(abstraction(m, n->get_owner()));
180+
return;
181+
}
182+
#endif
183+
166184
abstraction_args r1, r2;
167185
r1.push_back(abstraction_arg(m));
168186
for (enode* arg : enode::args(n)) {
169187
unsigned n = result.size();
170188
abstract(arg, t, x, result);
189+
std::cout << result.size() << "\n";
171190
for (unsigned i = n; i < result.size(); ++i) {
172191
abstraction& a = result[i];
173192
for (auto const& v : r1) {
@@ -193,7 +212,9 @@ void create_induction_lemmas::filter_abstractions(bool sign, abstractions& abs)
193212
vector<expr_ref_vector> values;
194213
expr_ref_vector fmls(m);
195214
for (auto & a : abs) fmls.push_back(a.m_term);
215+
std::cout << "sweep\n";
196216
vs(fmls, values);
217+
std::cout << "done sweep\n";
197218
unsigned j = 0;
198219
for (unsigned i = 0; i < fmls.size(); ++i) {
199220
bool all_cex = true;
@@ -207,15 +228,17 @@ void create_induction_lemmas::filter_abstractions(bool sign, abstractions& abs)
207228
abs[j++] = abs.get(i);
208229
}
209230
}
231+
std::cout << "resulting size: " << j << " down from " << abs.size() << "\n";
210232
abs.shrink(j);
211233
}
212234

213235
/*
214236
* Create simple induction lemmas of the form:
215237
*
216-
* lit & a.eqs() & is-c(t) => is-c(sk);
217238
* lit & a.eqs() => alpha
218-
* lit & a.eqs() & is-c(t) => ~beta
239+
* alpha & is-c(sk) => ~beta
240+
*
241+
* alpha & is-c(t) => is-c(sk);
219242
*
220243
* where
221244
* lit = is a formula containing t
@@ -242,50 +265,59 @@ void create_induction_lemmas::create_lemmas(expr* t, expr* sk, abstraction& a, l
242265
return;
243266
expr_ref alpha = a.m_term;
244267
auto const& eqs = a.m_eqs;
268+
literal alpha_lit = null_literal;
245269
literal_vector common_literals;
246270
for (func_decl* c : *m_dt.get_datatype_constructors(s)) {
247271
func_decl* is_c = m_dt.get_constructor_recognizer(c);
248272
bool has_1recursive = false;
249273
for (func_decl* acc : *m_dt.get_constructor_accessors(c)) {
250274
if (acc->get_range() != s)
251275
continue;
252-
if (common_literals.empty()) {
253-
common_literals.push_back(~lit);
254-
for (auto const& p : eqs) {
255-
common_literals.push_back(~mk_literal(m.mk_eq(p.first, p.second)));
256-
}
276+
if (alpha_lit == null_literal) {
277+
alpha_lit = mk_literal(alpha);
278+
if (lit.sign()) alpha_lit.neg();
257279
}
258280
has_1recursive = true;
259-
literal_vector lits(common_literals);
260-
lits.push_back(~mk_literal(m.mk_app(is_c, t)));
261281
expr_ref beta(alpha);
262282
expr_safe_replace rep(m);
263283
rep.insert(sk, m.mk_app(acc, sk));
264284
rep(beta);
265285
literal b_lit = mk_literal(beta);
266286
if (lit.sign()) b_lit.neg();
287+
288+
// alpha & is_c(sk) => ~beta
289+
literal_vector lits;
290+
lits.push_back(~alpha_lit);
291+
lits.push_back(~mk_literal(m.mk_app(is_c, sk)));
267292
lits.push_back(~b_lit);
268293
add_th_lemma(lits);
269294
}
295+
296+
// alpha & is_c(t) => is_c(sk)
270297
if (has_1recursive) {
271-
literal_vector lits(common_literals);
298+
literal_vector lits;
299+
lits.push_back(~alpha_lit);
272300
lits.push_back(~mk_literal(m.mk_app(is_c, t)));
273301
lits.push_back(mk_literal(m.mk_app(is_c, sk)));
274302
add_th_lemma(lits);
275303
}
276304
}
277-
if (!common_literals.empty()) {
278-
literal_vector lits(common_literals);
279-
literal a_lit = mk_literal(alpha);
280-
if (lit.sign()) a_lit.neg();
281-
lits.push_back(a_lit);
305+
306+
// phi & eqs => alpha
307+
if (alpha_lit != null_literal) {
308+
literal_vector lits;
309+
lits.push_back(~lit);
310+
for (auto const& p : eqs) {
311+
lits.push_back(~mk_literal(m.mk_eq(p.first, p.second)));
312+
}
313+
lits.push_back(alpha_lit);
282314
add_th_lemma(lits);
283315
}
284316
}
285317

286318
void create_induction_lemmas::add_th_lemma(literal_vector const& lits) {
287319
IF_VERBOSE(1, ctx.display_literals_verbose(verbose_stream() << "lemma:\n", lits) << "\n");
288-
ctx.mk_clause(lits.size(), lits.c_ptr(), nullptr, smt::CLS_TH_LEMMA);
320+
ctx.mk_clause(lits.size(), lits.c_ptr(), nullptr, smt::CLS_TH_AXIOM); // CLS_TH_LEMMA, but then should re-instance if GC'ed
289321
++m_num_lemmas;
290322
}
291323

@@ -301,8 +333,8 @@ literal create_induction_lemmas::mk_literal(expr* e) {
301333
func_decl* create_induction_lemmas::mk_skolem(sort* s) {
302334
func_decl* f = nullptr;
303335
if (!m_sort2skolem.find(s, f)) {
304-
sort* domain[2] = { s, m.mk_bool_sort() };
305-
f = m.mk_fresh_func_decl("sk", 2, domain, s);
336+
sort* domain[3] = { m_a.mk_int(), s, m.mk_bool_sort() };
337+
f = m.mk_fresh_func_decl("sk", 3, domain, s);
306338
m_pinned.push_back(f);
307339
m_pinned.push_back(s);
308340
m_sort2skolem.insert(s, f);
@@ -314,10 +346,11 @@ func_decl* create_induction_lemmas::mk_skolem(sort* s) {
314346
bool create_induction_lemmas::operator()(literal lit) {
315347
unsigned num = m_num_lemmas;
316348
enode* r = ctx.bool_var2enode(lit.var());
349+
unsigned position = 0;
317350
for (enode* n : induction_positions(r)) {
318351
expr* t = n->get_owner();
319352
sort* s = m.get_sort(t);
320-
expr_ref sk(m.mk_app(mk_skolem(s), t, r->get_owner()), m);
353+
expr_ref sk(m.mk_app(mk_skolem(s), m_a.mk_int(position), t, r->get_owner()), m);
321354
std::cout << "abstract " << mk_pp(t, m) << " " << sk << "\n";
322355
abstractions abs;
323356
abstract(r, n, sk, abs);
@@ -326,6 +359,8 @@ bool create_induction_lemmas::operator()(literal lit) {
326359
for (abstraction& a : abs) {
327360
create_lemmas(t, sk, a, lit);
328361
}
362+
std::cout << "lemmas created\n";
363+
++position;
329364
}
330365
return m_num_lemmas > num;
331366
}
@@ -335,6 +370,7 @@ create_induction_lemmas::create_induction_lemmas(context& ctx, ast_manager& m, v
335370
m(m),
336371
vs(vs),
337372
m_dt(m),
373+
m_a(m),
338374
m_pinned(m),
339375
m_num_lemmas(0)
340376
{}

src/smt/smt_induction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "smt/smt_types.h"
2121
#include "ast/rewriter/value_sweep.h"
2222
#include "ast/datatype_decl_plugin.h"
23+
#include "ast/arith_decl_plugin.h"
2324

2425
namespace smt {
2526

@@ -52,6 +53,7 @@ namespace smt {
5253
ast_manager& m;
5354
value_sweep& vs;
5455
datatype::util m_dt;
56+
arith_util m_a;
5557
obj_map<sort, func_decl*> m_sort2skolem;
5658
ast_ref_vector m_pinned;
5759
unsigned m_num_lemmas;

src/smt/smt_internalizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,9 @@ namespace smt {
558558
void context::internalize_lambda(quantifier * q) {
559559
TRACE("internalize_quantifier", tout << mk_pp(q, m) << "\n";);
560560
SASSERT(is_lambda(q));
561+
if (e_internalized(q)) {
562+
return;
563+
}
561564
app_ref lam_name(m.mk_fresh_const("lambda", m.get_sort(q)), m);
562565
app_ref eq(m), lam_app(m);
563566
expr_ref_vector vars(m);

src/smt/theory_recfun.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Revision History:
2121
#include "ast/ast_util.h"
2222
#include "ast/for_each_expr.h"
2323
#include "smt/theory_recfun.h"
24-
#include "smt/params/smt_params_helper.hpp"
2524

2625

2726
#define TRACEFN(x) TRACE("recfun", tout << x << '\n';)
@@ -357,6 +356,7 @@ namespace smt {
357356
literal_vector preds;
358357
auto & vars = e.m_def->get_vars();
359358

359+
unsigned max_depth = 0;
360360
for (recfun::case_def const & c : e.m_def->get_cases()) {
361361
// applied predicate to `args`
362362
app_ref pred_applied = c.apply_case_predicate(e.m_args);
@@ -376,15 +376,25 @@ namespace smt {
376376
}
377377
else if (!is_enabled_guard(pred_applied)) {
378378
disable_guard(pred_applied, guards);
379+
max_depth = std::max(depth, max_depth);
379380
continue;
380381
}
381382
activate_guard(pred_applied, guards);
382383
}
383384
// the disjunction of branches is asserted
384385
// to close the available cases.
385-
std::function<literal_vector(void)> fn2 = [&]() { return preds; };
386-
scoped_trace_stream _tr2(*this, fn2);
387-
ctx().mk_th_axiom(get_id(), preds);
386+
{
387+
scoped_trace_stream _tr2(*this, preds);
388+
ctx().mk_th_axiom(get_id(), preds);
389+
}
390+
(void)max_depth;
391+
// add_induction_lemmas(max_depth);
392+
}
393+
394+
void theory_recfun::add_induction_lemmas(unsigned depth) {
395+
if (depth > 4 && ctx().get_fparams().m_induction && induction::should_try(ctx())) {
396+
ctx().get_induction()();
397+
}
388398
}
389399

390400
void theory_recfun::activate_guard(expr* pred_applied, expr_ref_vector const& guards) {

src/smt/theory_recfun.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ namespace smt {
127127
void assert_body_axiom(body_expansion & e);
128128
literal mk_literal(expr* e);
129129

130+
void add_induction_lemmas(unsigned depth);
130131
void disable_guard(expr* guard, expr_ref_vector const& guards);
131132
unsigned get_depth(expr* e);
132133
void set_depth(unsigned d, expr* e);

0 commit comments

Comments
 (0)