diff --git a/src/muz/base/dl_rule.cpp b/src/muz/base/dl_rule.cpp index 9bd7a7adf9f..d0c872c3c20 100644 --- a/src/muz/base/dl_rule.cpp +++ b/src/muz/base/dl_rule.cpp @@ -778,7 +778,7 @@ namespace datalog { tail_neg.push_back(false); } - SASSERT(tail.size()==tail_neg.size()); + SASSERT(tail.size() == tail_neg.size()); rule_ref old_r = r; r = mk(head, tail.size(), tail.data(), tail_neg.data(), old_r->name()); r->set_accounting_parent_object(m_ctx, old_r); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 700c2f21109..fe9fe2054e9 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1728,6 +1728,18 @@ namespace smt { throw default_exception("user propagator must be initialized"); return m_user_propagator->add_expr(e); } + + void user_propagate_register_declared(user_propagator::register_created_eh_t& r) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_declared(r); + } + + func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + return m_user_propagator->declare(name, n, domain, range); + } bool watches_fixed(enode* n) const; diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 4b321fd4aaa..1e97fa4653b 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -248,6 +248,14 @@ namespace smt { unsigned user_propagate_register(expr* e) { return m_kernel.user_propagate_register(e); } + + void user_propagate_register_declared(user_propagator::register_created_eh_t& r) { + m_kernel.user_propagate_register_declared(r); + } + + func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) { + return m_kernel.user_propagate_declare(name, n, domain, range); + } }; @@ -477,4 +485,12 @@ namespace smt { return m_imp->user_propagate_register(e); } -}; + void kernel::user_propagate_register_declared(user_propagator::register_created_eh_t& r) { + m_imp->user_propagate_register_declared(r); + } + + func_decl* kernel::user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) { + return m_imp->user_propagate_declare(name, n, domain, range); + } + +}; \ No newline at end of file diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 2259dd99731..7af8f2fe68e 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -301,12 +301,11 @@ namespace smt { void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh); - - /** - \brief register an expression to be tracked fro user propagation. - */ unsigned user_propagate_register(expr* e); + void user_propagate_register_declared(user_propagator::register_created_eh_t& r); + + func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range); /** \brief Return a reference to smt::context. diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 2b50e07abe3..75983b49095 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -173,6 +173,26 @@ void theory_user_propagator::propagate() { m_qhead = qhead; } +func_decl* theory_user_propagator::declare(symbol const& name, unsigned n, sort* const* domain, sort* range) { + if (!m_created_eh) + throw default_exception("event handler for dynamic expressions has to be registered before functions can be created"); + // ensure that declaration plugin is registered with m. + if (!m.has_plugin(get_id())) + m.register_plugin(get_id(), alloc(user_propagator::plugin)); + + func_decl_info info(get_id(), user_propagator::plugin::kind_t::OP_USER_PROPAGATE); + return m.mk_func_decl(name, n, domain, range, info); +} + +bool theory_user_propagator::internalize_atom(app* atom, bool gate_ctx) { + return internalize_term(atom); +} + +bool theory_user_propagator::internalize_term(app* term) { + NOT_IMPLEMENTED_YET(); + return false; +} + void theory_user_propagator::collect_statistics(::statistics & st) const { st.update("user-propagations", m_stats.m_num_propagations); st.update("user-watched", get_num_vars()); diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index d007de6a014..0c5cdeef8e3 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -56,6 +56,8 @@ namespace smt { user_propagator::fixed_eh_t m_fixed_eh; user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; + user_propagator::register_created_eh_t m_created_eh; + user_propagator::context_obj* m_api_context = nullptr; unsigned m_qhead = 0; uint_set m_fixed; @@ -94,6 +96,8 @@ namespace smt { void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + void register_declared(user_propagator::register_created_eh_t& created_eh) { m_created_eh = created_eh; } + func_decl* declare(symbol const& name, unsigned n, sort* const* domain, sort* range); bool has_fixed() const { return (bool)m_fixed_eh; } @@ -103,8 +107,8 @@ namespace smt { void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); theory * mk_fresh(context * new_ctx) override; - bool internalize_atom(app * atom, bool gate_ctx) override { UNREACHABLE(); return false; } - bool internalize_term(app * term) override { UNREACHABLE(); return false; } + bool internalize_atom(app* atom, bool gate_ctx) override; + bool internalize_term(app* term) override; void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, v1, v2); } void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, v1, v2); } bool use_diseqs() const override { return ((bool)m_diseq_eh); } diff --git a/src/tactic/core/elim_uncnstr_tactic.cpp b/src/tactic/core/elim_uncnstr_tactic.cpp index 915a0f3c01e..a5cbea932d3 100644 --- a/src/tactic/core/elim_uncnstr_tactic.cpp +++ b/src/tactic/core/elim_uncnstr_tactic.cpp @@ -38,6 +38,7 @@ class elim_uncnstr_tactic : public tactic { struct rw_cfg : public default_rewriter_cfg { bool m_produce_proofs; obj_hashtable & m_vars; + obj_hashtable& m_nonvars; ref m_mc; arith_util m_a_util; bv_util m_bv_util; @@ -49,10 +50,11 @@ class elim_uncnstr_tactic : public tactic { unsigned long long m_max_memory; unsigned m_max_steps; - rw_cfg(ast_manager & m, bool produce_proofs, obj_hashtable & vars, mc * _m, - unsigned long long max_memory, unsigned max_steps): + rw_cfg(ast_manager & m, bool produce_proofs, obj_hashtable & vars, obj_hashtable & nonvars, mc * _m, + unsigned long long max_memory, unsigned max_steps): m_produce_proofs(produce_proofs), m_vars(vars), + m_nonvars(nonvars), m_mc(_m), m_a_util(m), m_bv_util(m), @@ -73,7 +75,7 @@ class elim_uncnstr_tactic : public tactic { } bool uncnstr(expr * arg) const { - return m_vars.contains(arg); + return m_vars.contains(arg) && !m_nonvars.contains(arg); } bool uncnstr(unsigned num, expr * const * args) const { @@ -749,16 +751,17 @@ class elim_uncnstr_tactic : public tactic { class rw : public rewriter_tpl { rw_cfg m_cfg; public: - rw(ast_manager & m, bool produce_proofs, obj_hashtable & vars, mc * _m, + rw(ast_manager & m, bool produce_proofs, obj_hashtable & vars, obj_hashtable& nonvars, mc * _m, unsigned long long max_memory, unsigned max_steps): rewriter_tpl(m, produce_proofs, m_cfg), - m_cfg(m, produce_proofs, vars, _m, max_memory, max_steps) { + m_cfg(m, produce_proofs, vars, nonvars, _m, max_memory, max_steps) { } }; ast_manager & m_manager; ref m_mc; obj_hashtable m_vars; + obj_hashtable m_nonvars; scoped_ptr m_rw; unsigned m_num_elim_apps = 0; unsigned long long m_max_memory; @@ -774,12 +777,11 @@ class elim_uncnstr_tactic : public tactic { } void init_rw(bool produce_proofs) { - m_rw = alloc(rw, m(), produce_proofs, m_vars, m_mc.get(), m_max_memory, m_max_steps); + m_rw = alloc(rw, m(), produce_proofs, m_vars, m_nonvars, m_mc.get(), m_max_memory, m_max_steps); } void run(goal_ref const & g, goal_ref_buffer & result) { bool produce_proofs = g->proofs_enabled(); - TRACE("goal", g->display(tout);); tactic_report report("elim-uncnstr", *g); m_vars.reset(); @@ -890,6 +892,16 @@ class elim_uncnstr_tactic : public tactic { m_num_elim_apps = 0; } + unsigned user_propagate_register(expr* e) override { + m_nonvars.insert(e); + return 0; + } + + void user_propagate_clear() override { + m_nonvars.reset(); + } + + }; } diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 899722c2a74..e5a2d282f5b 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -23,8 +23,33 @@ namespace user_propagator { typedef std::function fresh_eh_t; typedef std::function push_eh_t; typedef std::function pop_eh_t; + typedef std::function register_created_eh_t; + class plugin : public decl_plugin { + public: + + enum kind_t { OP_USER_PROPAGATE }; + + virtual ~plugin() {} + + virtual decl_plugin* mk_fresh() { return alloc(plugin); } + + family_id get_family_id() const { return m_family_id; } + + sort* mk_sort(decl_kind k, unsigned num_parameters, parameter const* parameters) override { + UNREACHABLE(); + return nullptr; + } + + func_decl* mk_func_decl(decl_kind k, unsigned num_parameters, parameter const* parameters, + unsigned arity, sort* const* domain, sort* range) { + UNREACHABLE(); + return nullptr; + } + + }; + class core { public: @@ -58,8 +83,25 @@ namespace user_propagator { throw default_exception("user-propagators are only supported on the SMT solver"); } + /** + * Create uninterpreted function for the user propagator. + * When expressions using the function are assigned values, generate a callback + * into a register_declared_eh(user_ctx, solver_ctx, declared_expr, declare_id) with arguments + * 1. context and callback context + * 2. declared_expr: expression using function that was declared at top. + * 3. declared_id: a unique identifier (unique within the current scope) to track the expression. + */ + virtual func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_created(register_created_eh_t& r) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + virtual void user_propagate_clear() { } + };