Skip to content

Commit

Permalink
prepare user propagator declared functions for likely Clemens use case
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Dec 17, 2021
1 parent a288f90 commit 6cc9aa3
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/muz/base/dl_rule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ namespace datalog {
tail_neg.push_back(false);
}

SASSERT(tail.size()==tail_neg.size());
SASSERT(tail.size() == tail_neg.size());
rule_ref old_r = r;
r = mk(head, tail.size(), tail.data(), tail_neg.data(), old_r->name());
r->set_accounting_parent_object(m_ctx, old_r);
Expand Down
12 changes: 12 additions & 0 deletions src/smt/smt_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,18 @@ namespace smt {
throw default_exception("user propagator must be initialized");
return m_user_propagator->add_expr(e);
}

void user_propagate_register_declared(user_propagator::register_created_eh_t& r) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
m_user_propagator->register_declared(r);
}

func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
return m_user_propagator->declare(name, n, domain, range);
}

bool watches_fixed(enode* n) const;

Expand Down
18 changes: 17 additions & 1 deletion src/smt/smt_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ namespace smt {
unsigned user_propagate_register(expr* e) {
return m_kernel.user_propagate_register(e);
}

void user_propagate_register_declared(user_propagator::register_created_eh_t& r) {
m_kernel.user_propagate_register_declared(r);
}

func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) {
return m_kernel.user_propagate_declare(name, n, domain, range);
}

};

Expand Down Expand Up @@ -477,4 +485,12 @@ namespace smt {
return m_imp->user_propagate_register(e);
}

};
void kernel::user_propagate_register_declared(user_propagator::register_created_eh_t& r) {
m_imp->user_propagate_register_declared(r);
}

func_decl* kernel::user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) {
return m_imp->user_propagate_declare(name, n, domain, range);
}

};
7 changes: 3 additions & 4 deletions src/smt/smt_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,11 @@ namespace smt {

void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh);


/**
\brief register an expression to be tracked fro user propagation.
*/
unsigned user_propagate_register(expr* e);

void user_propagate_register_declared(user_propagator::register_created_eh_t& r);

func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range);

/**
\brief Return a reference to smt::context.
Expand Down
20 changes: 20 additions & 0 deletions src/smt/theory_user_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ void theory_user_propagator::propagate() {
m_qhead = qhead;
}

func_decl* theory_user_propagator::declare(symbol const& name, unsigned n, sort* const* domain, sort* range) {
if (!m_created_eh)
throw default_exception("event handler for dynamic expressions has to be registered before functions can be created");
// ensure that declaration plugin is registered with m.
if (!m.has_plugin(get_id()))
m.register_plugin(get_id(), alloc(user_propagator::plugin));

func_decl_info info(get_id(), user_propagator::plugin::kind_t::OP_USER_PROPAGATE);
return m.mk_func_decl(name, n, domain, range, info);
}

bool theory_user_propagator::internalize_atom(app* atom, bool gate_ctx) {
return internalize_term(atom);
}

bool theory_user_propagator::internalize_term(app* term) {
NOT_IMPLEMENTED_YET();
return false;
}

void theory_user_propagator::collect_statistics(::statistics & st) const {
st.update("user-propagations", m_stats.m_num_propagations);
st.update("user-watched", get_num_vars());
Expand Down
8 changes: 6 additions & 2 deletions src/smt/theory_user_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ namespace smt {
user_propagator::fixed_eh_t m_fixed_eh;
user_propagator::eq_eh_t m_eq_eh;
user_propagator::eq_eh_t m_diseq_eh;
user_propagator::register_created_eh_t m_created_eh;

user_propagator::context_obj* m_api_context = nullptr;
unsigned m_qhead = 0;
uint_set m_fixed;
Expand Down Expand Up @@ -94,6 +96,8 @@ namespace smt {
void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; }
void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; }
void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; }
void register_declared(user_propagator::register_created_eh_t& created_eh) { m_created_eh = created_eh; }
func_decl* declare(symbol const& name, unsigned n, sort* const* domain, sort* range);

bool has_fixed() const { return (bool)m_fixed_eh; }

Expand All @@ -103,8 +107,8 @@ namespace smt {
void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits);

theory * mk_fresh(context * new_ctx) override;
bool internalize_atom(app * atom, bool gate_ctx) override { UNREACHABLE(); return false; }
bool internalize_term(app * term) override { UNREACHABLE(); return false; }
bool internalize_atom(app* atom, bool gate_ctx) override;
bool internalize_term(app* term) override;
void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, v1, v2); }
void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, v1, v2); }
bool use_diseqs() const override { return ((bool)m_diseq_eh); }
Expand Down
26 changes: 19 additions & 7 deletions src/tactic/core/elim_uncnstr_tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class elim_uncnstr_tactic : public tactic {
struct rw_cfg : public default_rewriter_cfg {
bool m_produce_proofs;
obj_hashtable<expr> & m_vars;
obj_hashtable<expr>& m_nonvars;
ref<mc> m_mc;
arith_util m_a_util;
bv_util m_bv_util;
Expand All @@ -49,10 +50,11 @@ class elim_uncnstr_tactic : public tactic {
unsigned long long m_max_memory;
unsigned m_max_steps;

rw_cfg(ast_manager & m, bool produce_proofs, obj_hashtable<expr> & vars, mc * _m,
unsigned long long max_memory, unsigned max_steps):
rw_cfg(ast_manager & m, bool produce_proofs, obj_hashtable<expr> & vars, obj_hashtable<expr> & nonvars, mc * _m,
unsigned long long max_memory, unsigned max_steps):
m_produce_proofs(produce_proofs),
m_vars(vars),
m_nonvars(nonvars),
m_mc(_m),
m_a_util(m),
m_bv_util(m),
Expand All @@ -73,7 +75,7 @@ class elim_uncnstr_tactic : public tactic {
}

bool uncnstr(expr * arg) const {
return m_vars.contains(arg);
return m_vars.contains(arg) && !m_nonvars.contains(arg);
}

bool uncnstr(unsigned num, expr * const * args) const {
Expand Down Expand Up @@ -749,16 +751,17 @@ class elim_uncnstr_tactic : public tactic {
class rw : public rewriter_tpl<rw_cfg> {
rw_cfg m_cfg;
public:
rw(ast_manager & m, bool produce_proofs, obj_hashtable<expr> & vars, mc * _m,
rw(ast_manager & m, bool produce_proofs, obj_hashtable<expr> & vars, obj_hashtable<expr>& nonvars, mc * _m,
unsigned long long max_memory, unsigned max_steps):
rewriter_tpl<rw_cfg>(m, produce_proofs, m_cfg),
m_cfg(m, produce_proofs, vars, _m, max_memory, max_steps) {
m_cfg(m, produce_proofs, vars, nonvars, _m, max_memory, max_steps) {
}
};

ast_manager & m_manager;
ref<mc> m_mc;
obj_hashtable<expr> m_vars;
obj_hashtable<expr> m_nonvars;
scoped_ptr<rw> m_rw;
unsigned m_num_elim_apps = 0;
unsigned long long m_max_memory;
Expand All @@ -774,12 +777,11 @@ class elim_uncnstr_tactic : public tactic {
}

void init_rw(bool produce_proofs) {
m_rw = alloc(rw, m(), produce_proofs, m_vars, m_mc.get(), m_max_memory, m_max_steps);
m_rw = alloc(rw, m(), produce_proofs, m_vars, m_nonvars, m_mc.get(), m_max_memory, m_max_steps);
}

void run(goal_ref const & g, goal_ref_buffer & result) {
bool produce_proofs = g->proofs_enabled();

TRACE("goal", g->display(tout););
tactic_report report("elim-uncnstr", *g);
m_vars.reset();
Expand Down Expand Up @@ -890,6 +892,16 @@ class elim_uncnstr_tactic : public tactic {
m_num_elim_apps = 0;
}

unsigned user_propagate_register(expr* e) override {
m_nonvars.insert(e);
return 0;
}

void user_propagate_clear() override {
m_nonvars.reset();
}


};
}

Expand Down
42 changes: 42 additions & 0 deletions src/tactic/user_propagator_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,33 @@ namespace user_propagator {
typedef std::function<void*(void*, ast_manager&, context_obj*&)> fresh_eh_t;
typedef std::function<void(void*)> push_eh_t;
typedef std::function<void(void*,unsigned)> pop_eh_t;
typedef std::function<void(void*, callback*, expr*, unsigned)> register_created_eh_t;


class plugin : public decl_plugin {
public:

enum kind_t { OP_USER_PROPAGATE };

virtual ~plugin() {}

virtual decl_plugin* mk_fresh() { return alloc(plugin); }

family_id get_family_id() const { return m_family_id; }

sort* mk_sort(decl_kind k, unsigned num_parameters, parameter const* parameters) override {
UNREACHABLE();
return nullptr;
}

func_decl* mk_func_decl(decl_kind k, unsigned num_parameters, parameter const* parameters,
unsigned arity, sort* const* domain, sort* range) {
UNREACHABLE();
return nullptr;
}

};

class core {
public:

Expand Down Expand Up @@ -58,8 +83,25 @@ namespace user_propagator {
throw default_exception("user-propagators are only supported on the SMT solver");
}

/**
* Create uninterpreted function for the user propagator.
* When expressions using the function are assigned values, generate a callback
* into a register_declared_eh(user_ctx, solver_ctx, declared_expr, declare_id) with arguments
* 1. context and callback context
* 2. declared_expr: expression using function that was declared at top.
* 3. declared_id: a unique identifier (unique within the current scope) to track the expression.
*/
virtual func_decl* user_propagate_declare(symbol const& name, unsigned n, sort* const* domain, sort* range) {
throw default_exception("user-propagators are only supported on the SMT solver");
}

virtual void user_propagate_register_created(register_created_eh_t& r) {
throw default_exception("user-propagators are only supported on the SMT solver");
}

virtual void user_propagate_clear() {
}


};

Expand Down

0 comments on commit 6cc9aa3

Please sign in to comment.