diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index c314f2ed234..e627e2343d2 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -13,8 +13,7 @@ Module Name: Nikolaj Bjorner (nbjorner) 2015-12-5 Murphy Berzish 2017-02-21 - -Notes: + Caleb Stanford 2020-07-07 --*/ diff --git a/src/sat/ba/ba_solver.cpp b/src/sat/ba/ba_solver.cpp index 239ed723d49..646cb65341d 100644 --- a/src/sat/ba/ba_solver.cpp +++ b/src/sat/ba/ba_solver.cpp @@ -5052,6 +5052,91 @@ namespace sat { return ok; } + bool ba_solver::extract_pb(std::function& add_cardinality, + std::function& add_pb) { + + unsigned_vector coeffs; + literal_vector lits; + for (constraint* cp : m_constraints) { + switch (cp->tag()) { + case card_t: { + card const& c = cp->to_card(); + unsigned n = c.size(); + unsigned k = c.k(); + + if (c.lit() == null_literal) { + // c.lits() >= k + // <=> + // ~c.lits() <= n - k + lits.reset(); + for (unsigned j = 0; j < n; ++j) lits.push_back(c[j]); + add_cardinality(lits.size(), lits.c_ptr(), n - k); + } + else { + // + // c.lit() <=> c.lits() >= k + // + // (c.lits() < k) or c.lit() + // = (c.lits() + (n - k + 1)*~c.lit()) <= n + // + // ~c.lit() or (c.lits() >= k) + // = ~c.lit() or (~c.lits() <= n - k) + // = k*c.lit() + ~c.lits() <= n + // + lits.reset(); + coeffs.reset(); + for (literal l : c) lits.push_back(l), coeffs.push_back(1); + lits.push_back(~c.lit()); coeffs.push_back(n - k + 1); + add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), n); + + lits.reset(); + coeffs.reset(); + for (literal l : c) lits.push_back(~l), coeffs.push_back(1); + lits.push_back(c.lit()); coeffs.push_back(k); + add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), n); + } + break; + } + case ba_solver::pb_t: { + ba_solver::pb const& p = cp->to_pb(); + lits.reset(); + coeffs.reset(); + unsigned sum = 0; + for (ba_solver::wliteral wl : p) sum += wl.first; + + if (p.lit() == null_literal) { + // w1 + .. + w_n >= k + // <=> + // ~wl + ... + ~w_n <= sum_of_weights - k + for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); + add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum - p.k()); + } + else { + // lit <=> w1 + .. + w_n >= k + // <=> + // lit or w1 + .. + w_n <= k - 1 + // ~lit or w1 + .. + w_n >= k + // <=> + // (sum - k + 1)*~lit + w1 + .. + w_n <= sum + // k*lit + ~wl + ... + ~w_n <= sum + lits.push_back(p.lit()), coeffs.push_back(p.k()); + for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); + add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); + + lits.reset(); + coeffs.reset(); + lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k()); + for (ba_solver::wliteral wl : p) lits.push_back(wl.second), coeffs.push_back(wl.first); + add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); + } + break; + } + case ba_solver::xr_t: + return false; + } + } + return true; + } }; diff --git a/src/sat/ba/ba_solver.h b/src/sat/ba/ba_solver.h index 101f2dc5250..05373d59414 100644 --- a/src/sat/ba/ba_solver.h +++ b/src/sat/ba/ba_solver.h @@ -579,6 +579,9 @@ namespace sat { bool validate() override; + bool extract_pb(std::function& add_cardinlaity, + std::function& add_pb) override; + }; diff --git a/src/sat/euf/euf_solver.cpp b/src/sat/euf/euf_solver.cpp index 1da64970aa6..7a7e4e73f4e 100644 --- a/src/sat/euf/euf_solver.cpp +++ b/src/sat/euf/euf_solver.cpp @@ -418,5 +418,14 @@ namespace euf { NOT_IMPLEMENTED_YET(); return nullptr; } - + + bool solver::extract_pb(std::function& card, + std::function& pb) { + if (m_true) + return false; + for (auto* e : m_extensions) + if (!e->extract_pb(card, pb)) + return false; + return true; + } } diff --git a/src/sat/euf/euf_solver.h b/src/sat/euf/euf_solver.h index 0bc7bbf8633..90ccdf23ee3 100644 --- a/src/sat/euf/euf_solver.h +++ b/src/sat/euf/euf_solver.h @@ -131,6 +131,10 @@ namespace euf { bool check_model(sat::model const& m) const override; unsigned max_var(unsigned w) const override; + bool extract_pb(std::function& card, + std::function& pb) override; + + sat::literal internalize(sat::sat_internalizer& si, expr* e, bool sign, bool root) override; model_converter* get_model(); diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index f0f15d3074e..b96c577f8fa 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -18,6 +18,7 @@ Revision History: --*/ #pragma once +#include #include "sat/sat_types.h" #include "util/params.h" #include "util/statistics.h" @@ -83,6 +84,11 @@ namespace sat { virtual bool is_blocked(literal l, ext_constraint_idx) = 0; virtual bool check_model(model const& m) const = 0; virtual unsigned max_var(unsigned w) const = 0; + + virtual bool extract_pb(std::function& card, + std::function& pb) { + return true; + } }; }; diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index a951b728b8f..fd6f39d5d2e 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -322,6 +322,7 @@ namespace sat { add_unit(~c[0], null_literal); return; } + m_is_pb = true; unsigned id = m_constraints.size(); m_constraints.push_back(constraint(k, id)); for (unsigned i = 0; i < sz; ++i) { @@ -414,99 +415,16 @@ namespace sat { } m_num_non_binary_clauses = s.m_clauses.size(); - if (s.get_extension()) - throw default_exception("local search is incompatible with extensions"); -#if 0 // copy cardinality clauses - ba_solver* ext = dynamic_cast(s.get_extension()); - if (ext) { - unsigned_vector coeffs; - literal_vector lits; - for (ba_solver::constraint* cp : ext->m_constraints) { - switch (cp->tag()) { - case ba_solver::card_t: { - ba_solver::card const& c = cp->to_card(); - unsigned n = c.size(); - unsigned k = c.k(); - - if (c.lit() == null_literal) { - // c.lits() >= k - // <=> - // ~c.lits() <= n - k - lits.reset(); - for (unsigned j = 0; j < n; ++j) lits.push_back(c[j]); - add_cardinality(lits.size(), lits.c_ptr(), n - k); - } - else { - // - // c.lit() <=> c.lits() >= k - // - // (c.lits() < k) or c.lit() - // = (c.lits() + (n - k + 1)*~c.lit()) <= n - // - // ~c.lit() or (c.lits() >= k) - // = ~c.lit() or (~c.lits() <= n - k) - // = k*c.lit() + ~c.lits() <= n - // - m_is_pb = true; - lits.reset(); - coeffs.reset(); - for (literal l : c) lits.push_back(l), coeffs.push_back(1); - lits.push_back(~c.lit()); coeffs.push_back(n - k + 1); - add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), n); - - lits.reset(); - coeffs.reset(); - for (literal l : c) lits.push_back(~l), coeffs.push_back(1); - lits.push_back(c.lit()); coeffs.push_back(k); - add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), n); - } - break; - } - case ba_solver::pb_t: { - ba_solver::pb const& p = cp->to_pb(); - lits.reset(); - coeffs.reset(); - m_is_pb = true; - unsigned sum = 0; - for (ba_solver::wliteral wl : p) sum += wl.first; - - if (p.lit() == null_literal) { - // w1 + .. + w_n >= k - // <=> - // ~wl + ... + ~w_n <= sum_of_weights - k - for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); - add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum - p.k()); - } - else { - // lit <=> w1 + .. + w_n >= k - // <=> - // lit or w1 + .. + w_n <= k - 1 - // ~lit or w1 + .. + w_n >= k - // <=> - // (sum - k + 1)*~lit + w1 + .. + w_n <= sum - // k*lit + ~wl + ... + ~w_n <= sum - lits.push_back(p.lit()), coeffs.push_back(p.k()); - for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); - add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); - - lits.reset(); - coeffs.reset(); - lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k()); - for (ba_solver::wliteral wl : p) lits.push_back(wl.second), coeffs.push_back(wl.first); - add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); - } - break; - } - case ba_solver::xr_t: - throw default_exception("local search is incompatible with enabling xor solving"); - break; - } - } - } -#endif - + extension* ext = s.get_extension(); + std::function card = + [&](unsigned sz, literal const* c, unsigned k) { add_cardinality(sz, c, k); }; + std::function pb = + [&](unsigned sz, literal const* c, unsigned const* coeffs, unsigned k) { add_pb(sz, c, coeffs, k); }; + if (ext && !ext->extract_pb(card, pb)) + throw default_exception("local search is incomplete with extensions beyond PB"); + if (_init) { init(); }