Skip to content

Commit a9098a5

Browse files
committed
optimise l terms addition
Signed-off-by: Lev Nachmanson <[email protected]>
1 parent 3e7e903 commit a9098a5

File tree

1 file changed

+155
-45
lines changed

1 file changed

+155
-45
lines changed

src/math/lp/dioph_eq.cpp

+155-45
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ namespace lp {
4848
return t.find(j) != t.end();
4949
}
5050

51+
struct iv {
52+
mpq m_coeff;
53+
unsigned m_j;
54+
lpvar var() const { return m_j; }
55+
const mpq & coeff() const { return m_coeff; }
56+
mpq & coeff() { return m_coeff; }
57+
iv() {}
58+
iv(const mpq& v, unsigned j) : m_coeff(v), m_j(j) {}
59+
};
60+
5161
struct bijection {
5262
std::unordered_map<unsigned, unsigned> m_map;
5363
std::unordered_map<unsigned, unsigned> m_rev_map;
@@ -72,11 +82,11 @@ namespace lp {
7282
m_map.erase(key);
7383
}
7484
bool has_val(unsigned b) const {
75-
return contains(m_rev_map, b);
85+
return m_rev_map.find(b) != m_rev_map.end();
7686
}
7787

7888
bool has_key(unsigned a) const {
79-
return contains(m_map, a);
89+
return m_map.find(a) != m_map.end();
8090
}
8191

8292
void transpose_val(unsigned b0, unsigned b1) {
@@ -148,7 +158,7 @@ namespace lp {
148158
}
149159

150160
bool has_key(unsigned j) const { return m_bij.has_key(j); }
151-
161+
bool has_second_key(unsigned j) const { return m_bij.has_val(j);}
152162
// Get the data by 'a', look up b in m_bij, then read from m_data
153163
const T& get_by_key(unsigned a) const {
154164
unsigned b = m_bij[a]; // relies on operator[](unsigned) from bijection
@@ -188,23 +198,6 @@ namespace lp {
188198
return m_c;
189199
}
190200
term_o() : m_c(0) {}
191-
void substitute_var_with_term(const term_o& t, unsigned col_to_subs) {
192-
mpq a = get_coeff(
193-
col_to_subs); // need to copy it becase the pointer value can be
194-
// changed during the next loop
195-
const mpq& coeff = t.get_coeff(col_to_subs);
196-
SASSERT(coeff.is_one() || coeff.is_minus_one());
197-
if (coeff.is_one()) {
198-
a = -a;
199-
}
200-
for (auto p : t) {
201-
if (p.j() == col_to_subs)
202-
continue;
203-
this->add_monomial(a * p.coeff(), p.j());
204-
}
205-
this->c() += a * t.c();
206-
this->m_coeffs.erase(col_to_subs);
207-
}
208201

209202
friend term_o operator*(const mpq& k, const term_o& term) {
210203
term_o r;
@@ -260,6 +253,17 @@ namespace lp {
260253
[](int j) -> std::string { return "x" + std::to_string(j); }, out);
261254
}
262255

256+
257+
// used for debug only
258+
std::ostream& print_lar_term_L(const std_vector<iv>& t, std::ostream& out) const {
259+
vector<std::pair<mpq, unsigned>> tp;
260+
for (const auto & p : t) {
261+
tp.push_back(std::make_pair(p.coeff(), p.var()));
262+
}
263+
return print_linear_combination_customized(
264+
tp, [](int j) -> std::string { return "x" + std::to_string(j); }, out);
265+
}
266+
263267
std::ostream& print_term_o(term_o const& term, std::ostream& out) const {
264268
if (term.size() == 0 && term.c().is_zero()) {
265269
out << "0";
@@ -331,8 +335,104 @@ namespace lp {
331335
// iterate over all rows from 0 to m_e_matrix.row_count() - 1 and return those i such !m_k2s.has_val(i)
332336
// set S - iterate over bijection m_k2s
333337
mpq m_c; // the constant of the equation
334-
lar_term m_tmp_l;
338+
struct term_with_index {
339+
// The invariant is that m_index[m_data[k].var()] = k, for each 0 <= k < m_data.size(),
340+
// and m_index[j] = -1, or m_tmp[m_index[j]].var() = j, for every 0 <= j < m_index.size().
341+
// For example m_data = [(coeff, 5), (coeff, 3)]
342+
// then m_index = [-1,-1, -1, 1, -1, 0, -1, ....].
343+
std_vector<iv> m_data;
344+
std_vector<int> m_index;
345+
// used for debug only
346+
lar_term to_term() const {
347+
lar_term r;
348+
for (const auto& p: m_data) {
349+
r.add_monomial(p.coeff(), p.var());
350+
}
351+
return r;
352+
}
353+
void add(const mpq& a, unsigned j) {
354+
SASSERT(!a.is_zero());
355+
// Expand m_index if needed
356+
if (j >= m_index.size()) {
357+
m_index.resize(j + 1, -1);
358+
}
359+
360+
int idx = m_index[j];
361+
if (idx == -1) {
362+
// Insert a new monomial { a, j } into m_data
363+
m_data.push_back({ a, j });
364+
m_index[j] = static_cast<int>(m_data.size() - 1);
365+
} else {
366+
// Accumulate the coefficient
367+
m_data[idx].coeff() += a;
368+
// If the coefficient becomes zero, remove the entry
369+
if (m_data[idx].coeff().is_zero()) {
370+
int last = static_cast<int>(m_data.size() - 1);
371+
// Swap with the last element for efficient removal
372+
if (idx != last) {
373+
auto tmp = m_data[last];
374+
m_data[idx] = tmp;
375+
m_index[tmp.var()] = idx;
376+
}
377+
m_data.pop_back();
378+
m_index[j] = -1;
379+
}
380+
}
381+
SASSERT(invariant());
382+
}
383+
384+
bool invariant() const {
385+
// 1. For each j in [0..m_index.size()), if m_index[j] = -1, ensure no m_data[k].var() == j
386+
// otherwise verify m_data[m_index[j]].var() == j
387+
for (unsigned j = 0; j < m_index.size(); j++) {
388+
int idx = m_index[j];
389+
if (idx == -1) {
390+
// Check that j is not in m_data
391+
for (unsigned k = 0; k < m_data.size(); ++k) {
392+
if (m_data[k].var() == j) {
393+
return false;
394+
}
395+
}
396+
}
397+
else {
398+
// Check that var() in m_data[idx] matches j
399+
if (idx < 0 || static_cast<unsigned>(idx) >= m_data.size()) {
400+
return false;
401+
}
402+
if (m_data[idx].var() != j || m_data[idx].coeff().is_zero()) {
403+
return false;
404+
}
405+
}
406+
}
407+
// 2. For each item in m_data, check that m_index[m_data[k].var()] == k
408+
// and that the coeff() is non-zero
409+
for (unsigned k = 0; k < m_data.size(); ++k) {
410+
unsigned var = m_data[k].var();
411+
if (var >= m_index.size()) {
412+
return false;
413+
}
414+
if (m_index[var] != static_cast<int>(k)) {
415+
return false;
416+
}
417+
if (m_data[k].coeff().is_zero()) {
418+
return false;
419+
}
420+
}
335421

422+
return true;
423+
}
424+
void clear() {
425+
for (const auto& p: m_data) {
426+
m_index[p.var()] = -1;
427+
}
428+
m_data.clear();
429+
SASSERT(invariant());
430+
}
431+
432+
};
433+
434+
term_with_index m_term_with_index;
435+
336436
bijection m_k2s;
337437
bij_map<lar_term> m_fresh_k2xt_terms;
338438
// m_row2fresh_defs[i] is the set of all fresh variables xt
@@ -981,6 +1081,7 @@ namespace lp {
9811081
mpq gcd_of_coeffs(const K& k) {
9821082
mpq g(0);
9831083
for (const auto& p : k) {
1084+
SASSERT(p.coeff().is_int());
9841085
if (g.is_zero())
9851086
g = abs(p.coeff());
9861087
else
@@ -1107,8 +1208,14 @@ namespace lp {
11071208
// there is no change in m_l_matrix
11081209
TRACE("dioph_eq", tout << "after subs k:" << k << "\n";
11091210
print_term_o(create_term_from_ind_c(), tout) << std::endl;
1110-
tout << "m_tmp_l:{"; print_lar_term_L(m_tmp_l, tout);
1111-
tout << "}, opened:"; print_ml(m_tmp_l, tout) << std::endl;);
1211+
tout << "m_term_with_index:{"; print_lar_term_L(m_term_with_index.m_data, tout);
1212+
tout << "}, opened:"; print_ml(m_term_with_index.to_term(), tout) << std::endl;);
1213+
}
1214+
1215+
void add_l_row_to_term_with_index(const mpq& coeff, unsigned ei) {
1216+
for (const auto & p: m_l_matrix.m_rows[ei]) {
1217+
m_term_with_index.add(coeff * p.coeff(), p.var());
1218+
}
11121219
}
11131220

11141221
void subs_front_in_indexed_vector_by_S(unsigned k, std::queue<unsigned> &q) {
@@ -1142,12 +1249,11 @@ namespace lp {
11421249
q.push(j);
11431250
}
11441251
m_c += coeff * e.m_c;
1145-
1146-
m_tmp_l += coeff * l_term_from_row(sub_index(k)); // improve later
1252+
add_l_row_to_term_with_index(coeff, sub_index(k));
11471253
TRACE("dioph_eq", tout << "after subs k:" << k << "\n";
11481254
print_term_o(create_term_from_ind_c(), tout) << std::endl;
1149-
tout << "m_tmp_l:{"; print_lar_term_L(m_tmp_l, tout);
1150-
tout << "}, opened:"; print_ml(m_tmp_l, tout) << std::endl;);
1255+
tout << "m_term_with_index:{"; print_lar_term_L(m_term_with_index.to_term(), tout);
1256+
tout << "}, opened:"; print_ml(m_term_with_index.to_term(), tout) << std::endl;);
11511257
}
11521258

11531259
bool is_substituted_by_fresh(unsigned k) const {
@@ -1263,7 +1369,7 @@ namespace lp {
12631369
m_indexed_work_vector.clear();
12641370
m_indexed_work_vector.resize(m_e_matrix.column_count());
12651371
m_c = 0;
1266-
m_tmp_l = lar_term();
1372+
m_term_with_index.clear();
12671373
for (const auto& p : lar_t) {
12681374
SASSERT(p.coeff().is_int());
12691375
if (is_fixed(p.j()))
@@ -1300,41 +1406,41 @@ namespace lp {
13001406
tout << "from ind:";
13011407
print_term_o(create_term_from_ind_c(), tout) << std::endl;
13021408
tout << "m_tmp_l:";
1303-
print_lar_term_L(m_tmp_l, tout) << std::endl;);
1409+
print_lar_term_L(m_term_with_index.to_term(), tout) << std::endl;);
13041410
subs_indexed_vector_with_S(q);
13051411
// if(
13061412
// fix_vars(term_to_tighten + open_ml(m_tmp_l)) !=
13071413
// term_to_lar_solver(remove_fresh_vars(create_term_from_ind_c())))
13081414
// enable_trace("dioph_eq");
13091415

1310-
TRACE("dioph_eq", tout << "after subs\n";
1416+
TRACE("dioph_eq_deb", tout << "after subs\n";
13111417
print_term_o(create_term_from_ind_c(), tout) << std::endl;
13121418
tout << "term_to_tighten:";
13131419
print_lar_term_L(term_to_tighten, tout) << std::endl;
1314-
tout << "m_tmp_l:"; print_lar_term_L(m_tmp_l, tout) << std::endl;
1420+
tout << "m_tmp_l:"; print_lar_term_L(m_term_with_index.to_term(), tout) << std::endl;
13151421
tout << "open_ml:";
1316-
print_lar_term_L(open_ml(m_tmp_l), tout) << std::endl;
1422+
print_lar_term_L(open_ml(m_term_with_index.to_term()), tout) << std::endl;
13171423
tout << "term_to_tighten + open_ml:";
1318-
print_term_o(term_to_tighten + open_ml(m_tmp_l), tout)
1424+
print_term_o(term_to_tighten + open_ml(m_term_with_index.to_term()), tout)
13191425
<< std::endl;
1320-
term_o ls = fix_vars(term_to_tighten + open_ml(m_tmp_l));
1426+
term_o ls = fix_vars(term_to_tighten + open_ml(m_term_with_index.to_term()));
13211427
tout << "ls:"; print_term_o(ls,tout) << std::endl;
13221428
term_o rs = term_to_lar_solver(remove_fresh_vars(create_term_from_ind_c()));
13231429
tout << "rs:"; print_term_o(rs, tout ) << std::endl;
13241430
term_o diff = ls - rs;
13251431
if (!diff.is_empty()) {
13261432
tout << "diff:"; print_term_o(diff, tout ) << std::endl;
13271433
}
1328-
13291434
);
1435+
13301436
SASSERT(
1331-
fix_vars(term_to_tighten + open_ml(m_tmp_l)) ==
1437+
fix_vars(term_to_tighten + open_ml(m_term_with_index.to_term())) ==
13321438
term_to_lar_solver(remove_fresh_vars(create_term_from_ind_c())));
13331439
mpq g = gcd_of_coeffs(m_indexed_work_vector);
13341440
TRACE("dioph_eq", tout << "after process_q_with_S\nt:";
13351441
print_term_o(create_term_from_ind_c(), tout) << std::endl;
13361442
tout << "g:" << g << std::endl;
1337-
/*tout << "dep:"; print_dep(tout, m_tmp_l) << std::endl;*/);
1443+
/*tout << "dep:"; print_dep(tout, m_term_with_index.m_data) << std::endl;*/);
13381444

13391445
if (g.is_one())
13401446
return false;
@@ -1364,7 +1470,7 @@ namespace lp {
13641470
if (m_c > rs || (is_strict && m_c == rs)) {
13651471
u_dependency* dep =
13661472
lra.mk_join(explain_fixed(lra.get_term(j)),
1367-
explain_fixed_in_meta_term(m_tmp_l));
1473+
explain_fixed_in_meta_term(m_term_with_index.m_data));
13681474
dep = lra.mk_join(
13691475
dep, lra.get_bound_constraint_witnesses_for_column(j));
13701476
for (constraint_index ci : lra.flatten(dep)) {
@@ -1377,7 +1483,7 @@ namespace lp {
13771483
if (m_c < rs || (is_strict && m_c == rs)) {
13781484
u_dependency* dep =
13791485
lra.mk_join(explain_fixed(lra.get_term(j)),
1380-
explain_fixed_in_meta_term(m_tmp_l));
1486+
explain_fixed_in_meta_term(m_term_with_index.m_data));
13811487
dep = lra.mk_join(
13821488
dep, lra.get_bound_constraint_witnesses_for_column(j));
13831489
for (constraint_index ci : lra.flatten(dep)) {
@@ -1432,7 +1538,7 @@ namespace lp {
14321538
lconstraint_kind kind =
14331539
upper ? lconstraint_kind::LE : lconstraint_kind::GE;
14341540
u_dependency* dep = prev_dep;
1435-
dep = lra.mk_join(dep, explain_fixed_in_meta_term(m_tmp_l));
1541+
dep = lra.mk_join(dep, explain_fixed_in_meta_term(m_term_with_index.m_data));
14361542
u_dependency* j_bound_dep = upper
14371543
? lra.get_column_upper_bound_witness(j)
14381544
: lra.get_column_lower_bound_witness(j);
@@ -1969,10 +2075,11 @@ namespace lp {
19692075

19702076
mpq coeff = m_e_matrix.get_val(c);
19712077
TRACE("dioph_eq", tout << "before pivot entry :"; print_entry(c.var(), tout) << std::endl;);
2078+
unsigned c_row = c.var();
19722079
m_e_matrix.pivot_term_to_row_given_cell(t, c, j, j_sign);
19732080
TRACE("dioph_eq", tout << "after pivoting c_row:";
1974-
print_entry(c.var(), tout););
1975-
SASSERT(entry_invariant(c.var()));
2081+
print_entry(c_row, tout););
2082+
SASSERT(entry_invariant(c_row));
19762083
cell_to_process--;
19772084
}
19782085
SASSERT(is_eliminated_from_f(j));
@@ -2207,7 +2314,8 @@ namespace lp {
22072314
unsigned h = -1;
22082315
unsigned n = 0; // number of choices for a fresh variable
22092316
mpq the_smallest_ahk;
2210-
unsigned kh, kh_sign;
2317+
unsigned kh;
2318+
int kh_sign;
22112319
for (unsigned ei=0; ei < m_e_matrix.row_count(); ei++) {
22122320
if (belongs_to_s(ei)) continue;
22132321
if (m_e_matrix.m_rows[ei].size() == 0) {
@@ -2240,7 +2348,7 @@ namespace lp {
22402348
kh_sign = k_sign;
22412349
}
22422350
}
2243-
if (h == UINT_MAX) return false;
2351+
if (h == -1) return false;
22442352
SASSERT(!the_smallest_ahk.is_one());
22452353
fresh_var_step(h, kh, the_smallest_ahk * mpq(kh_sign));
22462354
return true;
@@ -2265,7 +2373,9 @@ namespace lp {
22652373
}
22662374

22672375
bool var_is_fresh(unsigned j) const {
2268-
return m_var_register.local_to_external(j) == UINT_MAX;
2376+
bool ret = m_fresh_k2xt_terms.has_second_key(j);
2377+
SASSERT(!ret || m_var_register.local_to_external(j) == UINT_MAX);
2378+
return ret;
22692379
}
22702380

22712381
};

0 commit comments

Comments
 (0)