diff --git a/smt/expr.cpp b/smt/expr.cpp index 81161f656..dd0eb3e09 100644 --- a/smt/expr.cpp +++ b/smt/expr.cpp @@ -1464,7 +1464,7 @@ expr expr::cmp_eq(const expr &rhs, bool simplify) const { return false; return rhs == *this; } - // constants on rhs from now. + // constants on rhs from now on. if (rhs.isTrue()) return *this; @@ -1649,6 +1649,14 @@ void expr::operator|=(const expr &rhs) { *this = *this || rhs; } +expr expr::mk_and(const vector &vals) { + expr ret(true); + for (auto &e : vals) { + ret &= e; + } + return ret; +} + expr expr::mk_and(const set &vals) { expr ret(true); for (auto &e : vals) { @@ -2235,6 +2243,12 @@ expr expr::subst(const vector> &repls) const { return Z3_substitute(ctx(), ast(), repls.size(), from.get(), to.get()); } +expr expr::subst_simplify(const vector> &repls) const { + if (repls.empty()) + return *this; + return subst(repls).simplify(); +} + expr expr::subst(const expr &from, const expr &to) const { C(from, to); auto f = from(); @@ -2248,6 +2262,20 @@ expr expr::subst_var(const expr &repl) const { return Z3_substitute_vars(ctx(), ast(), 1, &r); } +expr expr::propagate(const AndExpr &constraints) const { + C(); + auto from = make_unique(constraints.exprs.size()); + auto to = make_unique(constraints.exprs.size()); + expr true_expr(true); + unsigned i = 0; + for (auto &e : constraints.exprs) { + C2(e); + from[i] = e(); + to[i++] = true_expr(); + } + return Z3_substitute(ctx(), ast(), i, from.get(), to.get()); +} + set expr::vars() const { return vars({ this }); } diff --git a/smt/expr.h b/smt/expr.h index 99c157334..f9ac0d1b6 100644 --- a/smt/expr.h +++ b/smt/expr.h @@ -20,6 +20,8 @@ typedef struct _Z3_sort* Z3_sort; namespace smt { +class AndExpr; + class expr { uintptr_t ptr = 0; @@ -288,6 +290,7 @@ class expr { void operator&=(const expr &rhs); void operator|=(const expr &rhs); + static expr mk_and(const std::vector &vals); static expr mk_and(const std::set &vals); static expr mk_or(const std::set &vals); @@ -363,11 +366,15 @@ class expr { // replace v1 -> v2 expr subst(const std::vector> &repls) const; + expr subst_simplify(const std::vector> &repls) const; expr subst(const expr &from, const expr &to) const; // replace the 1st quantified variable expr subst_var(const expr &repl) const; + // turn all expressions in 'constraints' into true + expr propagate(const AndExpr &constraints) const; + std::set vars() const; static std::set vars(const std::vector &exprs); diff --git a/smt/exprs.cpp b/smt/exprs.cpp index cb6c628b4..5cee1439c 100644 --- a/smt/exprs.cpp +++ b/smt/exprs.cpp @@ -4,6 +4,7 @@ #include "smt/exprs.h" #include "smt/smt.h" #include "util/compiler.h" +#include #include using namespace std; @@ -45,6 +46,12 @@ void AndExpr::del(const AndExpr &other) { exprs.erase(e); } +expr AndExpr::propagate(const AndExpr &other) const { + vector ret; + ranges::set_difference(exprs, other.exprs, back_inserter(ret), less{}); + return expr::mk_and(ret).propagate(other); +} + void AndExpr::reset() { exprs.clear(); } diff --git a/smt/exprs.h b/smt/exprs.h index f258ad6da..3250ea073 100644 --- a/smt/exprs.h +++ b/smt/exprs.h @@ -29,6 +29,7 @@ class AndExpr { void add(expr &&e, unsigned limit = 16); void add(const AndExpr &other); void del(const AndExpr &other); + expr propagate(const AndExpr &other) const; void reset(); bool contains(const expr &e) const; expr operator()() const; @@ -37,6 +38,7 @@ class AndExpr { auto operator<=>(const AndExpr&) const = default; friend std::ostream &operator<<(std::ostream &os, const AndExpr &e); template friend class DisjointExpr; + friend class expr; };