From 5857236f2f61b2c6dfb113c2618e89b104182138 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 29 Nov 2021 19:41:30 -0800 Subject: [PATCH] introducing base namespace for user propagator --- src/api/api_solver.cpp | 18 ++++---- src/sat/sat_solver/inc_sat_solver.cpp | 18 ++++---- src/sat/smt/euf_solver.cpp | 11 +++-- src/sat/smt/euf_solver.h | 15 ++++--- src/sat/smt/user_solver.h | 35 +++++++-------- src/smt/smt_context.cpp | 14 +++--- src/smt/smt_context.h | 16 +++---- src/smt/smt_kernel.cpp | 28 ++++++------ src/smt/smt_kernel.h | 14 +++--- src/smt/smt_solver.cpp | 14 +++--- src/smt/user_propagator.cpp | 32 +++++++------- src/smt/user_propagator.h | 36 ++++++++-------- src/solver/solver.h | 49 +-------------------- src/tactic/user_propagator_base.h | 61 +++++++++++++++++++++++++++ 14 files changed, 189 insertions(+), 172 deletions(-) create mode 100644 src/tactic/user_propagator_base.h diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 2de51c3b069..6f38d024687 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -866,7 +866,7 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } - class api_context_obj : public solver::context_obj { + class api_context_obj : public user_propagator::context_obj { api::context* c; public: api_context_obj(api::context* c):c(c) {} @@ -883,9 +883,9 @@ extern "C" { Z3_TRY; RESET_ERROR_CODE(); init_solver(c, s); - solver::push_eh_t _push = push_eh; - solver::pop_eh_t _pop = pop_eh; - solver::fresh_eh_t _fresh = [=](void * user_ctx, ast_manager& m, solver::context_obj*& _ctx) { + user_propagator::push_eh_t _push = push_eh; + user_propagator::pop_eh_t _pop = pop_eh; + user_propagator::fresh_eh_t _fresh = [=](void * user_ctx, ast_manager& m, user_propagator::context_obj*& _ctx) { ast_context_params params; params.set_foreign_manager(&m); auto* ctx = alloc(api::context, ¶ms, false); @@ -902,7 +902,7 @@ extern "C" { Z3_fixed_eh fixed_eh) { Z3_TRY; RESET_ERROR_CODE(); - solver::fixed_eh_t _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh; + user_propagator::fixed_eh_t _fixed = (void(*)(void*,user_propagator::callback*,unsigned,expr*))fixed_eh; to_solver_ref(s)->user_propagate_register_fixed(_fixed); Z3_CATCH; } @@ -913,7 +913,7 @@ extern "C" { Z3_final_eh final_eh) { Z3_TRY; RESET_ERROR_CODE(); - solver::final_eh_t _final = (bool(*)(void*,solver::propagate_callback*))final_eh; + user_propagator::final_eh_t _final = (bool(*)(void*,user_propagator::callback*))final_eh; to_solver_ref(s)->user_propagate_register_final(_final); Z3_CATCH; } @@ -924,7 +924,7 @@ extern "C" { Z3_eq_eh eq_eh) { Z3_TRY; RESET_ERROR_CODE(); - solver::eq_eh_t _eq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))eq_eh; + user_propagator::eq_eh_t _eq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))eq_eh; to_solver_ref(s)->user_propagate_register_eq(_eq); Z3_CATCH; } @@ -935,7 +935,7 @@ extern "C" { Z3_eq_eh diseq_eh) { Z3_TRY; RESET_ERROR_CODE(); - solver::eq_eh_t _diseq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))diseq_eh; + user_propagator::eq_eh_t _diseq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))diseq_eh; to_solver_ref(s)->user_propagate_register_diseq(_diseq); Z3_CATCH; } @@ -952,7 +952,7 @@ extern "C" { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); - reinterpret_cast(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq)); + reinterpret_cast(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq)); Z3_CATCH; } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index e2ce4363774..ea81f0c4d77 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -661,25 +661,25 @@ class inc_sat_solver : public solver { void user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) override { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) override { ensure_euf()->user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); } - void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override { + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) override { ensure_euf()->user_propagate_register_fixed(fixed_eh); } - void user_propagate_register_final(solver::final_eh_t& final_eh) override { + void user_propagate_register_final(user_propagator::final_eh_t& final_eh) override { ensure_euf()->user_propagate_register_final(final_eh); } - void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override { + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) override { ensure_euf()->user_propagate_register_eq(eq_eh); } - void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override { + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { ensure_euf()->user_propagate_register_diseq(diseq_eh); } @@ -959,11 +959,11 @@ class inc_sat_solver : public solver { extract_asm2dep(asm2dep); sat::literal_vector const& core = m_solver.get_core(); TRACE("sat", - for (auto kv : m_dep2asm) { + for (auto const& kv : m_dep2asm) { tout << mk_pp(kv.m_key, m) << " |-> " << sat::literal(kv.m_value) << "\n"; } tout << "asm2fml: "; - for (auto kv : asm2fml) { + for (auto const& kv : asm2fml) { tout << mk_pp(kv.m_key, m) << " |-> " << mk_pp(kv.m_value, m) << "\n"; } tout << "core: "; for (auto c : core) tout << c << " "; tout << "\n"; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 2af8485c1d7..bbdd6ce0ebf 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -295,7 +295,7 @@ namespace euf { return; bool sign = l.sign(); m_egraph.set_value(n, sign ? l_false : l_true); - for (auto th : enode_th_vars(n)) + for (auto const& th : enode_th_vars(n)) m_id2solver[th.get_id()]->asserted(l); size_t* c = to_ptr(l); @@ -519,8 +519,7 @@ namespace euf { void solver::push() { si.push(); - scope s; - s.m_var_lim = m_var_trail.size(); + scope s(m_var_trail.size()); m_scopes.push_back(s); m_trail.push_scope(); for (auto* e : m_solvers) @@ -994,9 +993,9 @@ namespace euf { void solver::user_propagate_init( void* ctx, - ::solver::push_eh_t& push_eh, - ::solver::pop_eh_t& pop_eh, - ::solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { m_user_propagator = alloc(user_solver::solver, *this); m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh); for (unsigned i = m_scopes.size(); i-- > 0; ) diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 0df07faff89..d70206d119b 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -72,6 +72,7 @@ namespace euf { }; struct scope { unsigned m_var_lim; + scope(unsigned l) : m_var_lim(l) {} }; @@ -400,27 +401,27 @@ namespace euf { // user propagator void user_propagate_init( void* ctx, - ::solver::push_eh_t& push_eh, - ::solver::pop_eh_t& pop_eh, - ::solver::fresh_eh_t& fresh_eh); + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh); bool watches_fixed(enode* n) const; void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.data()); } void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); } - void user_propagate_register_final(::solver::final_eh_t& final_eh) { + void user_propagate_register_final(user_propagator::final_eh_t& final_eh) { check_for_user_propagator(); m_user_propagator->register_final(final_eh); } - void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) { + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) { check_for_user_propagator(); m_user_propagator->register_fixed(fixed_eh); } - void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) { + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) { check_for_user_propagator(); m_user_propagator->register_eq(eq_eh); } - void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) { + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) { check_for_user_propagator(); m_user_propagator->register_diseq(diseq_eh); } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 2939a2f0a1d..275e33bbc58 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -21,11 +21,12 @@ Module Name: #include "sat/smt/sat_th.h" #include "solver/solver.h" +#include "tactic/user_propagator_base.h" namespace user_solver { - class solver : public euf::th_euf_solver, public ::solver::propagate_callback { + class solver : public euf::th_euf_solver, public user_propagator::callback { struct prop_info { unsigned_vector m_ids; @@ -47,15 +48,15 @@ namespace user_solver { }; void* m_user_context; - ::solver::push_eh_t m_push_eh; - ::solver::pop_eh_t m_pop_eh; - ::solver::fresh_eh_t m_fresh_eh; - ::solver::final_eh_t m_final_eh; - ::solver::fixed_eh_t m_fixed_eh; - ::solver::eq_eh_t m_eq_eh; - ::solver::eq_eh_t m_diseq_eh; - ::solver::context_obj* m_api_context { nullptr }; - unsigned m_qhead { 0 }; + user_propagator::push_eh_t m_push_eh; + user_propagator::pop_eh_t m_pop_eh; + user_propagator::fresh_eh_t m_fresh_eh; + user_propagator::final_eh_t m_final_eh; + user_propagator::fixed_eh_t m_fixed_eh; + user_propagator::eq_eh_t m_eq_eh; + user_propagator::eq_eh_t m_diseq_eh; + user_propagator::context_obj* m_api_context = nullptr; + unsigned m_qhead = 0; vector m_prop; unsigned_vector m_prop_lim; vector m_id2justification; @@ -91,9 +92,9 @@ namespace user_solver { */ void add( void* ctx, - ::solver::push_eh_t& push_eh, - ::solver::pop_eh_t& pop_eh, - ::solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { m_user_context = ctx; m_push_eh = push_eh; m_pop_eh = pop_eh; @@ -102,10 +103,10 @@ namespace user_solver { unsigned add_expr(expr* e); - void register_final(::solver::final_eh_t& final_eh) { m_final_eh = final_eh; } - void register_fixed(::solver::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } - void register_eq(::solver::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } - void register_diseq(::solver::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + void register_final(user_propagator::final_eh_t& final_eh) { m_final_eh = final_eh; } + void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } + void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } + void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index ee3341e4366..48eb23b4f6b 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -198,7 +198,7 @@ namespace smt { return; ast_translation tr(src_ctx.m, m, false); auto* p = get_theory(m.mk_family_id("user_propagator")); - m_user_propagator = reinterpret_cast(p); + m_user_propagator = reinterpret_cast(p); SASSERT(m_user_propagator); for (unsigned i = 0; i < src_ctx.m_user_propagator->get_num_vars(); ++i) { app* e = src_ctx.m_user_propagator->get_expr(i); @@ -2886,11 +2886,11 @@ namespace smt { void context::user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { setup_context(false); - m_user_propagator = alloc(user_propagator, *this); + m_user_propagator = alloc(theory_user_propagator, *this); m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh); for (unsigned i = m_scopes.size(); i-- > 0; ) m_user_propagator->push_scope_eh(); @@ -3552,7 +3552,7 @@ namespace smt { parallel p(*this); return p(asms); } - lbool r; + lbool r = l_undef; do { pop_to_base_lvl(); expr_ref_vector asms(m, num_assumptions, assumptions); @@ -3573,7 +3573,7 @@ namespace smt { if (!check_preamble(true)) return l_undef; TRACE("before_search", display(tout);); setup_context(false); - lbool r; + lbool r = l_undef; do { pop_to_base_lvl(); expr_ref_vector asms(cube); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index b5a19569ad5..20ee4762969 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -88,7 +88,7 @@ namespace smt { scoped_ptr m_qmanager; scoped_ptr m_model_generator; scoped_ptr m_relevancy_propagator; - user_propagator* m_user_propagator; + theory_user_propagator* m_user_propagator; random_gen m_random; bool m_flushing; // (debug support) true when flushing mutable unsigned m_lemma_id; @@ -1695,29 +1695,29 @@ namespace smt { */ void user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh); + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh); - void user_propagate_register_final(solver::final_eh_t& final_eh) { + void user_propagate_register_final(user_propagator::final_eh_t& final_eh) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); m_user_propagator->register_final(final_eh); } - void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); m_user_propagator->register_fixed(fixed_eh); } - void user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); m_user_propagator->register_eq(eq_eh); } - void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); m_user_propagator->register_diseq(diseq_eh); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index ecd443a5560..4b321fd4aaa 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -223,25 +223,25 @@ namespace smt { void user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { m_kernel.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); } - void user_propagate_register_final(solver::final_eh_t& final_eh) { + void user_propagate_register_final(user_propagator::final_eh_t& final_eh) { m_kernel.user_propagate_register_final(final_eh); } - void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_kernel.user_propagate_register_fixed(fixed_eh); } - void user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) { m_kernel.user_propagate_register_eq(eq_eh); } - void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_kernel.user_propagate_register_diseq(diseq_eh); } @@ -451,25 +451,25 @@ namespace smt { void kernel::user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { m_imp->user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); } - void kernel::user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + void kernel::user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_imp->user_propagate_register_fixed(fixed_eh); } - void kernel::user_propagate_register_final(solver::final_eh_t& final_eh) { + void kernel::user_propagate_register_final(user_propagator::final_eh_t& final_eh) { m_imp->user_propagate_register_final(final_eh); } - void kernel::user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + void kernel::user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) { m_imp->user_propagate_register_eq(eq_eh); } - void kernel::user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + void kernel::user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_imp->user_propagate_register_diseq(diseq_eh); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 68fbca582b4..2259dd99731 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -289,17 +289,17 @@ namespace smt { */ void user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh); + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh); - void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh); + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh); - void user_propagate_register_final(solver::final_eh_t& final_eh); + void user_propagate_register_final(user_propagator::final_eh_t& final_eh); - void user_propagate_register_eq(solver::eq_eh_t& eq_eh); + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh); - void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh); + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh); /** diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 930a731a1f4..9352e33f49d 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -215,25 +215,25 @@ namespace { void user_propagate_init( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) override { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) override { m_context.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); } - void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override { + void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) override { m_context.user_propagate_register_fixed(fixed_eh); } - void user_propagate_register_final(solver::final_eh_t& final_eh) override { + void user_propagate_register_final(user_propagator::final_eh_t& final_eh) override { m_context.user_propagate_register_final(final_eh); } - void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override { + void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) override { m_context.user_propagate_register_eq(eq_eh); } - void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override { + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { m_context.user_propagate_register_diseq(diseq_eh); } diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp index 2590449d0f3..799ae23e2ab 100644 --- a/src/smt/user_propagator.cpp +++ b/src/smt/user_propagator.cpp @@ -22,15 +22,15 @@ Module Name: using namespace smt; -user_propagator::user_propagator(context& ctx): +theory_user_propagator::theory_user_propagator(context& ctx): theory(ctx, ctx.get_manager().mk_family_id("user_propagator")) {} -user_propagator::~user_propagator() { +theory_user_propagator::~theory_user_propagator() { dealloc(m_api_context); } -void user_propagator::force_push() { +void theory_user_propagator::force_push() { for (; m_num_scopes > 0; --m_num_scopes) { theory::push_scope_eh(); m_push_eh(m_user_context); @@ -38,7 +38,7 @@ void user_propagator::force_push() { } } -unsigned user_propagator::add_expr(expr* e) { +unsigned theory_user_propagator::add_expr(expr* e) { force_push(); enode* n = ensure_enode(e); if (is_attached_to_var(n)) @@ -48,7 +48,7 @@ unsigned user_propagator::add_expr(expr* e) { return v; } -void user_propagator::propagate_cb( +void theory_user_propagator::propagate_cb( unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) { @@ -59,8 +59,8 @@ void user_propagator::propagate_cb( m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); } -theory * user_propagator::mk_fresh(context * new_ctx) { - auto* th = alloc(user_propagator, *new_ctx); +theory * theory_user_propagator::mk_fresh(context * new_ctx) { + auto* th = alloc(theory_user_propagator, *new_ctx); void* ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context); th->add(ctx, m_push_eh, m_pop_eh, m_fresh_eh); if ((bool)m_fixed_eh) th->register_fixed(m_fixed_eh); @@ -70,7 +70,7 @@ theory * user_propagator::mk_fresh(context * new_ctx) { return th; } -final_check_status user_propagator::final_check_eh() { +final_check_status theory_user_propagator::final_check_eh() { if (!(bool)m_final_eh) return FC_DONE; force_push(); @@ -81,7 +81,7 @@ final_check_status user_propagator::final_check_eh() { return done ? FC_DONE : FC_CONTINUE; } -void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { +void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { if (!m_fixed_eh) return; force_push(); @@ -93,11 +93,11 @@ void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, m_fixed_eh(m_user_context, this, v, value); } -void user_propagator::push_scope_eh() { +void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } -void user_propagator::pop_scope_eh(unsigned num_scopes) { +void theory_user_propagator::pop_scope_eh(unsigned num_scopes) { unsigned n = std::min(num_scopes, m_num_scopes); m_num_scopes -= n; num_scopes -= n; @@ -110,11 +110,11 @@ void user_propagator::pop_scope_eh(unsigned num_scopes) { m_prop_lim.shrink(old_sz); } -bool user_propagator::can_propagate() { +bool theory_user_propagator::can_propagate() { return m_qhead < m_prop.size(); } -void user_propagator::propagate() { +void theory_user_propagator::propagate() { TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n"); if (m_qhead == m_prop.size()) return; @@ -130,8 +130,8 @@ void user_propagator::propagate() { for (auto const& p : prop.m_eqs) m_eqs.push_back(enode_pair(get_enode(p.first), get_enode(p.second))); DEBUG_CODE(for (auto const& p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root());); - DEBUG_CODE(for (unsigned id : prop.m_ids) VERIFY(m_fixed.contains(id));); - DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true);); + DEBUG_CODE(for (unsigned id : prop.m_ids) VERIFY(m_fixed.contains(id));); + DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true);); TRACE("user_propagate", tout << "propagating: " << prop.m_conseq << "\n"); @@ -155,7 +155,7 @@ void user_propagator::propagate() { m_qhead = qhead; } -void user_propagator::collect_statistics(::statistics & st) const { +void theory_user_propagator::collect_statistics(::statistics & st) const { st.update("user-propagations", m_stats.m_num_propagations); st.update("user-watched", get_num_vars()); } diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h index dcf41c3bb23..3e9db8cdd64 100644 --- a/src/smt/user_propagator.h +++ b/src/smt/user_propagator.h @@ -27,7 +27,7 @@ Module Name: #include "solver/solver.h" namespace smt { - class user_propagator : public theory, public solver::propagate_callback { + class theory_user_propagator : public theory, public user_propagator::callback { struct prop_info { unsigned_vector m_ids; @@ -49,14 +49,14 @@ namespace smt { }; void* m_user_context = nullptr; - solver::push_eh_t m_push_eh; - solver::pop_eh_t m_pop_eh; - solver::fresh_eh_t m_fresh_eh; - solver::final_eh_t m_final_eh; - solver::fixed_eh_t m_fixed_eh; - solver::eq_eh_t m_eq_eh; - solver::eq_eh_t m_diseq_eh; - solver::context_obj* m_api_context = nullptr; + user_propagator::push_eh_t m_push_eh; + user_propagator::pop_eh_t m_pop_eh; + user_propagator::fresh_eh_t m_fresh_eh; + user_propagator::final_eh_t m_final_eh; + user_propagator::fixed_eh_t m_fixed_eh; + user_propagator::eq_eh_t m_eq_eh; + user_propagator::eq_eh_t m_diseq_eh; + user_propagator::context_obj* m_api_context = nullptr; unsigned m_qhead = 0; uint_set m_fixed; vector m_prop; @@ -70,18 +70,18 @@ namespace smt { void force_push(); public: - user_propagator(context& ctx); + theory_user_propagator(context& ctx); - ~user_propagator() override; + ~theory_user_propagator() override; /* * \brief initial setup for user propagator. */ void add( void* ctx, - solver::push_eh_t& push_eh, - solver::pop_eh_t& pop_eh, - solver::fresh_eh_t& fresh_eh) { + user_propagator::push_eh_t& push_eh, + user_propagator::pop_eh_t& pop_eh, + user_propagator::fresh_eh_t& fresh_eh) { m_user_context = ctx; m_push_eh = push_eh; m_pop_eh = pop_eh; @@ -90,10 +90,10 @@ namespace smt { unsigned add_expr(expr* e); - void register_final(solver::final_eh_t& final_eh) { m_final_eh = final_eh; } - void register_fixed(solver::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } - void register_eq(solver::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } - void register_diseq(solver::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + void register_final(user_propagator::final_eh_t& final_eh) { m_final_eh = final_eh; } + void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } + void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } + void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } diff --git a/src/solver/solver.h b/src/solver/solver.h index 550105167d9..90975ebf9a7 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -18,6 +18,7 @@ Module Name: --*/ #pragma once +#include "tactic/user_propagator_base.h" #include "solver/check_sat_result.h" #include "solver/progress_callback.h" #include "util/params.h" @@ -47,7 +48,7 @@ solver* mk_smt2_solver(ast_manager& m, params_ref const& p); - statistics - results based on check_sat_result API */ -class solver : public check_sat_result { +class solver : public check_sat_result, public user_propagator::base{ params_ref m_params; symbol m_cancel_backup_file; public: @@ -239,52 +240,6 @@ class solver : public check_sat_result { virtual expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) = 0; - class propagate_callback { - public: - virtual ~propagate_callback() = default; - virtual void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; - }; - class context_obj { - public: - virtual ~context_obj() {} - }; - typedef std::function final_eh_t; - typedef std::function fixed_eh_t; - typedef std::function eq_eh_t; - typedef std::function fresh_eh_t; - typedef std::function push_eh_t; - typedef std::function pop_eh_t; - - virtual void user_propagate_init( - void* ctx, - push_eh_t& push_eh, - pop_eh_t& pop_eh, - fresh_eh_t& fresh_eh) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - - virtual void user_propagate_register_fixed(fixed_eh_t& fixed_eh) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - virtual void user_propagate_register_final(final_eh_t& final_eh) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - virtual void user_propagate_register_eq(eq_eh_t& eq_eh) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - virtual void user_propagate_register_diseq(eq_eh_t& diseq_eh) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - virtual unsigned user_propagate_register(expr* e) { - throw default_exception("user-propagators are only supported on the SMT solver"); - } - - /** \brief Display the content of this solver. */ diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h new file mode 100644 index 00000000000..f7826e559c8 --- /dev/null +++ b/src/tactic/user_propagator_base.h @@ -0,0 +1,61 @@ + +#pragma once + +#include "ast/ast.h" + +namespace user_propagator { + + class callback { + public: + virtual ~callback() = default; + virtual void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; + }; + + class context_obj { + public: + virtual ~context_obj() {} + }; + + typedef std::function final_eh_t; + typedef std::function fixed_eh_t; + typedef std::function eq_eh_t; + typedef std::function fresh_eh_t; + typedef std::function push_eh_t; + typedef std::function pop_eh_t; + + + class base { + public: + + virtual void user_propagate_init( + void* ctx, + push_eh_t& push_eh, + pop_eh_t& pop_eh, + fresh_eh_t& fresh_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + + virtual void user_propagate_register_fixed(fixed_eh_t& fixed_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_final(final_eh_t& final_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_eq(eq_eh_t& eq_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_diseq(eq_eh_t& diseq_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual unsigned user_propagate_register(expr* e) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + }; + +}