diff --git a/src/sat/smt/array_model.cpp b/src/sat/smt/array_model.cpp index 74594e5a0b1..ecaedd97bde 100644 --- a/src/sat/smt/array_model.cpp +++ b/src/sat/smt/array_model.cpp @@ -20,7 +20,11 @@ Module Name: #include "sat/smt/euf_solver.h" namespace array { - + + + void solver::init_model() { + collect_defaults(); + } bool solver::add_dep(euf::enode* n, top_sort& dep) { if (!a.is_array(n->get_expr())) { @@ -41,6 +45,10 @@ namespace array { for (euf::enode* k : euf::enode_class(n)) if (a.is_const(k->get_expr())) dep.add(n, k->get_arg(0)); + theory_var v = get_th_var(n); + euf::enode* d = get_default(v); + if (d) + dep.add(n, d); if (!dep.deps().contains(n)) dep.insert(n, nullptr); return true; @@ -57,6 +65,17 @@ namespace array { func_interp * fi = alloc(func_interp, m, arity); mdl.register_decl(f, fi); + theory_var v = get_th_var(n); + euf::enode* d = get_default(v); + if (d && !fi->get_else()) + fi->set_else(values.get(d->get_root_id())); + + if (!fi->get_else() && get_else(v)) + fi->set_else(get_else(v)); + +#if 0 + // this functionality is already taken care of by model_init. + if (!fi->get_else()) for (euf::enode* k : euf::enode_class(n)) if (a.is_const(k->get_expr())) @@ -66,6 +85,7 @@ namespace array { for (euf::enode* p : euf::enode_parents(n)) if (a.is_default(p->get_expr())) fi->set_else(values.get(p->get_root_id())); +#endif if (!fi->get_else()) { expr* else_value = nullptr; @@ -90,6 +110,9 @@ namespace array { fi->set_else(else_value); } + if (!get_else(v) && fi->get_else()) + set_else(v, fi->get_else()); + for (euf::enode* p : euf::enode_parents(n)) { if (a.is_select(p->get_expr()) && p->get_arg(0)->get_root() == n) { expr* value = values.get(p->get_root_id(), nullptr); @@ -175,4 +198,104 @@ namespace array { return table_diff(r1, r2, else1) || table_diff(r2, r1, else2); } + void solver::collect_defaults() { + unsigned num_vars = get_num_vars(); + m_defaults.reset(); + m_else_values.reset(); + m_parents.reset(); + m_parents.resize(num_vars, -1); + m_defaults.resize(num_vars); + m_else_values.resize(num_vars); + + // + // Create equivalence classes for defaults. + // + for (unsigned v = 0; v < num_vars; ++v) { + euf::enode * n = var2enode(v); + expr* e = n->get_expr(); + + theory_var r = get_representative(v); + + mg_merge(v, r); + + if (a.is_const(e)) + set_default(v, n->get_arg(0)); + else if (a.is_store(e)) { + theory_var w = get_th_var(n->get_arg(0)); + SASSERT(w != euf::null_theory_var); + mg_merge(v, get_representative(w)); + TRACE("array", tout << "merge: " << ctx.bpp(n) << " " << v << " " << w << "\n";); + } + else if (a.is_default(e)) { + theory_var w = get_th_var(n->get_arg(0)); + SASSERT(w != euf::null_theory_var); + set_default(w, n); + } + } + } + + void solver::set_default(theory_var v, euf::enode* n) { + TRACE("array", tout << "set default: " << v << " " << ctx.bpp(n) << "\n";); + v = mg_find(v); + if (!m_defaults[v]) + m_defaults[v] = n; + } + + euf::enode* solver::get_default(theory_var v) { + return m_defaults[mg_find(v)]; + } + + void solver::set_else(theory_var v, expr* e) { + m_else_values[mg_find(v)] = e; + } + + expr* solver::get_else(theory_var v) { + return m_else_values[mg_find(v)]; + } + + euf::theory_var solver::mg_find(theory_var n) { + if (m_parents[n] < 0) + return n; + theory_var n0 = n; + n = m_parents[n0]; + if (m_parents[n] < -1) + return n; + while (m_parents[n] >= 0) + n = m_parents[n]; + // compress path. + while (m_parents[n0] >= 0) { + theory_var n1 = m_parents[n0]; + m_parents[n0] = n; + n0 = n1; + } + return n; + } + + void solver::mg_merge(theory_var u, theory_var v) { + u = mg_find(u); + v = mg_find(v); + if (u != v) { + SASSERT(m_parents[u] < 0); + SASSERT(m_parents[v] < 0); + if (m_parents[u] > m_parents[v]) + std::swap(u, v); + m_parents[u] += m_parents[v]; + m_parents[v] = u; + + if (!m_defaults[u]) + m_defaults[u] = m_defaults[v]; + + CTRACE("array", m_defaults[v], + tout << ctx.bpp(m_defaults[v]->get_root()) << "\n"; + tout << ctx.bpp(m_defaults[u]->get_root()) << "\n"; + ); + + // NB. it may be the case that m_defaults[u] != m_defaults[v] + // when m and n are finite arrays. + + } + } + + + } diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index 26f2d902298..b64dbfff53a 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -219,7 +219,17 @@ namespace array { void pop_core(unsigned n) override; // models + euf::enode_vector m_defaults; // temporary field for model construction + ptr_vector m_else_values; // + svector m_parents; // temporary field for model construction bool have_different_model_values(theory_var v1, theory_var v2); + void collect_defaults(); + void mg_merge(theory_var u, theory_var v); + theory_var mg_find(theory_var n); + void set_default(theory_var v, euf::enode* n); + euf::enode* get_default(theory_var v); + void set_else(theory_var v, expr* e); + expr* get_else(theory_var v); // diagnostics std::ostream& display_info(std::ostream& out, char const* id, euf::enode_vector const& v) const; @@ -244,6 +254,7 @@ namespace array { bool use_diseqs() const override { return true; } void new_diseq_eh(euf::th_eq const& eq) override; bool unit_propagate() override; + void init_model() override; void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; bool add_dep(euf::enode* n, top_sort& dep) override; sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 2870c263f4c..1538f9a3ba5 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -101,6 +101,11 @@ namespace euf { theory_var th_euf_solver::get_th_var(expr* e) const { return get_th_var(ctx.get_enode(e)); } + + theory_var th_euf_solver::get_representative(theory_var v) const { + euf::enode* r = var2enode(v)->get_root(); + return get_th_var(r); + } void th_euf_solver::push_core() { m_var2enode_lim.push_back(m_var2enode.size()); diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index f8e26e3452d..bb9358e5932 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -182,6 +182,7 @@ namespace euf { sat::literal mk_literal(expr* e) const; theory_var get_th_var(enode* n) const { return n->get_th_var(get_id()); } theory_var get_th_var(expr* e) const; + theory_var get_representative(theory_var v) const; trail_stack& get_trail_stack(); bool is_attached_to_var(enode* n) const; bool is_root(theory_var v) const { return var2enode(v)->is_root(); }