From 0c93c7aa08ca189f38dde19f105e81362641bb2d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 18 Aug 2020 10:30:10 -0700 Subject: [PATCH] adding user propagation to API Signed-off-by: Nikolaj Bjorner --- src/api/api_solver.cpp | 32 ++++++++++++++++++++++++++++++++ src/api/z3_api.h | 37 +++++++++++++++++++++++++++++++++++++ src/smt/smt_context.cpp | 2 +- src/smt/smt_context.h | 15 ++++++++++++++- src/smt/smt_kernel.cpp | 25 ++++++++++++++++++++----- src/smt/smt_kernel.h | 17 +++++++++++++++-- src/smt/smt_solver.cpp | 11 +++++++++-- src/solver/solver.h | 6 +++++- 8 files changed, 133 insertions(+), 12 deletions(-) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index a2a0e87385c..4cd4462b9d8 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -886,5 +886,37 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + void Z3_API Z3_solver_propagate_init( + Z3_context c, + Z3_solver s, + void* user_context, + Z3_push_eh push_eh, + Z3_pop_eh pop_eh, + Z3_fixed_eh fixed_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + init_solver(c, s); + std::function _push = push_eh; + std::function _pop = pop_eh; + std::function _fixed = [&](void* ctx, unsigned id, expr* e) { fixed_eh(ctx, id, of_ast(e)); }; + to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop); + Z3_CATCH; + } + + unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) { + Z3_TRY; + LOG_Z3_solver_propagate_register(c, s, e); + RESET_ERROR_CODE(); + return to_solver_ref(s)->user_propagate_register(to_expr(e)); + Z3_CATCH_RETURN(0); + } + + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver s, unsigned sz, unsigned const* ids, Z3_ast conseq) { + Z3_TRY; + LOG_Z3_solver_propagate_consequence(c, s, sz, ids, conseq); + RESET_ERROR_CODE(); + to_solver_ref(s)->user_propagate_consequence(sz, ids, to_expr(conseq)); + Z3_CATCH; + } }; diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 63c7e71e89c..b012ef66d5b 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6514,6 +6514,43 @@ extern "C" { Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e); + + /** + \brief register a user-properator with the solver. + */ + + typedef void Z3_push_eh(void* ctx); + typedef void Z3_pop_eh(void* ctx, unsigned num_scopes); + typedef void Z3_fixed_eh(void* ctx, unsigned id, Z3_ast value); + + void Z3_API Z3_solver_propagate_init( + Z3_context c, + Z3_solver s, + void* user_context, + Z3_push_eh push_eh, + Z3_pop_eh pop_eh, + Z3_fixed_eh fixed_eh); + + /** + \brief register an expression to propagate on with the solver. + Only expressions of type Bool and type Bit-Vector can be registered for propagation. + + def_API('Z3_solver_propagate_register', UINT, (_in(CONTEXT), _in(SOLVER), _in(AST))) + */ + + unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); + + /** + \brief propagate a consequence based on fixed values. + This is a callback a client may invoke during the fixed_eh callback. + The callback adds a propagation consequence based on the fixed values of the + \c ids. + + def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, UINT), _in(AST))) + */ + + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver, unsigned sz, unsigned const* ids, Z3_ast conseq); + /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index b869be69e49..49c64502804 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2878,7 +2878,7 @@ namespace smt { } } - void context::register_user_propagator( + void context::user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index e425e025b70..c62341de38d 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1682,12 +1682,25 @@ namespace smt { /* * user-propagator */ - void register_user_propagator( + void user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, std::function& pop_eh); + unsigned user_propagate_register(expr* e) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + return m_user_propagator->add_expr(e); + } + + void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->add_propagation(sz, ids, conseq); + } + + bool watches_fixed(enode* n) const; void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 4f134ecac89..40fe2cfd370 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -233,12 +233,20 @@ namespace smt { m_kernel.updt_params(p); } - void register_user_propagator( + void user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, std::function& pop_eh) { - m_kernel.register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + m_kernel.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); + } + + unsigned user_propagate_register(expr* e) { + return m_kernel.user_propagate_register(e); + } + + void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { + m_kernel.user_propagate_consequence(sz, ids, conseq); } }; @@ -327,7 +335,6 @@ namespace smt { return m_imp->check(cube, clauses); } - lbool kernel::get_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq, expr_ref_vector& unfixed) { return m_imp->get_consequences(assumptions, vars, conseq, unfixed); } @@ -453,12 +460,20 @@ namespace smt { return m_imp->get_implied_upper_bound(e); } - void kernel::register_user_propagator( + void kernel::user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, std::function& pop_eh) { - m_imp->register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + m_imp->user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); + } + + unsigned kernel::user_propagate_register(expr* e) { + return m_imp->user_propagate_register(e); + } + + void kernel::user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { + m_imp->user_propagate_consequence(sz, ids, conseq); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index c14ba66bf16..7fdd14fc14f 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -285,14 +285,27 @@ namespace smt { static void collect_param_descrs(param_descrs & d); /** - \brief register a user-propagator "theory" + \brief initialize a user-propagator "theory" */ - void register_user_propagator( + void user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, std::function& pop_eh); + /** + \brief register an expression to be tracked fro user propagation. + */ + unsigned user_propagate_register(expr* e); + + + /** + \brief accept a user-propagation callback (issued during fixed_he). + */ + + void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq); + + /** \brief Return a reference to smt::context. This is a temporary hack to support user theories. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 62e7947d725..9868840e674 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -208,14 +208,21 @@ namespace { return m_context.get_trail(); } - void register_user_propagator( + void user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, std::function& pop_eh) override { - m_context.register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + m_context.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); } + unsigned user_propagate_register(expr* e) override { + return m_context.user_propagate_register(e); + } + + void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { + m_context.user_propagate_consequence(sz, ids, conseq); + } struct scoped_minimize_core { smt_solver& s; diff --git a/src/solver/solver.h b/src/solver/solver.h index 6adefc6e9f3..659bf249b32 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -238,7 +238,7 @@ class solver : public check_sat_result { virtual expr_ref get_implied_upper_bound(expr* e) = 0; - virtual void register_user_propagator( + virtual void user_propagate_init( void* ctx, std::function& fixed_eh, std::function& push_eh, @@ -246,6 +246,10 @@ class solver : public check_sat_result { throw default_exception("user-propagators are only supported on the SMT solver"); } + virtual unsigned user_propagate_register(expr* e) { return 0; } + + virtual void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) {} + /** \brief Display the content of this solver.