Skip to content

Commit

Permalink
extend solver callbacks with methods
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <[email protected]>
  • Loading branch information
NikolajBjorner committed Aug 22, 2020
1 parent 080be7a commit 2d5b749
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 55 deletions.
42 changes: 37 additions & 5 deletions scripts/update_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,26 @@ def Z3_set_error_handler(ctx, hndlr, _elems=Elementaries(_lib.Z3_set_error_handl
_elems.Check(ctx)
return ceh
def Z3_solver_propagate_init(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh, _elems = Elementaries(_lib.Z3_solver_propagate_init)):
_elems.f(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh)
_elems.Check(ctx)
def Z3_solver_propagate_init(ctx, s, user_ctx, push_eh, pop_eh, fresh_eh, _elems = Elementaries(_lib.Z3_solver_propagate_init)):
_elems.f(ctx, s, user_ctx, push_eh, pop_eh, fresh_eh)
_elems.Check(ctx)
def Z3_solver_propagate_final(ctx, s, final_eh, _elems = Elementaries(_lib.Z3_solver_propagate_final)):
_elems.f(ctx, s, final_eh)
_elems.Check(ctx)
def Z3_solver_propagate_fixed(ctx, s, fixed_eh, _elems = Elementaries(_lib.Z3_solver_propagate_fixed)):
_elems.f(ctx, s, fixed_eh)
_elems.Check(ctx)
def Z3_solver_propagate_eq(ctx, s, eq_eh, _elems = Elementaries(_lib.Z3_solver_propagate_eq)):
_elems.f(ctx, s, eq_eh)
_elems.Check(ctx)
def Z3_solver_propagate_diseq(ctx, s, diseq_eh, _elems = Elementaries(_lib.Z3_solver_propagate_diseq)):
_elems.f(ctx, s, diseq_eh)
_elems.Check(ctx)
""")

Expand Down Expand Up @@ -1826,11 +1843,26 @@ def _to_pystr(s):
push_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
pop_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint)
fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p)
fresh_eh_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p)
fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p)
final_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)
eq_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint)
_lib.Z3_solver_propagate_init.restype = None
_lib.Z3_solver_propagate_init.argtypes = [ContextObj, SolverObj, ctypes.c_void_p, push_eh_type, pop_eh_type, fixed_eh_type, fresh_eh_type]
_lib.Z3_solver_propagate_init.argtypes = [ContextObj, SolverObj, ctypes.c_void_p, push_eh_type, pop_eh_type, fresh_eh_type]
_lib.Z3_solver_propagate_final.restype = None
_lib.Z3_solver_propagate_final.argtypes = [ContextObj, SolverObj, final_eh_type]
_lib.Z3_solver_propagate_fixed.restype = None
_lib.Z3_solver_propagate_fixed.argtypes = [ContextObj, SolverObj, fixed_eh_type]
_lib.Z3_solver_propagate_eq.restype = None
_lib.Z3_solver_propagate_eq.argtypes = [ContextObj, SolverObj, eq_eh_type]
_lib.Z3_solver_propagate_diseq.restype = None
_lib.Z3_solver_propagate_diseq.argtypes = [ContextObj, SolverObj, eq_eh_type]
"""
Expand Down
48 changes: 45 additions & 3 deletions src/api/api_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,19 +892,61 @@ extern "C" {
void* user_context,
Z3_push_eh push_eh,
Z3_pop_eh pop_eh,
Z3_fixed_eh fixed_eh,
Z3_fresh_eh fresh_eh) {
Z3_TRY;
RESET_ERROR_CODE();
init_solver(c, s);
std::function<void(void*)> _push = push_eh;
std::function<void(void*,unsigned)> _pop = pop_eh;
std::function<void(void*,solver::propagate_callback*,unsigned,expr*)> _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh;
std::function<void*(void*)> _fresh = fresh_eh;
to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop, _fresh);
to_solver_ref(s)->user_propagate_init(user_context, _push, _pop, _fresh);
Z3_CATCH;
}

void Z3_API Z3_solver_propagate_fixed(
Z3_context c,
Z3_solver s,
Z3_fixed_eh fixed_eh) {
Z3_TRY;
RESET_ERROR_CODE();
solver::fixed_eh_t _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh;
to_solver_ref(s)->user_propagate_register_fixed(_fixed);
Z3_CATCH;
}

void Z3_API Z3_solver_propagate_final(
Z3_context c,
Z3_solver s,
Z3_final_eh final_eh) {
Z3_TRY;
RESET_ERROR_CODE();
solver::final_eh_t _final = (bool(*)(void*,solver::propagate_callback*))final_eh;
to_solver_ref(s)->user_propagate_register_final(_final);
Z3_CATCH;
}

void Z3_API Z3_solver_propagate_eq(
Z3_context c,
Z3_solver s,
Z3_eq_eh eq_eh) {
Z3_TRY;
RESET_ERROR_CODE();
solver::eq_eh_t _eq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))eq_eh;
to_solver_ref(s)->user_propagate_register_eq(_eq);
Z3_CATCH;
}

void Z3_API Z3_solver_propagate_diseq(
Z3_context c,
Z3_solver s,
Z3_eq_eh diseq_eh) {
Z3_TRY;
RESET_ERROR_CODE();
solver::eq_eh_t _diseq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))diseq_eh;
to_solver_ref(s)->user_propagate_register_diseq(_diseq);
Z3_CATCH;
}

unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_propagate_register(c, s, e);
Expand Down
1 change: 1 addition & 0 deletions src/api/ml/z3native.ml.pre
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and literals = ptr
and constructor = ptr
and constructor_list = ptr
and solver = ptr
and solver_callback = ptr
and goal = ptr
and tactic = ptr
and params = ptr
Expand Down
102 changes: 83 additions & 19 deletions src/api/python/z3/z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10505,59 +10505,123 @@ def TransitiveClosure(f):
return FuncDeclRef(Z3_mk_transitive_closure(f.ctx_ref(), f.ast), f.ctx)


_user_propagate_bases = {}
class PropClosures:
# import thread
def __init__(self):
self.bases = {}
# self.lock = thread.Lock()

def get(self, ctx):
# self.lock.acquire()
r = self.bases[ctx]
# self.lock.release()
return r

def set(self, ctx, r):
# self.lock.acquire()
self.bases[ctx] = r
# self.lock.release()

def insert(self, r):
# self.lock.acquire()
id = len(self.bases) + 3
self.bases[id] = r
# self.lock.release()
return id

_prop_closures = PropClosures()

def user_prop_push(ctx):
_user_propagate_bases[ctx].push();
_prop_closures.get(ctx).push();

def user_prop_pop(ctx, num_scopes):
_user_propagate_bases[ctx].pop(num_scopes)
_prop_closures.get(ctx).pop(num_scopes)

def user_prop_fresh(ctx):
prop = _prop_closures.get(ctx)
new_prop = UsePropagateBase(None, prop.ctx)
_prop_closures.set(new_prop.id, new_prop.fresh())
return ctypes.c_void_p(new_prop.id)

def user_prop_fixed(ctx, cb, id, value):
prop = _user_propagate_bases[ctx]
prop = _prop_closures.get(ctx)
prop.cb = cb
prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx))
prop.cb = None

def user_prop_fresh(ctx):
prop = _user_propagate_bases[ctx]
new_prop = UsePropagateBase(None, prop.ctx)
_user_prop_bases[new_prop.id] = new_prop.fresh()
return ctypes.c_void_p(new_prop.id)

def user_prop_final(ctx, cb):
prop = _prop_closures.get(ctx)
prop.cb = cb
prop.final()
prop.cb = None

def user_prop_eq(ctx, cb, x, y):
prop = _prop_closures.get(ctx)
prop.cb = cb
prop.eq(x, y)
prop.cb = None

def user_prop_diseq(ctx, cb, x, y):
prop = _prop_closures.get(ctx)
prop.cb = cb
prop.diseq(x, y)
prop.cb = None

_user_prop_push = push_eh_type(user_prop_push)
_user_prop_pop = pop_eh_type(user_prop_pop)
_user_prop_fixed = fixed_eh_type(user_prop_fixed)
_user_prop_fresh = fresh_eh_type(user_prop_fresh)
_user_prop_fixed = fixed_eh_type(user_prop_fixed)
_user_prop_final = final_eh_type(user_prop_final)
_user_prop_eq = eq_eh_type(user_prop_eq)
_user_prop_diseq = eq_eh_type(user_prop_diseq)

class UserPropagateBase:

def __init__(self, s, ctx = None):
self.id = len(_user_propagate_bases) + 3
self.solver = s
self.solver = s
self.ctx = s.ctx if s is not None else ctx
self.cb = None
_user_propagate_bases[self.id] = self
self.id = _prop_closures.insert(self)
self.fixed = None
self.final = None
self.eq = None
self.diseq = None
if s:
Z3_solver_propagate_init(s.ctx.ref(),
s.solver,
ctypes.c_void_p(self.id),
_user_prop_push,
_user_prop_pop,
_user_prop_fixed,
_user_prop_fresh)


def add_fixed(self, fixed):
assert not self.fixed
Z3_solver_propagate_fixed(self.ctx.ref(), self.solver.solver, _user_prop_fixed)
self.fixed = fixed

def add_final(self, final):
assert not self.final
Z3_solver_propagate_final(self.ctx.ref(), self.solver.solver, _user_prop_final)
self.final = final

def add_eq(self, eq):
assert not self.eq
Z3_solver_propagate_eq(self.ctx.ref(), self.solver.solver, _user_prop_eq)
self.eq = eq

def add_diseq(self, diseq):
assert not self.diseq
Z3_solver_propagate_diseq(self.ctx.ref(), self.solver.solver, _user_prop_diseq)
self.diseq = diseq

def push(self):
raise Z3Exception("push has not been overwritten")

def pop(self, num_scopes):
raise Z3Exception("pop has not been overwritten")

def fixed(self, id, e):
raise Z3Exception("fixed has not been overwritten")

def fresh(self, prop_base):
def fresh(self):
raise Z3Exception("fresh has not been overwritten")

def add(self, e):
Expand Down
37 changes: 34 additions & 3 deletions src/api/z3_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1420,8 +1420,10 @@ typedef void Z3_error_handler(Z3_context c, Z3_error_code e);
*/
typedef void Z3_push_eh(void* ctx);
typedef void Z3_pop_eh(void* ctx, unsigned num_scopes);
typedef void Z3_fixed_eh(void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value);
typedef void* Z3_fresh_eh(void* ctx);
typedef void Z3_fixed_eh(void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value);
typedef void Z3_eq_eh(void* ctx, Z3_solver_callback cb, unsigned x, unsigned y);
typedef void Z3_final_eh(void* ctx, Z3_solver_callback cb);

/**
\brief A Goal is essentially a set of formulas.
Expand Down Expand Up @@ -6537,17 +6539,46 @@ extern "C" {
void* user_context,
Z3_push_eh push_eh,
Z3_pop_eh pop_eh,
Z3_fixed_eh fixed_eh,
Z3_fresh_eh fresh_eh);

/**
\brief register a callback for when an expression is bound to a fixed value.
The supported expression types are
- Booleans
- Bit-vectors
*/

void Z3_API Z3_solver_propagate_fixed(Z3_context c, Z3_solver s, Z3_fixed_eh fixed_eh);

/**
\brief register a callback on final check.
This provides freedom to the propagator to delay actions or implement a branch-and bound solver.
The final_eh callback takes as argument the original user_context that was used
when calling \c Z3_solver_propagate_init, and it takes a callback context for propagations.
If may use the callback context to invoke the \c Z3_solver_propagate_consequence function.
If the callback context gets used, the solver continues.
*/
void Z3_API Z3_solver_propagate_final(Z3_context c, Z3_solver s, Z3_final_eh final_eh);

/**
\brief register a callback on expression equalities.
*/
void Z3_API Z3_solver_propagate_eq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh);

/**
\brief register a callback on expression dis-equalities.
*/
void Z3_API Z3_solver_propagate_diseq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh);

/**
\brief register an expression to propagate on with the solver.
Only expressions of type Bool and type Bit-Vector can be registered for propagation.
def_API('Z3_solver_propagate_register', UINT, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/

unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e);
unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e);

/**
\brief propagate a consequence based on fixed values.
Expand Down
3 changes: 1 addition & 2 deletions src/smt/smt_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2951,13 +2951,12 @@ namespace smt {

void context::user_propagate_init(
void* ctx,
std::function<void(void*, solver::propagate_callback*, unsigned, expr*)>& fixed_eh,
std::function<void(void*)>& push_eh,
std::function<void(void*, unsigned)>& pop_eh,
std::function<void*(void*)>& fresh_eh) {
setup_context(m_fparams.m_auto_config);
m_user_propagator = alloc(user_propagator, *this);
m_user_propagator->add(ctx, fixed_eh, push_eh, pop_eh, fresh_eh);
m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh);
for (unsigned i = m_scopes.size(); i-- > 0; )
m_user_propagator->push_scope_eh();
register_plugin(m_user_propagator);
Expand Down
25 changes: 24 additions & 1 deletion src/smt/smt_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -1689,11 +1689,34 @@ namespace smt {
*/
void user_propagate_init(
void* ctx,
std::function<void(void*, solver::propagate_callback*, unsigned, expr*)>& fixed_eh,
std::function<void(void*)>& push_eh,
std::function<void(void*, unsigned)>& pop_eh,
std::function<void*(void*)>& fresh_eh);

void user_propagate_register_final(solver::final_eh_t& final_eh) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
m_user_propagator->register_final(final_eh);
}

void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
m_user_propagator->register_fixed(fixed_eh);
}

void user_propagate_register_eq(solver::eq_eh_t& eq_eh) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
m_user_propagator->register_eq(eq_eh);
}

void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
m_user_propagator->register_diseq(diseq_eh);
}

unsigned user_propagate_register(expr* e) {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
Expand Down
Loading

0 comments on commit 2d5b749

Please sign in to comment.