Skip to content

Commit 48712b4

Browse files
Add initial value setting for variables in Z3 API, solver, and optimize modules
1 parent 0ba306e commit 48712b4

31 files changed

+297
-9
lines changed

src/api/api_opt.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,21 @@ extern "C" {
459459
Z3_CATCH;
460460
}
461461

462-
462+
void Z3_API Z3_optimize_set_initial_value(Z3_context c, Z3_optimize o, Z3_ast var, Z3_ast value) {
463+
Z3_TRY;
464+
LOG_Z3_optimize_set_initial_value(c, o, var, value);
465+
RESET_ERROR_CODE();
466+
if (to_expr(var)->get_sort() != to_expr(value)->get_sort()) {
467+
SET_ERROR_CODE(Z3_INVALID_USAGE, "variable and value should have same sort");
468+
return;
469+
}
470+
ast_manager& m = mk_c(c)->m();
471+
if (!m.is_value(to_expr(value))) {
472+
SET_ERROR_CODE(Z3_INVALID_USAGE, "a proper value was not supplied");
473+
return;
474+
}
475+
to_optimize_ptr(o)->initialize_value(to_expr(var), to_expr(value));
476+
Z3_CATCH;
477+
}
463478

464479
};

src/api/api_solver.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -1143,5 +1143,23 @@ extern "C" {
11431143
Z3_CATCH_RETURN(nullptr);
11441144
}
11451145

1146+
void Z3_API Z3_solver_set_initial_value(Z3_context c, Z3_solver s, Z3_ast var, Z3_ast value) {
1147+
Z3_TRY;
1148+
LOG_Z3_solver_set_initial_value(c, s, var, value);
1149+
RESET_ERROR_CODE();
1150+
if (to_expr(var)->get_sort() != to_expr(value)->get_sort()) {
1151+
SET_ERROR_CODE(Z3_INVALID_USAGE, "variable and value should have same sort");
1152+
return;
1153+
}
1154+
ast_manager& m = mk_c(c)->m();
1155+
if (!m.is_value(to_expr(value))) {
1156+
SET_ERROR_CODE(Z3_INVALID_USAGE, "a proper value was not supplied");
1157+
return;
1158+
}
1159+
to_solver_ref(s)->user_propagate_initialize_value(to_expr(var), to_expr(value));
1160+
Z3_CATCH;
1161+
}
1162+
1163+
11461164

11471165
};

src/api/c++/z3++.h

+22
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,17 @@ namespace z3 {
28652865
check_error();
28662866
return result;
28672867
}
2868+
void set_initial_value(expr const& var, expr const& value) {
2869+
Z3_solver_set_initial_value(ctx(), m_solver, var, value);
2870+
check_error();
2871+
}
2872+
void set_initial_value(expr const& var, int i) {
2873+
set_initial_value(var, ctx().num_val(i, var.get_sort()));
2874+
}
2875+
void set_initial_value(expr const& var, bool b) {
2876+
set_initial_value(var, ctx().bool_val(b));
2877+
}
2878+
28682879
expr proof() const { Z3_ast r = Z3_solver_get_proof(ctx(), m_solver); check_error(); return expr(ctx(), r); }
28692880
friend std::ostream & operator<<(std::ostream & out, solver const & s);
28702881

@@ -3330,6 +3341,17 @@ namespace z3 {
33303341
handle add(expr const& e, unsigned weight) {
33313342
return add_soft(e, weight);
33323343
}
3344+
void set_initial_value(expr const& var, expr const& value) {
3345+
Z3_optimize_set_initial_value(ctx(), m_opt, var, value);
3346+
check_error();
3347+
}
3348+
void set_initial_value(expr const& var, int i) {
3349+
set_initial_value(var, ctx().num_val(i, var.get_sort()));
3350+
}
3351+
void set_initial_value(expr const& var, bool b) {
3352+
set_initial_value(var, ctx().bool_val(b));
3353+
}
3354+
33333355
handle maximize(expr const& e) {
33343356
return handle(Z3_optimize_maximize(ctx(), m_opt, e));
33353357
}

src/api/python/z3/z3.py

+14
Original file line numberDiff line numberDiff line change
@@ -7353,6 +7353,13 @@ def trail_levels(self):
73537353
Z3_solver_get_levels(self.ctx.ref(), self.solver, trail.vector, len(trail), levels)
73547354
return trail, levels
73557355

7356+
def set_initial_value(self, var, value):
7357+
"""initialize the solver's state by setting the initial value of var to value
7358+
"""
7359+
s = var.sort()
7360+
value = s.cast(value)
7361+
Z3_solver_set_initial_value(self.ctx.ref(), self.solver, var.ast, value.ast)
7362+
73567363
def trail(self):
73577364
"""Return trail of the solver state after a check() call.
73587365
"""
@@ -8032,6 +8039,13 @@ def asoft(a):
80328039
return [asoft(a) for a in arg]
80338040
return asoft(arg)
80348041

8042+
def set_initial_value(self, var, value):
8043+
"""initialize the solver's state by setting the initial value of var to value
8044+
"""
8045+
s = var.sort()
8046+
value = s.cast(value)
8047+
Z3_optimize_set_initial_value(self.ctx.ref(), self.optimize, var.ast, value.ast)
8048+
80358049
def maximize(self, arg):
80368050
"""Add objective function to maximize."""
80378051
return OptimizeObjective(

src/api/z3_api.h

+12
Original file line numberDiff line numberDiff line change
@@ -7241,6 +7241,18 @@ extern "C" {
72417241

72427242
bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq);
72437243

7244+
7245+
/**
7246+
\brief provide an initialization hint to the solver. The initialization hint is used to calibrate an initial value of the expression that
7247+
represents a variable. If the variable is Boolean, the initial phase is set according to \c value. If the variable is an integer or real,
7248+
the initial Simplex tableau is recalibrated to attempt to follow the value assignment.
7249+
7250+
def_API('Z3_solver_set_initial_value', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST), _in(AST)))
7251+
*/
7252+
7253+
void Z3_API Z3_solver_set_initial_value(Z3_context c, Z3_solver s, Z3_ast var, Z3_ast value);
7254+
7255+
72447256
/**
72457257
\brief Check whether the assertions in a given solver are consistent or not.
72467258

src/api/z3_optimization.h

+12
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ extern "C" {
139139
*/
140140
void Z3_API Z3_optimize_pop(Z3_context c, Z3_optimize d);
141141

142+
/**
143+
\brief provide an initialization hint to the solver.
144+
The initialization hint is used to calibrate an initial value of the expression that
145+
represents a variable. If the variable is Boolean, the initial phase is set
146+
according to \c value. If the variable is an integer or real,
147+
the initial Simplex tableau is recalibrated to attempt to follow the value assignment.
148+
149+
def_API('Z3_optimize_set_initial_value', VOID, (_in(CONTEXT), _in(OPTIMIZE), _in(AST), _in(AST)))
150+
*/
151+
152+
void Z3_API Z3_optimize_set_initial_value(Z3_context c, Z3_optimize o, Z3_ast var, Z3_ast value);
153+
142154
/**
143155
\brief Check consistency and produce optimal values.
144156
\param c - context

src/math/lp/lar_solver.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -2081,6 +2081,24 @@ namespace lp {
20812081
lpvar lar_solver::to_column(unsigned ext_j) const {
20822082
return m_var_register.external_to_local(ext_j);
20832083
}
2084+
2085+
bool lar_solver::move_lpvar_to_value(lpvar j, mpq const& value) {
2086+
if (is_base(j))
2087+
return false;
2088+
2089+
impq ivalue(value);
2090+
auto& lcs = m_mpq_lar_core_solver;
2091+
auto& slv = m_mpq_lar_core_solver.m_r_solver;
2092+
2093+
if (slv.column_has_upper_bound(j) && lcs.m_r_upper_bounds()[j] < ivalue)
2094+
return false;
2095+
if (slv.column_has_lower_bound(j) && lcs.m_r_lower_bounds()[j] > ivalue)
2096+
return false;
2097+
2098+
set_value_for_nbasic_column(j, ivalue);
2099+
return true;
2100+
}
2101+
20842102

20852103
bool lar_solver::tighten_term_bounds_by_delta(lpvar j, const impq& delta) {
20862104
SASSERT(column_has_term(j));

src/math/lp/lar_solver.h

+1
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ class lar_solver : public column_namer {
623623
lp_status find_feasible_solution();
624624
void move_non_basic_columns_to_bounds();
625625
bool move_non_basic_column_to_bounds(unsigned j);
626+
bool move_lpvar_to_value(lpvar j, mpq const& value);
626627
inline bool r_basis_has_inf_int() const {
627628
for (unsigned j : r_basis()) {
628629
if (column_is_int(j) && !column_value_is_int(j))

src/opt/opt_context.cpp

+12-7
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ namespace opt {
5858
}
5959

6060
void context::scoped_state::pop() {
61-
m_hard.resize(m_hard_lim.back());
62-
m_asms.resize(m_asms_lim.back());
61+
m_hard.shrink(m_hard_lim.back());
62+
m_asms.shrink(m_asms_lim.back());
63+
m_values.shrink(m_values_lim.back());
6364
unsigned k = m_objectives_term_trail_lim.back();
6465
while (m_objectives_term_trail.size() > k) {
6566
unsigned idx = m_objectives_term_trail.back();
@@ -79,6 +80,7 @@ namespace opt {
7980
m_objectives_lim.pop_back();
8081
m_hard_lim.pop_back();
8182
m_asms_lim.pop_back();
83+
m_values_lim.pop_back();
8284
}
8385

8486
void context::scoped_state::add(expr* hard) {
@@ -306,13 +308,11 @@ namespace opt {
306308
if (contains_quantifiers()) {
307309
warning_msg("optimization with quantified constraints is not supported");
308310
}
309-
#if 0
310-
if (is_qsat_opt()) {
311-
return run_qsat_opt();
312-
}
313-
#endif
314311
solver& s = get_solver();
315312
s.assert_expr(m_hard_constraints);
313+
for (auto const& [var, value] : m_scoped_state.m_values) {
314+
s.user_propagate_initialize_value(var, value);
315+
}
316316

317317
opt_params optp(m_params);
318318
symbol pri = optp.priority();
@@ -697,6 +697,11 @@ namespace opt {
697697
}
698698
}
699699

700+
void context::initialize_value(expr* var, expr* value) {
701+
m_scoped_state.m_values.push_back({expr_ref(var, m), expr_ref(value, m)});
702+
}
703+
704+
700705
/**
701706
* Set the solver to the SAT core.
702707
* It requres:

src/opt/opt_context.h

+4
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,14 @@ namespace opt {
140140
unsigned_vector m_objectives_lim;
141141
unsigned_vector m_objectives_term_trail;
142142
unsigned_vector m_objectives_term_trail_lim;
143+
unsigned_vector m_values_lim;
143144
map_id m_indices;
144145

145146
public:
146147
expr_ref_vector m_hard;
147148
expr_ref_vector m_asms;
148149
vector<objective> m_objectives;
150+
vector<std::pair<expr_ref, expr_ref>> m_values;
149151

150152
scoped_state(ast_manager& m):
151153
m(m),
@@ -275,6 +277,8 @@ namespace opt {
275277

276278
void add_offset(unsigned id, rational const& o) override;
277279

280+
void initialize_value(expr* var, expr* value);
281+
278282
void register_on_model(on_model_t& ctx, std::function<void(on_model_t&, model_ref&)>& on_model) {
279283
m_on_model_ctx = ctx;
280284
m_on_model_eh = on_model;

src/opt/opt_solver.h

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ namespace opt {
116116
phase* get_phase() override { return m_context.get_phase(); }
117117
void set_phase(phase* p) override { m_context.set_phase(p); }
118118
void move_to_front(expr* e) override { m_context.move_to_front(e); }
119+
void user_propagate_initialize_value(expr* var, expr* value) override { m_context.user_propagate_initialize_value(var, value); }
119120

120121
void set_logic(symbol const& logic);
121122

src/sat/sat_solver/inc_sat_solver.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,10 @@ class inc_sat_solver : public solver {
702702
ensure_euf()->user_propagate_register_decide(r);
703703
}
704704

705+
void user_propagate_initialize_value(expr* var, expr* value) override {
706+
ensure_euf()->user_propagate_initialize_value(var, value);
707+
}
708+
705709

706710
private:
707711

src/sat/sat_solver/sat_smt_solver.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@ class sat_smt_solver : public solver {
577577
ensure_euf()->user_propagate_register_decide(r);
578578
}
579579

580+
void user_propagate_initialize_value(expr* var, expr* value) override {
581+
ensure_euf()->user_propagate_initialize_value(var, value);
582+
}
583+
584+
580585
private:
581586

582587
void add_assumption(expr* a) {

src/sat/smt/euf_solver.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,31 @@ namespace euf {
12561256
add_solver(m_user_propagator);
12571257
}
12581258

1259+
void solver::user_propagate_initialize_value(expr* var, expr* value) {
1260+
if (m.is_bool(var)) {
1261+
auto lit = expr2literal(var);
1262+
if (lit == sat::null_literal) {
1263+
IF_VERBOSE(5, verbose_stream() << "no literal associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
1264+
return;
1265+
}
1266+
if (m.is_true(value))
1267+
s().set_phase(lit);
1268+
else if (m.is_false(value))
1269+
s().set_phase(~lit);
1270+
else
1271+
IF_VERBOSE(5, verbose_stream() << "malformed value " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
1272+
return;
1273+
}
1274+
auto* th = m_id2solver.get(var->get_sort()->get_family_id(), nullptr);
1275+
if (!th) {
1276+
IF_VERBOSE(5, verbose_stream() << "no default initialization associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
1277+
return;
1278+
}
1279+
// th->initialize_value(var, value);
1280+
IF_VERBOSE(5, verbose_stream() << "no default initialization associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
1281+
}
1282+
1283+
12591284
bool solver::watches_fixed(enode* n) const {
12601285
return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_id()) != null_theory_var;
12611286
}

src/sat/smt/euf_solver.h

+2
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ namespace euf {
564564
m_user_propagator->add_expr(e);
565565
}
566566

567+
void user_propagate_initialize_value(expr* var, expr* value);
568+
567569
// solver factory
568570
::solver* mk_solver() { return m_mk_solver(); }
569571
void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; }

src/smt/smt_context.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -2914,6 +2914,44 @@ namespace smt {
29142914
register_plugin(m_user_propagator);
29152915
}
29162916

2917+
void context::user_propagate_initialize_value(expr* var, expr* value) {
2918+
m_values.push_back({expr_ref(var, m), expr_ref(value, m)});
2919+
push_trail(push_back_vector(m_values));
2920+
}
2921+
2922+
void context::initialize_value(expr* var, expr* value) {
2923+
IF_VERBOSE(10, verbose_stream() << "context initialize " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
2924+
sort* s = var->get_sort();
2925+
ensure_internalized(var);
2926+
2927+
if (m.is_bool(s)) {
2928+
auto v = get_bool_var_of_id_option(var->get_id());
2929+
if (v == null_bool_var) {
2930+
IF_VERBOSE(5, verbose_stream() << "Boolean variable has no literal " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
2931+
return;
2932+
}
2933+
m_bdata[v].m_phase_available = true;
2934+
if (m.is_true(value))
2935+
m_bdata[v].m_phase = true;
2936+
else if (m.is_false(value))
2937+
m_bdata[v].m_phase = false;
2938+
else
2939+
IF_VERBOSE(5, verbose_stream() << "Boolean value is not constant " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
2940+
return;
2941+
}
2942+
2943+
if (!e_internalized(var))
2944+
return;
2945+
enode* n = get_enode(var);
2946+
theory* th = m_theories.get_plugin(s->get_family_id());
2947+
if (!th) {
2948+
IF_VERBOSE(5, verbose_stream() << "No theory is attached to variable " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
2949+
return;
2950+
}
2951+
th->initialize_value(var, value);
2952+
2953+
}
2954+
29172955
bool context::watches_fixed(enode* n) const {
29182956
return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var;
29192957
}
@@ -3756,6 +3794,9 @@ namespace smt {
37563794
TRACE("search", display(tout); display_enodes_lbls(tout););
37573795
TRACE("search_detail", m_asserted_formulas.display(tout););
37583796
init_search();
3797+
for (auto const& [var, value] : m_values)
3798+
initialize_value(var, value);
3799+
37593800
flet<bool> l(m_searching, true);
37603801
TRACE("after_init_search", display(tout););
37613802
IF_VERBOSE(2, verbose_stream() << "(smt.searching)\n";);

0 commit comments

Comments
 (0)