From 976e4c91b0f151483e3cf047e294fb5e58ef7580 Mon Sep 17 00:00:00 2001 From: Caleb Stanford Date: Thu, 30 Jul 2020 16:54:49 -0400 Subject: [PATCH] Integrate new regex solver (#4602) * std::cout debugging statements * comment out std::cout debugging as this is now a shared fork * convert std::cout to TRACE statements for seq_rewriter and seq_regex * add cases to min_length and max_length for regexes * bug fix * update min_length and max_length functions for REs * initial pass on simplifying derivative normal forms by eliminating redundant predicates locally * add seq_regex_brief trace statements * working on debugging ref count issue * fix ref count bug and convert trace statements to seq_regex_brief * add compact tracing for cache hits/misses * seq_regex fix cache hit/miss tracing and wrapper around is_nullable * minor * label and disable more experimental changes for testing * minor documentation / tracing * a few more @EXP annotations * dead state elimination skeleton code * progress on dead state elimination * more progress on dead state elimination * refactor dead state class to separate self-contained state_graph class * finish factoring state_graph to only work with unsigned values, and implement separate functionality for expr* logic * implement get_all_derivatives, add debug tracing * trace statements for debugging is_nullable loop bug * fix is_nullable loop bug * comment out local nullable change and mark experimental * pretty printing for state_graph * rewrite state graph to remove the fragile assumption that all edges from a state are added at a time * start of general cycle detection check + fix some comments * implement full cycle detection procedure * normalize derivative conditions to form 'ele <= a' * order derivative conditions by character code * fix confusing names m_to and m_from * assign increasing state IDs from 1 instead of using get_id on AST node * remove elim_condition call in get_dall_derivatives * use u_map instead of uint_map to avoid memory leak * remove unnecessary call to is_ground * debugging * small improvements to seq_regex_brief tracing * fix bug on evil2 example * save work * new propagate code * work in progress on using same seq sort for deriv calls * avoid re-computing derivatives: use same head var for every derivative call * use min_length on regexes to prune search * simple implementation of can_be_in_cycle using rank function idea * add a disabled experimental change * minor cleanup comments, etc. * seq_rewriter cleanup for PR * typo noticed by Nikolaj * move state graph to util/state_graph * re-add accidentally removed line * clean up seq_regex code removing obsolete functions and comments * a few more cleanup items * remove experimental functionality for integration * fix compilation * remove some tracing and TODOs * remove old comment * update copyright dates to 2020 * feedback from Nikolaj * use [] for map access * make state_graph methods constant * avoid recursion in mark_dead_recursive and mark_live_recursive * a possible bug fix in propagate_nonempty * write down list of invariants in state_graph * implement partial invariant check and insert CASSERT statements * expand on invariant check and tracing * finish state graph invariant check * minor tweaks * regex propagation: convert first two axioms to propagations * remove obsolete regex solver functionality Co-authored-by: calebstanford-msr --- src/ast/rewriter/seq_rewriter.cpp | 2 +- src/ast/rewriter/seq_rewriter.h | 1 - src/ast/seq_decl_plugin.cpp | 40 ++- src/smt/seq_regex.cpp | 477 +++++++++++++++++------------- src/smt/seq_regex.h | 69 +++-- src/util/CMakeLists.txt | 2 + src/util/state_graph.cpp | 410 +++++++++++++++++++++++++ src/util/state_graph.h | 182 ++++++++++++ 8 files changed, 924 insertions(+), 259 deletions(-) create mode 100644 src/util/state_graph.cpp create mode 100644 src/util/state_graph.h diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 97cc4d1a694..dd07362b61d 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -2617,7 +2617,7 @@ expr_ref seq_rewriter::mk_der_compl(expr* r) { Make an re_predicate with an arbitrary condition cond, enforcing derivative normal form on how conditions are written. - Tries to rewrites everything to (ele <= x) constraints: + Tries to rewrite everything to (ele <= x) constraints: (ele = a) => ite(ele <= a-1, none, ite(ele <= a, epsilon, none)) (a = ele) => " (a <= ele) => ite(ele <= a-1, none, epsilon) diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index 8e23ed0fe87..b9cc2707bda 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -342,6 +342,5 @@ class seq_rewriter { // heuristic elimination of element from condition that comes form a derivative. // special case optimization for conjunctions of equalities, disequalities and ranges. void elim_condition(expr* elem, expr_ref& cond); - }; diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index ce1516f41c0..3040ee8999a 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -1316,22 +1316,21 @@ unsigned seq_util::re::min_length(expr* r) const { unsigned lo = 0, hi = 0; if (is_empty(r)) return UINT_MAX; - if (is_concat(r, r1, r2)) + if (is_concat(r, r1, r2)) return u.max_plus(min_length(r1), min_length(r2)); - if (m.is_ite(r, s, r1, r2)) + if (is_union(r, r1, r2) || m.is_ite(r, s, r1, r2)) return std::min(min_length(r1), min_length(r2)); - if (is_diff(r, r1, r2)) - return min_length(r1); - if (is_union(r, r1, r2)) - return std::min(min_length(r1), min_length(r2)); - if (is_intersection(r, r1, r2)) + if (is_intersection(r, r1, r2)) return std::max(min_length(r1), min_length(r2)); - if (is_loop(r, r1, lo, hi)) + if (is_diff(r, r1, r2) || is_reverse(r, r1) || is_plus(r, r1)) + return min_length(r1); + if (is_loop(r, r1, lo) || is_loop(r, r1, lo, hi)) return u.max_mul(lo, min_length(r1)); - if (is_range(r)) - return 1; - if (is_to_re(r, s)) + if (is_to_re(r, s)) return u.str.min_length(s); + if (is_range(r) || is_of_pred(r) || is_full_char(r)) + return 1; + // Else: star, option, complement, full_seq, derivative return 0; } @@ -1341,22 +1340,21 @@ unsigned seq_util::re::max_length(expr* r) const { unsigned lo = 0, hi = 0; if (is_empty(r)) return 0; - if (is_concat(r, r1, r2)) + if (is_concat(r, r1, r2)) return u.max_plus(max_length(r1), max_length(r2)); - if (m.is_ite(r, s, r1, r2)) + if (is_union(r, r1, r2) || m.is_ite(r, s, r1, r2)) return std::max(max_length(r1), max_length(r2)); - if (is_diff(r, r1, r2)) - return max_length(r1); - if (is_union(r, r1, r2)) - return std::max(max_length(r1), max_length(r2)); - if (is_intersection(r, r1, r2)) + if (is_intersection(r, r1, r2)) return std::min(max_length(r1), max_length(r2)); + if (is_diff(r, r1, r2) || is_reverse(r, r1) || is_opt(r, r1)) + return max_length(r1); if (is_loop(r, r1, lo, hi)) return u.max_mul(hi, max_length(r1)); - if (is_range(r)) - return 1; - if (is_to_re(r, s)) + if (is_to_re(r, s)) return u.str.max_length(s); + if (is_range(r) || is_of_pred(r) || is_full_char(r)) + return 1; + // Else: star, plus, complement, full_seq, loop(r,r1,lo), derivative return UINT_MAX; } diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 41a79ae9efb..ff41590b3a3 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -1,5 +1,5 @@ /*++ -Copyright (c) 2011 Microsoft Corporation +Copyright (c) 2020 Microsoft Corporation Module Name: @@ -24,7 +24,8 @@ namespace smt { seq_regex::seq_regex(theory_seq& th): th(th), ctx(th.get_context()), - m(th.get_manager()) + m(th.get_manager()), + m_state_to_expr(m) {} seq_util& seq_regex::u() { return th.m_util; } @@ -35,34 +36,6 @@ namespace smt { arith_util& seq_regex::a() { return th.m_autil; } void seq_regex::rewrite(expr_ref& e) { th.m_rewrite(e); } - bool seq_regex::can_propagate() const { - for (auto const& p : m_to_propagate) { - literal trigger = p.m_trigger; - if (trigger == null_literal || ctx.get_assignment(trigger) != l_undef) - return true; - } - return false; - } - - bool seq_regex::propagate() { - bool change = false; - for (unsigned i = 0; !ctx.inconsistent() && i < m_to_propagate.size(); ++i) { - propagation_lit const& pl = m_to_propagate[i]; - literal trigger = pl.m_trigger; - if (trigger != null_literal && ctx.get_assignment(trigger) == l_undef) - continue; - if (propagate(pl.m_lit, trigger)) { - m_to_propagate.erase_and_swap(i--); - change = true; - } - else if (trigger != pl.m_trigger) { - m_to_propagate.set(i, propagation_lit(pl.m_lit, trigger)); - } - - } - return change; - } - /** * is_string_equality holds of str.in_re s R, * @@ -103,14 +76,14 @@ namespace smt { } /** - * Propagate the atom (str.in.re s r) + * Propagate the atom (str.in_re s r) * * Propagation implements the following inference rules * - * (not (str.in.re s r)) => (str.in.re s (complement r)) - * (str.in.re s r) => r != {} + * (not (str.in_re s r)) => (str.in_re s (complement r)) + * (str.in_re s r) => r != {} * - * (str.in.re s r) => (accept s 0 r) + * (str.in_re s r) => (accept s 0 r) */ void seq_regex::propagate_in_re(literal lit) { @@ -118,7 +91,9 @@ namespace smt { expr* e = ctx.bool_var2expr(lit.var()); VERIFY(str().is_in_re(e, s, r)); - TRACE("seq", tout << "propagate " << lit.sign() << " " << mk_pp(e, m) << "\n";); + TRACE("seq_regex", tout << "propagate in RE: " << lit.sign() << " " << mk_pp(e, m) << std::endl;); + STRACE("seq_regex_brief", tout << "PIR(" << mk_pp(s, m) << "," + << state_str(r) << ") ";); // convert negative negative membership literals to positive // ~(s in R) => s in C(R) @@ -140,21 +115,6 @@ namespace smt { if (is_string_equality(lit)) return; - // - // TBD s in R => R != {} - // non-emptiness enforcement could instead of here, - // be added to propagate_accept after some threshold is met. - // - if (false) { - expr_ref is_empty(m.mk_eq(r, re().mk_empty(m.get_sort(s))), m); - rewrite(is_empty); - literal is_emptyl = th.mk_literal(is_empty); - if (ctx.get_assignment(is_emptyl) != l_false) { - th.propagate_lit(nullptr, 1, &lit, ~is_emptyl); - return; - } - } - expr_ref zero(a().mk_int(0), m); expr_ref acc = sk().mk_accept(s, zero, r); literal acc_lit = th.mk_literal(acc); @@ -164,27 +124,30 @@ namespace smt { th.propagate_lit(nullptr, 1, &lit, acc_lit); } - void seq_regex::propagate_accept(literal lit) { - // std::cout << "PA "; - literal t = null_literal; - if (!propagate(lit, t)) - m_to_propagate.push_back(propagation_lit(lit, t)); - } - /** * Propagate the atom (accept s i r) - * - * Propagation implements the following inference rules * - * (accept s i r[if(c,r1,r2)]) & c => (accept s i r[r1]) - * (accept s i r[if(c,r1,r2)]) & ~c => (accept s i r[r2]) - * (accept s i r) & nullable(r) => len(s) >= i - * (accept s i r) & ~nullable(r) => len(s) >= i + 1 - * (accept s i r) & len(s) <= i => nullable(r) - * (accept s i r) & len(s) > i => (accept s (+ i 1) D(nth(s,i), r)) + * Propagation triggers updating the state graph for dead state detection: + * (accept s i r) => update_state_graph(r) + * (accept s i r) & dead(r) => false + * + * Propagation is also blocked under certain conditions to throttle + * state space exploration past a certain point: see block_unfolding + * + * Otherwise, propagation implements the following inference rules: + * + * Rule 1. (accept s i r) => len(s) >= i + min_len(r) + * Rule 2. (accept s i r) & len(s) <= i => nullable(r) + * Rule 3. (accept s i r) and len(s) > i => + * (accept s (i + 1) (derivative s[i] r) + * + * Acceptance of a derivative is unfolded into a disjunction over + * all derivatives. Effectively, this implements the following rule, + * but all in one step: + * (accept s i (ite c r1 r2)) => + * c & (accept s i r1) \/ ~c & (accept s i r2) */ - - bool seq_regex::propagate(literal lit, literal& trigger) { + void seq_regex::propagate_accept(literal lit) { SASSERT(!lit.sign()); expr* s = nullptr, *i = nullptr, *r = nullptr; @@ -192,146 +155,78 @@ namespace smt { unsigned idx = 0; VERIFY(sk().is_accept(e, s, i, idx, r)); - // std::cout << "\nP " << idx << " " << r->get_id() << " "; - - TRACE("seq", tout << "propagate " << mk_pp(e, m) << "\n";); + TRACE("seq_regex", tout << "propagate accept: " + << mk_pp(e, m) << std::endl;); + STRACE("seq_regex_brief", tout << std::endl + << "PA(" << mk_pp(s, m) << "@" << idx + << "," << state_str(r) << ") ";); if (re().is_empty(r)) { + STRACE("seq_regex_brief", tout << "(empty) ";); th.add_axiom(~lit); - return true; + return; } - if (block_unfolding(lit, idx)) - return true; - - propagate_nullable(lit, s, idx, r); - - return propagate_derivative(lit, e, s, i, idx, r, trigger); - } + update_state_graph(r); - /** - Implement the two axioms as propagations: - - (accept s i r) => len(s) >= i - (accept s i r) & ~nullable(r) => len(s) >= i + 1 - - evaluate nullable(r): - nullable(r) := true -> propagate: (accept s i r) => len(s) >= i - nullable(r) := false -> propagate: (accept s i r) => len(s) >= i + 1 - - Otherwise: - propagate: (accept s i r) => len(s) >= i - evaluate len(s) <= i: - len(s) <= i := undef -> axiom: (accept s i r) & len(s) <= i => nullable(r) - len(s) <= i := true -> propagate: (accept s i r) & len(s) <= i => nullable(r) - len(s) <= i := false -> noop. - - */ + if (m_state_graph.is_dead(get_state_id(r))) { + STRACE("seq_regex_brief", tout << "(dead) ";); + th.add_axiom(~lit); + return; + } - void seq_regex::propagate_nullable(literal lit, expr* s, unsigned idx, expr* r) { - // std::cout << "PN "; - expr_ref is_nullable = seq_rw().is_nullable(r); - rewrite(is_nullable); - literal len_s_ge_i = th.m_ax.mk_ge(th.mk_len(s), idx); - if (m.is_true(is_nullable)) { - th.propagate_lit(nullptr, 1,&lit, len_s_ge_i); + if (block_unfolding(lit, idx)) { + STRACE("seq_regex_brief", tout << "(blocked) ";); + return; } - else if (m.is_false(is_nullable)) { - th.propagate_lit(nullptr, 1, &lit, th.m_ax.mk_ge(th.mk_len(s), idx + 1)); - //unsigned len = std::max(1u, re().min_length(r)); - //th.propagate_lit(nullptr, 1, &lit, th.m_ax.mk_ge(th.mk_len(s), idx + re().min_length(r))); + + STRACE("seq_regex_brief", tout << "(unfold) ";); + + // Rule 1: use min_length to prune search + expr_ref s_to_re(re().mk_to_re(s), m); + expr_ref s_plus_r(re().mk_concat(s_to_re, r), m); + unsigned min_len = re().min_length(s_plus_r); + literal len_s_ge_min = th.m_ax.mk_ge(th.mk_len(s), min_len); + th.propagate_lit(nullptr, 1, &lit, len_s_ge_min); + // Axiom equivalent to the above: th.add_axiom(~lit, len_s_ge_min); + + // Rule 2: nullable check + literal len_s_le_i = th.m_ax.mk_le(th.mk_len(s), idx); + expr_ref is_nullable = is_nullable_wrapper(r); + if (m.is_false(is_nullable)) { + th.propagate_lit(nullptr, 1, &lit, ~len_s_le_i); } - else { - literal is_nullable_lit = th.mk_literal(is_nullable); + else if (!m.is_true(is_nullable)) { + // is_nullable did not simplify + literal is_nullable_lit = th.mk_literal(is_nullable_wrapper(r)); ctx.mark_as_relevant(is_nullable_lit); - literal len_s_le_i = th.m_ax.mk_le(th.mk_len(s), idx); - switch (ctx.get_assignment(len_s_le_i)) { - case l_undef: - th.add_axiom(~lit, ~len_s_le_i, is_nullable_lit); - break; - case l_true: { - literal lits[2] = { lit, len_s_le_i }; - th.propagate_lit(nullptr, 2, lits, is_nullable_lit); - break; - } - case l_false: - break; - } - th.propagate_lit(nullptr, 1, &lit, len_s_ge_i); + th.add_axiom(~lit, ~len_s_le_i, is_nullable_lit); } - } - - bool seq_regex::propagate_derivative(literal lit, expr* e, expr* s, expr* i, unsigned idx, expr* r, literal& trigger) { - // (accept s i R) & len(s) > i => (accept s (+ i 1) D(nth(s, i), R)) or conds - // std::cout << "PD "; - expr_ref d(m); - expr_ref head = th.mk_nth(s, i); - - d = derivative_wrapper(m.mk_var(0, m.get_sort(head)), r); - // timer tm; - // std::cout << d->get_id() << " " << tm.get_seconds() << "\n"; - //if (tm.get_seconds() > 0.3) - // std::cout << d << "\n"; - // std::cout.flush(); - literal_vector conds; - conds.push_back(~lit); - conds.push_back(th.m_ax.mk_le(th.mk_len(s), idx)); - expr* cond = nullptr, *tt = nullptr, *el = nullptr; - var_subst subst(m); - expr_ref_vector sub(m); - sub.push_back(head); - // s in R[if(p,R1,R2)] & p => s in R[R1] - // s in R[if(p,R1,R2)] & ~p => s in R[R2] - while (m.is_ite(d, cond, tt, el)) { - literal lcond = th.mk_literal(subst(cond, sub)); - switch (ctx.get_assignment(lcond)) { - case l_true: - conds.push_back(~lcond); - d = tt; - break; - case l_false: - conds.push_back(lcond); - d = el; - break; - case l_undef: -#if 1 - ctx.mark_as_relevant(lcond); - trigger = lcond; - return false; -#else - if (re().is_empty(tt)) { - literal_vector ensure_false(conds); - ensure_false.push_back(~lcond); - th.add_axiom(ensure_false); - conds.push_back(lcond); - d = el; - } - else if (re().is_empty(el)) { - literal_vector ensure_true(conds); - ensure_true.push_back(lcond); - th.add_axiom(ensure_true); - conds.push_back(~lcond); - d = tt; - } - else { - ctx.mark_as_relevant(lcond); - trigger = lcond; - return false; - } - break; -#endif - } - } - if (!is_ground(d)) { - d = subst(d, sub); + + // Rule 3: derivative unfolding + literal_vector accept_next; + expr_ref hd = th.mk_nth(s, i); + expr_ref deriv(m); + deriv = derivative_wrapper(hd, r); + accept_next.push_back(~lit); + accept_next.push_back(len_s_le_i); + expr_ref_pair_vector cofactors(m); + get_cofactors(deriv, cofactors); + for (auto const& p : cofactors) { + if (m.is_false(p.first) || re().is_empty(p.second)) continue; + expr_ref cond(p.first, m); + expr_ref deriv_leaf(p.second, m); + + expr_ref acc = sk().mk_accept(s, a().mk_int(idx + 1), deriv_leaf); + expr_ref choice(m.mk_and(cond, acc), m); + literal choice_lit = th.mk_literal(choice); + accept_next.push_back(choice_lit); + // TBD: try prioritizing unvisited states here over visited + // ones (in the state graph), to improve performance + STRACE("seq_regex_verbose", tout << "added choice: " + << mk_pp(choice, m) << std::endl;); } - // at this point there should be no free variables as the ites are at top-level. - if (!re().is_empty(d)) - conds.push_back(th.mk_literal(sk().mk_accept(s, a().mk_int(idx + 1), d))); - th.add_axiom(conds); - TRACE("seq", tout << "unfold " << head << "\n" << mk_pp(r, m) << "\n";); - // std::cout << "D "; - return true; + th.add_axiom(accept_next); } /** @@ -352,7 +247,7 @@ namespace smt { * within the same Regex. */ bool seq_regex::coallesce_in_re(literal lit) { - return false; + return false; // disabled expr* s = nullptr, *r = nullptr; expr* e = ctx.bool_var2expr(lit.var()); VERIFY(str().is_in_re(e, s, r)); @@ -372,7 +267,7 @@ namespace smt { th.m_trail_stack.push(vector_value_trail(m_s_in_re, i)); m_s_in_re[i].m_active = false; IF_VERBOSE(11, verbose_stream() << "Intersect " << regex << " " << - mk_pp(entry.m_re, m) << " " << mk_pp(s, m) << " " << mk_pp(entry.m_s, m) << "\n";); + mk_pp(entry.m_re, m) << " " << mk_pp(s, m) << " " << mk_pp(entry.m_s, m) << std::endl;); regex = re().mk_inter(entry.m_re, regex); rewrite(regex); lits.push_back(~entry.m_lit); @@ -402,17 +297,71 @@ namespace smt { } /* - Wrapper around the regex symbolic derivative from the rewriter. + Wrapper around calls to is_nullable from the seq rewriter. + + Note: the nullable wrapper and derivative wrapper actually use + different sequence rewriters; these are at: + m_seq_rewrite + (returned by seq_rw()) + th.m_rewrite.m_imp->m_cfg.m_seq_rw + (private, can't be accessed directly) + As a result operations are cached separately for the nullable + and derivative calls. TBD if caching them using the same rewriter + makes any difference. + */ + expr_ref seq_regex::is_nullable_wrapper(expr* r) { + STRACE("seq_regex", tout << "nullable: " << mk_pp(r, m) << std::endl;); + + expr_ref result = seq_rw().is_nullable(r); + rewrite(result); + + STRACE("seq_regex", tout << "nullable result: " << mk_pp(result, m) << std::endl;); + STRACE("seq_regex_brief", tout << "n(" << state_str(r) << ")=" + << mk_pp(result, m) << " ";); + + return result; + } + + /* + Wrapper around the regex symbolic derivative from the seq rewriter. Ensures that the derivative is written in a normalized BDD form with optimizations for if-then-else expressions involving the head. + + Note: the nullable wrapper and derivative wrapper actually use + different sequence rewriters; these are at: + m_seq_rewrite + (returned by seq_rw()) + th.m_rewrite.m_imp->m_cfg.m_seq_rw + (private, can't be accessed directly) + As a result operations are cached separately for the nullable + and derivative calls. TBD if caching them using the same rewriter + makes any difference. */ expr_ref seq_regex::derivative_wrapper(expr* hd, expr* r) { - expr_ref result = expr_ref(re().mk_derivative(hd, r), m); + STRACE("seq_regex", tout << "derivative(" << mk_pp(hd, m) << "): " << mk_pp(r, m) << std::endl;); + + // Use canonical variable for head + expr_ref hd_canon(m.mk_var(0, m.get_sort(hd)), m); + expr_ref result(re().mk_derivative(hd_canon, r), m); rewrite(result); + + // Substitute with real head + var_subst subst(m); + expr_ref_vector sub(m); + sub.push_back(hd); + result = subst(result, sub); + + STRACE("seq_regex", tout << "derivative result: " << mk_pp(result, m) << std::endl;); + STRACE("seq_regex_brief", tout << "d(" << state_str(r) << ")=" + << state_str(result) << " ";); + return result; } void seq_regex::propagate_eq(expr* r1, expr* r2) { + TRACE("seq_regex", tout << "propagate EQ: " << mk_pp(r1, m) << ", " << mk_pp(r2, m) << std::endl;); + STRACE("seq_regex_brief", tout << "PEQ ";); + sort* seq_sort = nullptr; VERIFY(u().is_re(r1, seq_sort)); expr_ref r = symmetric_diff(r1, r2); @@ -423,6 +372,9 @@ namespace smt { } void seq_regex::propagate_ne(expr* r1, expr* r2) { + TRACE("seq_regex", tout << "propagate NEQ: " << mk_pp(r1, m) << ", " << mk_pp(r2, m) << std::endl;); + STRACE("seq_regex_brief", tout << "PNEQ ";); + sort* seq_sort = nullptr; VERIFY(u().is_re(r1, seq_sort)); expr_ref r = symmetric_diff(r1, r2); @@ -452,18 +404,25 @@ namespace smt { void seq_regex::propagate_is_non_empty(literal lit) { expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr, *n = nullptr; VERIFY(sk().is_is_non_empty(e, r, u, n)); - expr_ref is_nullable = seq_rw().is_nullable(r); - rewrite(is_nullable); + + TRACE("seq_regex", tout << "propagate nonempty: " << mk_pp(e, m) << std::endl;); + STRACE("seq_regex_brief", tout + << std::endl << "PNE(" << expr_id_str(e) << "," << state_str(r) + << "," << expr_id_str(u) << "," << expr_id_str(n) << ") ";); + + expr_ref is_nullable = is_nullable_wrapper(r); if (m.is_true(is_nullable)) return; literal null_lit = th.mk_literal(is_nullable); expr_ref hd = mk_first(r, n); expr_ref d(m); d = derivative_wrapper(hd, r); + literal_vector lits; lits.push_back(~lit); if (null_lit != false_literal) lits.push_back(null_lit); + expr_ref_pair_vector cofactors(m); get_cofactors(d, cofactors); for (auto const& p : cofactors) { @@ -474,11 +433,12 @@ namespace smt { rewrite(cond); if (m.is_false(cond)) continue; - expr_ref next_non_empty = sk().mk_is_non_empty(p.second, re().mk_union(u, p.second), n); + expr_ref next_non_empty = sk().mk_is_non_empty(p.second, re().mk_union(u, r), n); if (!m.is_true(cond)) next_non_empty = m.mk_and(cond, next_non_empty); lits.push_back(th.mk_literal(next_non_empty)); } + th.add_axiom(lits); } @@ -498,6 +458,25 @@ namespace smt { } } + void seq_regex::get_all_derivatives(expr* r, expr_ref_vector& results) { + // Get derivative + sort* seq_sort = nullptr; + VERIFY(u().is_re(r, seq_sort)); + expr_ref n(m.mk_fresh_const("re.char", seq_sort), m); + expr_ref hd = mk_first(r, n); + expr_ref d(m); + d = derivative_wrapper(hd, r); + // Use get_cofactors method and try to filter out unsatisfiable conds + expr_ref_pair_vector cofactors(m); + get_cofactors(d, cofactors); + STRACE("seq_regex_verbose", tout << "getting all derivatives of: " << mk_pp(r, m) << std::endl;); + for (auto const& p : cofactors) { + if (m.is_false(p.first) || re().is_empty(p.second)) continue; + STRACE("seq_regex_verbose", tout << "adding derivative: " << mk_pp(p.second, m) << std::endl;); + results.push_back(p.second); + } + } + /* is_empty(r, u) => ~is_nullable(r) is_empty(r, u) => (forall x . ~cond(x)) or is_empty(r1, u union r) for (cond, r) in min-terms(D(x,r)) @@ -507,8 +486,13 @@ namespace smt { void seq_regex::propagate_is_empty(literal lit) { expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr, *n = nullptr; VERIFY(sk().is_is_empty(e, r, u, n)); - expr_ref is_nullable = seq_rw().is_nullable(r); - rewrite(is_nullable); + expr_ref is_nullable = is_nullable_wrapper(r); + + TRACE("seq_regex", tout << "propagate empty: " << mk_pp(e, m) << std::endl;); + STRACE("seq_regex_brief", tout + << std::endl << "PE(" << expr_id_str(e) << "," << state_str(r) + << "," << expr_id_str(u) << "," << expr_id_str(n) << ") ";); + if (m.is_true(is_nullable)) { th.add_axiom(~lit); return; @@ -546,4 +530,89 @@ namespace smt { VERIFY(u().is_seq(seq_sort, elem_sort)); return sk().mk("re.first", n, a().mk_int(r->get_id()), elem_sort); } + + /** + * Dead state elimination using the state_graph class + */ + + unsigned seq_regex::get_state_id(expr* e) { + // Assign increasing IDs starting from 1 + if (!m_expr_to_state.contains(e)) { + m_state_to_expr.push_back(e); + unsigned new_id = m_state_to_expr.size(); + m_expr_to_state.insert(e, new_id); + STRACE("seq_regex_brief", tout << "new(" << expr_id_str(e) + << ")=" << state_str(e) << " ";); + } + return m_expr_to_state.find(e); + } + expr* seq_regex::get_expr_from_id(unsigned id) { + SASSERT(id >= 1); + SASSERT(id <= m_state_to_expr.size()); + return m_state_to_expr.get(id); + } + + bool seq_regex::can_be_in_cycle(expr *r1, expr *r2) { + // TBD: This can be used to optimize the state graph: + // return false here if it is known that r1 -> r2 can never be + // in a cycle. There are various easy syntactic checks on r1 and r2 + // that can be used to infer this (e.g. star height, or length if + // both are star-free). + // This check need not be sound, but if it is not, some dead states + // will be missed. + return true; + } + + /* + Update the state graph with expression r and all its derivatives. + */ + bool seq_regex::update_state_graph(expr* r) { + unsigned r_id = get_state_id(r); + if (m_state_graph.is_done(r_id)) return false; + if (m_state_graph.get_size() >= m_max_state_graph_size) { + STRACE("seq_regex", tout << "Warning: ignored state graph update -- max size of seen states reached!" << std::endl;); + STRACE("seq_regex_brief", tout << "(MAX SIZE REACHED) ";); + return false; + } + STRACE("seq_regex", tout << "Updating state graph for regex " + << mk_pp(r, m) << ") ";); + // Add state + m_state_graph.add_state(r_id); + STRACE("seq_regex_brief", tout << std::endl << "USG(" + << state_str(r) << ") ";); + expr_ref r_nullable = is_nullable_wrapper(r); + if (m.is_true(r_nullable)) { + m_state_graph.mark_live(r_id); + } + else { + // Add edges to all derivatives + expr_ref_vector derivatives(m); + STRACE("seq_regex_verbose", tout + << std::endl << " getting all derivs: " << r_id << " ";); + get_all_derivatives(r, derivatives); + for (auto const& dr: derivatives) { + unsigned dr_id = get_state_id(dr); + STRACE("seq_regex_verbose", tout + << std::endl << " traversing deriv: " << dr_id << " ";); + m_state_graph.add_state(dr_id); + bool maybecycle = can_be_in_cycle(r, dr); + m_state_graph.add_edge(r_id, dr_id, maybecycle); + } + m_state_graph.mark_done(r_id); + } + STRACE("seq_regex_brief", tout << std::endl;); + STRACE("seq_regex_brief", m_state_graph.display(tout);); + return true; + } + + std::string seq_regex::state_str(expr* e) { + if (m_expr_to_state.contains(e)) + return std::to_string(get_state_id(e)); + else + return expr_id_str(e); + } + std::string seq_regex::expr_id_str(expr* e) { + return std::string("id") + std::to_string(e->get_id()); + } + } diff --git a/src/smt/seq_regex.h b/src/smt/seq_regex.h index 1d77cf81dbe..7d4cfb65112 100644 --- a/src/smt/seq_regex.h +++ b/src/smt/seq_regex.h @@ -1,5 +1,5 @@ /*++ -Copyright (c) 2011 Microsoft Corporation +Copyright (c) 2020 Microsoft Corporation Module Name: @@ -17,6 +17,7 @@ Module Name: #pragma once #include "util/scoped_vector.h" +#include "util/state_graph.h" #include "ast/seq_decl_plugin.h" #include "ast/rewriter/seq_rewriter.h" #include "smt/smt_context.h" @@ -27,6 +28,7 @@ namespace smt { class theory_seq; class seq_regex { + // Data about a constraint of the form (str.in_re s R) struct s_in_re { literal m_lit; expr* m_s; @@ -36,20 +38,34 @@ namespace smt { m_lit(l), m_s(s), m_re(r), m_active(true) {} }; - struct propagation_lit { - literal m_lit; - literal m_trigger; - propagation_lit(literal lit, literal t): m_lit(lit), m_trigger(t) {} - propagation_lit(literal lit): m_lit(lit), m_trigger(null_literal) {} - propagation_lit(): m_lit(null_literal), m_trigger(null_literal) {} - }; - - theory_seq& th; - context& ctx; - ast_manager& m; - vector m_s_in_re; - scoped_vector m_to_propagate; - + theory_seq& th; + context& ctx; + ast_manager& m; + vector m_s_in_re; + + /* + state_graph for dead state detection, and associated methods + */ + state_graph m_state_graph; + ptr_addr_map m_expr_to_state; + expr_ref_vector m_state_to_expr; + unsigned m_max_state_graph_size { 10000 }; + // Convert between expressions and states (IDs) + unsigned get_state_id(expr* e); + expr* get_expr_from_id(unsigned id); + // Cycle-detection heuristic + // Note: Doesn't need to be sound or complete (doesn't affect soundness) + bool can_be_in_cycle(expr* r1, expr* r2); + // Update the graph + bool update_state_graph(expr* r); + + // Printing expressions for seq_regex_brief + std::string state_str(expr* e); + std::string expr_id_str(expr* e); + + /* + Solvers and utilities + */ seq_util& u(); class seq_util::re& re(); class seq_util::str& str(); @@ -63,42 +79,32 @@ namespace smt { bool coallesce_in_re(literal lit); - bool propagate(literal lit, literal& trigger); - bool block_unfolding(literal lit, unsigned i); - void propagate_nullable(literal lit, expr* s, unsigned idx, expr* r); - - bool propagate_derivative(literal lit, expr* e, expr* s, expr* i, unsigned idx, expr* r, literal& trigger); - expr_ref mk_first(expr* r, expr* n); - expr_ref unroll_non_empty(expr* r, expr_mark& seen, unsigned depth); - bool is_member(expr* r, expr* u); expr_ref symmetric_diff(expr* r1, expr* r2); + expr_ref is_nullable_wrapper(expr* r); expr_ref derivative_wrapper(expr* hd, expr* r); void get_cofactors(expr* r, expr_ref_vector& conds, expr_ref_pair_vector& result); - void get_cofactors(expr* r, expr_ref_pair_vector& result) { expr_ref_vector conds(m); get_cofactors(r, conds, result); } + void get_all_derivatives(expr* r, expr_ref_vector& results); public: seq_regex(theory_seq& th); - void push_scope() { m_to_propagate.push_scope(); } - - void pop_scope(unsigned num_scopes) { m_to_propagate.pop_scope(num_scopes); } - - bool can_propagate() const; - - bool propagate(); + void push_scope() {} + void pop_scope(unsigned num_scopes) {} + bool can_propagate() const { return false; } + bool propagate() const { return false; } void propagate_in_re(literal lit); @@ -117,4 +123,3 @@ namespace smt { }; }; - diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 13bc4715cdc..21d45453ce7 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -51,6 +51,7 @@ z3_add_component(util small_object_allocator.cpp smt2_util.cpp stack.cpp + state_graph.cpp statistics.cpp symbol.cpp timeit.cpp @@ -67,6 +68,7 @@ z3_add_component(util prime_generator.h rational.h rlimit.h + state_graph.h symbol.h trace.h ) diff --git a/src/util/state_graph.cpp b/src/util/state_graph.cpp new file mode 100644 index 00000000000..baaf345d2be --- /dev/null +++ b/src/util/state_graph.cpp @@ -0,0 +1,410 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + state_graph.cpp + +Abstract: + + Data structure for incrementally tracking "live" and "dead" states in an + abstract transition system. + +Author: + + Caleb Stanford (calebstanford-msr / cdstanford) 2020-7 + +--*/ + +#include "state_graph.h" + +void state_graph::add_state_core(state s) { + STRACE("state_graph", tout << "add(" << s << ") ";); + SASSERT(!m_seen.contains(s)); + // Ensure corresponding var in union find structure + while (s >= m_state_ufind.get_num_vars()) { + m_state_ufind.mk_var(); + } + // Initialize as unvisited + m_seen.insert(s); + m_unexplored.insert(s); + m_targets.insert(s, state_set()); + m_sources.insert(s, state_set()); + m_sources_maybecycle.insert(s, state_set()); +} +void state_graph::remove_state_core(state s) { + // This is a partial deletion -- the state is still seen and can't be + // added again later. + // The state should be unknown, and all edges to or from the state + // should already have been renamed. + STRACE("state_graph", tout << "del(" << s << ") ";); + SASSERT(m_seen.contains(s)); + SASSERT(!m_state_ufind.is_root(s)); + SASSERT(m_unknown.contains(s)); + m_targets.remove(s); + m_sources.remove(s); + m_sources_maybecycle.remove(s); + m_unknown.remove(s); +} + +void state_graph::mark_unknown_core(state s) { + STRACE("state_graph", tout << "unk(" << s << ") ";); + SASSERT(m_state_ufind.is_root(s)); + SASSERT(m_unexplored.contains(s)); + m_unexplored.remove(s); + m_unknown.insert(s); +} +void state_graph::mark_live_core(state s) { + STRACE("state_graph", tout << "live(" << s << ") ";); + SASSERT(m_state_ufind.is_root(s)); + SASSERT(m_unknown.contains(s)); + m_unknown.remove(s); + m_live.insert(s); +} +void state_graph::mark_dead_core(state s) { + STRACE("state_graph", tout << "dead(" << s << ") ";); + SASSERT(m_state_ufind.is_root(s)); + SASSERT(m_unknown.contains(s)); + m_unknown.remove(s); + m_dead.insert(s); +} + +/* + Add edge to the graph. + - If the annotation 'maybecycle' is false, then the user is sure + that this edge will never be part of a cycle. + - May already exist, in which case maybecycle = false overrides + maybecycle = true. +*/ +void state_graph::add_edge_core(state s1, state s2, bool maybecycle) { + STRACE("state_graph", tout << "add(" << s1 << "," << s2 << "," + << (maybecycle ? "y" : "n") << ") ";); + SASSERT(m_state_ufind.is_root(s1)); + SASSERT(m_state_ufind.is_root(s2)); + if (s1 == s2) return; + if (!m_targets[s1].contains(s2)) { + // add new edge + m_targets[s1].insert(s2); + m_sources[s2].insert(s1); + if (maybecycle) m_sources_maybecycle[s2].insert(s1); + } + else if (!maybecycle && m_sources_maybecycle[s2].contains(s1)) { + // update existing edge + m_sources_maybecycle[s2].remove(s1); + } +} +void state_graph::remove_edge_core(state s1, state s2) { + STRACE("state_graph", tout << "del(" << s1 << "," << s2 << ") ";); + SASSERT(m_targets[s1].contains(s2)); + SASSERT(m_sources[s2].contains(s1)); + m_targets[s1].remove(s2); + m_sources[s2].remove(s1); + m_sources_maybecycle[s2].remove(s1); +} +void state_graph::rename_edge_core(state old1, state old2, + state new1, state new2) { + SASSERT(m_targets[old1].contains(old2)); + SASSERT(m_sources[old2].contains(old1)); + bool maybecycle = m_sources_maybecycle[old2].contains(old1); + remove_edge_core(old1, old2); + add_edge_core(new1, new2, maybecycle); +} + +/* + Merge two states or more generally a set of states into one, + returning the new state. Also merges associated edges. + + Preconditions: + - The set should be nonempty + - Every state in the set should be unknown + - Each state should currently exist + - If passing a set of states by reference, it should not be a set + from the edge relations, as merging states modifies edge relations. +*/ +auto state_graph::merge_states(state s1, state s2) -> state { + SASSERT(m_state_ufind.is_root(s1)); + SASSERT(m_state_ufind.is_root(s2)); + SASSERT(m_unknown.contains(s1)); + SASSERT(m_unknown.contains(s2)); + STRACE("state_graph", tout << "merge(" << s1 << "," << s2 << ") ";); + m_state_ufind.merge(s1, s2); + if (m_state_ufind.is_root(s2)) std::swap(s1, s2); + // rename s2 to s1 in edges + for (auto s_to: m_targets[s2]) { + rename_edge_core(s2, s_to, s1, s_to); + } + for (auto s_from: m_sources[s2]) { + rename_edge_core(s_from, s2, s_from, s1); + } + remove_state_core(s2); + return s1; +} +auto state_graph::merge_states(state_set& s_set) -> state { + SASSERT(s_set.num_elems() > 0); + state prev_s = 0; // initialization here optional + bool first_iter = true; + for (auto s: s_set) { + if (first_iter) { + prev_s = s; + first_iter = false; + continue; + } + prev_s = merge_states(prev_s, s); + } + return prev_s; +} + +/* + If s is not live, mark it, and recurse on all states into s + Precondition: s is live or unknown +*/ +void state_graph::mark_live_recursive(state s) { + SASSERT(m_live.contains(s) || m_unknown.contains(s)); + vector to_search; + to_search.push_back(s); + while (to_search.size() > 0) { + state x = to_search.back(); + to_search.pop_back(); + SASSERT(m_live.contains(x) || m_unknown.contains(x)); + if (m_live.contains(x)) continue; + mark_live_core(x); + for (auto x_from: m_sources[x]) { + to_search.push_back(x_from); + } + } +} + +/* + Check if all targets of a state are dead. + Precondition: s is unknown +*/ +bool state_graph::all_targets_dead(state s) { + SASSERT(m_unknown.contains(s)); + for (auto s_to: m_targets[s]) { + // unknown pointing to live should have been marked as live! + SASSERT(!m_live.contains(s_to)); + if (m_unknown.contains(s_to) || m_unexplored.contains(s_to)) + return false; + } + return true; +} +/* + Check if s is now known to be dead. If so, mark and recurse + on all states into s. + Precondition: s is live, dead, or unknown +*/ +void state_graph::mark_dead_recursive(state s) { + SASSERT(m_live.contains(s) || m_dead.contains(s) || m_unknown.contains(s)); + vector to_search; + to_search.push_back(s); + while (to_search.size() > 0) { + state x = to_search.back(); + to_search.pop_back(); + if (!m_unknown.contains(x)) continue; + if (!all_targets_dead(x)) continue; + // x is unknown and all targets from x are dead + mark_dead_core(x); + for (auto x_from: m_sources[x]) { + to_search.push_back(x_from); + } + } +} + +/* + Merge all cycles of unknown states containing s into one state. + Return the new state + Precondition: s is unknown. +*/ +auto state_graph::merge_all_cycles(state s) -> state { + SASSERT(m_unknown.contains(s)); + // Visit states in a DFS backwards from s + state_set visited; // all backwards edges pushed + state_set resolved; // known in SCC or not + state_set scc; // known in SCC + resolved.insert(s); + scc.insert(s); + vector to_search; + to_search.push_back(s); + while (to_search.size() > 0) { + state x = to_search.back(); + if (!visited.contains(x)) { + visited.insert(x); + // recurse backwards only on maybecycle edges + // and only on unknown states + for (auto y: m_sources_maybecycle[x]) { + if (m_unknown.contains(y)) + to_search.push_back(y); + } + } + else if (!resolved.contains(x)) { + resolved.insert(x); + to_search.pop_back(); + // determine in SCC or not + for (auto y: m_sources_maybecycle[x]) { + if (scc.contains(y)) { + scc.insert(x); + break; + } + } + } + else { + to_search.pop_back(); + } + } + // scc is the union of all cycles containing s + return merge_states(scc); +} + +/* + Exposed methods +*/ + +void state_graph::add_state(state s) { + if (m_seen.contains(s)) return; + STRACE("state_graph", tout << "[state_graph] adding state " << s << ": ";); + add_state_core(s); + CASSERT("state_graph", check_invariant()); + STRACE("state_graph", tout << std::endl;); +} +void state_graph::mark_live(state s) { + STRACE("state_graph", tout << "[state_graph] marking live " << s << ": ";); + SASSERT(m_unexplored.contains(s) || m_live.contains(s)); + SASSERT(m_state_ufind.is_root(s)); + if (m_unexplored.contains(s)) mark_unknown_core(s); + mark_live_recursive(s); + CASSERT("state_graph", check_invariant()); + STRACE("state_graph", tout << std::endl;); +} +void state_graph::add_edge(state s1, state s2, bool maybecycle) { + STRACE("state_graph", tout << "[state_graph] adding edge " + << s1 << "->" << s2 << ": ";); + SASSERT(m_unexplored.contains(s1) || m_live.contains(s1)); + SASSERT(m_state_ufind.is_root(s1)); + SASSERT(m_seen.contains(s2)); + s2 = m_state_ufind.find(s2); + add_edge_core(s1, s2, maybecycle); + if (m_live.contains(s2)) mark_live(s1); + CASSERT("state_graph", check_invariant()); + STRACE("state_graph", tout << std::endl;); +} +void state_graph::mark_done(state s) { + SASSERT(m_unexplored.contains(s) || m_live.contains(s)); + SASSERT(m_state_ufind.is_root(s)); + if (m_live.contains(s)) return; + STRACE("state_graph", tout << "[state_graph] marking done " << s << ": ";); + if (m_unexplored.contains(s)) mark_unknown_core(s); + s = merge_all_cycles(s); + mark_dead_recursive(s); // check if dead + CASSERT("state_graph", check_invariant()); + STRACE("state_graph", tout << std::endl;); +} + +unsigned state_graph::get_size() const { + return m_state_ufind.get_num_vars(); +} + +bool state_graph::is_seen(state s) const { + return m_seen.contains(s); +} +bool state_graph::is_live(state s) const { + return m_live.contains(m_state_ufind.find(s)); +} +bool state_graph::is_dead(state s) const { + return m_dead.contains(m_state_ufind.find(s)); +} +bool state_graph::is_done(state s) const { + return m_seen.contains(s) && !m_unexplored.contains(m_state_ufind.find(s)); +} + +/* + Class invariants check (and associated auxiliary functions) + + check_invariant performs a sequence of SASSERT assertions, + then always returns true. +*/ +#ifdef Z3DEBUG +bool state_graph::is_subset(state_set set1, state_set set2) const { + for (auto s1: set1) { + if (!set2.contains(s1)) return false; + } + return true; +} +bool state_graph::is_disjoint(state_set set1, state_set set2) const { + for (auto s1: set1) { + if (set2.contains(s1)) return false; + } + return true; +} +#define ASSERT_FOR_ALL_STATES(STATESET, COND) { \ + for (auto s: STATESET) { SASSERT(COND); }} ((void) 0) +#define ASSERT_FOR_ALL_EDGES(EDGEREL, COND) { \ + for (auto e: (EDGEREL)) { \ + state s1 = e.m_key; for (auto s2: e.m_value) { SASSERT(COND); } \ + }} ((void) 0) +bool state_graph::check_invariant() const { + // Check state invariants + SASSERT(is_subset(m_live, m_seen)); + SASSERT(is_subset(m_dead, m_seen)); + SASSERT(is_subset(m_unknown, m_seen)); + SASSERT(is_subset(m_unexplored, m_seen)); + SASSERT(is_disjoint(m_live, m_dead)); + SASSERT(is_disjoint(m_live, m_unknown)); + SASSERT(is_disjoint(m_live, m_unexplored)); + SASSERT(is_disjoint(m_dead, m_unknown)); + SASSERT(is_disjoint(m_dead, m_unexplored)); + SASSERT(is_disjoint(m_unknown, m_unexplored)); + ASSERT_FOR_ALL_STATES(m_seen, s < m_state_ufind.get_num_vars()); + ASSERT_FOR_ALL_STATES(m_seen, + (m_state_ufind.is_root(s) == + (m_live.contains(s) || m_dead.contains(s) || + m_unknown.contains(s) || m_unexplored.contains(s)))); + // Check edge invariants + ASSERT_FOR_ALL_EDGES(m_sources_maybecycle, m_sources[s1].contains(s2)); + ASSERT_FOR_ALL_EDGES(m_sources, m_targets[s2].contains(s1)); + ASSERT_FOR_ALL_EDGES(m_targets, m_sources[s2].contains(s1)); + ASSERT_FOR_ALL_EDGES(m_targets, + m_state_ufind.is_root(s1) && m_state_ufind.is_root(s2)); + ASSERT_FOR_ALL_EDGES(m_targets, s1 != s2); + // Check relationship between states and edges + ASSERT_FOR_ALL_EDGES(m_targets, + !m_live.contains(s2) || m_live.contains(s1)); + ASSERT_FOR_ALL_STATES(m_dead, is_subset(m_targets[s], m_dead)); + ASSERT_FOR_ALL_STATES(m_unknown, !is_subset(m_targets[s], m_dead)); + // For the "no cycles" of unknown states on maybecycle edges, + // we only do a partial check for cycles of size 2 + ASSERT_FOR_ALL_EDGES(m_sources_maybecycle, + !(m_unknown.contains(s1) && m_unknown.contains(s2) && + m_sources_maybecycle[s2].contains(s1))); + + STRACE("state_graph", tout << "(invariant passed) ";); + return true; +} +#endif + +/* + Pretty printing +*/ +std::ostream& state_graph::display(std::ostream& o) const { + o << "---------- State Graph ----------" << std::endl + << "Seen:"; + for (auto s: m_seen) { + o << " " << s; + state s_root = m_state_ufind.find(s); + if (s_root != s) + o << "(=" << s_root << ")"; + } + o << std::endl + << "Live:" << m_live << std::endl + << "Dead:" << m_dead << std::endl + << "Unknown:" << m_unknown << std::endl + << "Unexplored:" << m_unexplored << std::endl + << "Edges:" << std::endl; + for (auto s1: m_seen) { + if (m_state_ufind.is_root(s1)) { + o << " " << s1 << " -> " << m_targets[s1] << std::endl; + } + } + o << "---------------------------------" << std::endl; + + return o; +} diff --git a/src/util/state_graph.h b/src/util/state_graph.h new file mode 100644 index 00000000000..190e3261a35 --- /dev/null +++ b/src/util/state_graph.h @@ -0,0 +1,182 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + state_graph.h + +Abstract: + + Data structure for incrementally tracking "live" and "dead" states in an + abstract transition system. + +Author: + + Caleb Stanford (calebstanford-msr / cdstanford) 2020-7 + +--*/ + +#pragma once + +#include "util/map.h" +#include "util/uint_set.h" +#include "util/union_find.h" +#include "util/vector.h" + +/* + state_graph + + Data structure which is capable of incrementally tracking + live states and dead states. + + "States" are integers. States and edges are added to the data + structure incrementally. + - States can be marked as "live" or "done". + "Done" signals that (1) no more outgoing edges will be + added and (2) the state will not be marked as live. The data + structure then tracks + which other states are live (can reach a live state), dead + (can't reach a live state), or neither. + - Some edges are labeled as not contained in a cycle. This is to + optimize search if it is known by the user of the structure + that no cycle will ever contain this edge. + + Internally, we use union_find to identify states within an SCC, + and incrementally update SCCs, while propagating backwards + live and dead SCCs. +*/ +class state_graph { +public: + typedef unsigned state; + typedef uint_set state_set; + typedef u_map edge_rel; + typedef basic_union_find state_ufind; + +private: + /* + All states are internally exactly one of: + - live: known to reach a live state + - dead: known to never reach a live state + - unknown: all outgoing edges have been added, but the + state is not known to be live or dead + - unexplored: not all outgoing edges have been added + + As SCCs are merged, some states become aliases, and a + union find data structure collapses a now obsolete + state to its current representative. m_seen keeps track + of states we have seen, including obsolete states. + */ + state_set m_live; + state_set m_dead; + state_set m_unknown; + state_set m_unexplored; + + state_set m_seen; + state_ufind m_state_ufind; + + /* + Edges are saved in both from and to maps. + A subset of edges are also marked as possibly being + part of a cycle by being stored in m_sources_maybecycle. + */ + edge_rel m_sources; + edge_rel m_targets; + edge_rel m_sources_maybecycle; + + /* + CLASS INVARIANTS + + *** To enable checking invariants, run z3 with -dbg:state_graph + (must also be in debug mode) *** + + State invariants: + - live, dead, unknown, and unexplored form a partition of + the set of roots in m_state_ufind + - all of these are subsets of m_seen + - everything in m_seen is an integer less than the number of variables + in m_state_ufind + + Edge invariants: + - all edges are between roots of m_state_ufind + - m_sources and m_targets are converses of each other + - no self-loops + - m_sources_maybecycle is a subrelation of m_sources + + Relationship between states and edges: + - every state with a live target is live + - every state with a dead source is dead + - every state with only dead targets is dead + - there are no cycles of unknown states on maybecycle edges + */ + #ifdef Z3DEBUG + bool is_subset(state_set set1, state_set set2) const; + bool is_disjoint(state_set set1, state_set set2) const; + bool check_invariant() const; + #endif + + /* + 'Core' functions that modify the plain graph, without + updating SCCs or propagating live/dead state information. + These are for internal use only. + */ + void add_state_core(state s); // unexplored + seen + void remove_state_core(state s); // unknown + seen -> seen + void mark_unknown_core(state s); // unexplored -> unknown + void mark_live_core(state s); // unknown -> live + void mark_dead_core(state s); // unknown -> dead + + void add_edge_core(state s1, state s2, bool maybecycle); + void remove_edge_core(state s1, state s2); + void rename_edge_core(state old1, state old2, state new1, state new2); + + state merge_states(state s1, state s2); + state merge_states(state_set& s_set); + + /* + Algorithmic search routines + - live state propagation + - dead state propagation + - cycle / strongly-connected component detection + */ + void mark_live_recursive(state s); + bool all_targets_dead(state s); + void mark_dead_recursive(state s); + state merge_all_cycles(state s); + +public: + state_graph(): + m_live(), m_dead(), m_unknown(), m_unexplored(), m_seen(), + m_state_ufind(), m_sources(), m_targets(), m_sources_maybecycle() + { + CASSERT("state_graph", check_invariant()); + } + + /* + Exposed methods + + These methods may be called in any order, as long as: + - states are added before edges are added between them + - outgoing edges are not added from a done state + - a done state is not marked as live + - edges are not added creating a cycle containing an edge with + maybecycle = false (this is not necessary for soundness, but + prevents completeness for successfully detecting dead states) + */ + void add_state(state s); + void add_edge(state s1, state s2, bool maybecycle); + void mark_live(state s); + void mark_done(state s); + + bool is_seen(state s) const; + bool is_live(state s) const; + bool is_dead(state s) const; + bool is_done(state s) const; + + unsigned get_size() const; + + /* + Pretty printing + */ + std::ostream& display(std::ostream& o) const; + +};