Skip to content

Commit

Permalink
add accessors for implied values to API
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <[email protected]>
  • Loading branch information
NikolajBjorner committed Jul 29, 2020
1 parent 4628cb8 commit 59d8895
Show file tree
Hide file tree
Showing 23 changed files with 347 additions and 2 deletions.
33 changes: 33 additions & 0 deletions src/api/api_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,39 @@ extern "C" {
Z3_CATCH_RETURN(nullptr);
}

Z3_ast Z3_API Z3_solver_get_implied_value(Z3_context c, Z3_solver s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_get_implied_value(c, s, e);
RESET_ERROR_CODE();
init_solver(c, s);
expr_ref v = to_solver_ref(s)->get_implied_value(to_expr(e));
mk_c(c)->save_ast_trail(v);
RETURN_Z3(of_ast(v));
Z3_CATCH_RETURN(nullptr);
}

Z3_ast Z3_API Z3_solver_get_implied_lower(Z3_context c, Z3_solver s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_get_implied_lower(c, s, e);
RESET_ERROR_CODE();
init_solver(c, s);
expr_ref v = to_solver_ref(s)->get_implied_lower_bound(to_expr(e));
mk_c(c)->save_ast_trail(v);
RETURN_Z3(of_ast(v));
Z3_CATCH_RETURN(nullptr);
}

Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_get_implied_upper(c, s, e);
RESET_ERROR_CODE();
init_solver(c, s);
expr_ref v = to_solver_ref(s)->get_implied_upper_bound(to_expr(e));
mk_c(c)->save_ast_trail(v);
RETURN_Z3(of_ast(v));
Z3_CATCH_RETURN(nullptr);
}

static Z3_lbool _solver_check(Z3_context c, Z3_solver s, unsigned num_assumptions, Z3_ast const assumptions[]) {
for (unsigned i = 0; i < num_assumptions; i++) {
if (!is_expr(to_ast(assumptions[i]))) {
Expand Down
17 changes: 15 additions & 2 deletions src/api/c++/z3++.h
Original file line number Diff line number Diff line change
Expand Up @@ -2379,12 +2379,25 @@ namespace z3 {
}
void add(expr const & e, char const * p) {
add(e, ctx().bool_const(p));
}
void add(expr_vector const& v) {
check_context(*this, v);
for (unsigned i = 0; i < v.size(); ++i)
add(v[i]);
}
// fails for some compilers:
// void add(expr_vector const& v) { check_context(*this, v); for (expr e : v) add(e); }
void from_file(char const* file) { Z3_solver_from_file(ctx(), m_solver, file); ctx().check_parser_error(); }
void from_string(char const* s) { Z3_solver_from_string(ctx(), m_solver, s); ctx().check_parser_error(); }

expr lower(expr const& e) {
Z3_ast r = Z3_solver_get_implied_lower(ctx(), m_solver, e); check_error(); return expr(ctx(), r);
}
expr upper(expr const& e) {
Z3_ast r = Z3_solver_get_implied_upper(ctx(), m_solver, e); check_error(); return expr(ctx(), r);
}
expr value(expr const& e) {
Z3_ast r = Z3_solver_get_implied_value(ctx(), m_solver, e); check_error(); return expr(ctx(), r);
}

check_result check() { Z3_lbool r = Z3_solver_check(ctx(), m_solver); check_error(); return to_check_result(r); }
check_result check(unsigned n, expr * const assumptions) {
array<Z3_ast> _assumptions(n);
Expand Down
15 changes: 15 additions & 0 deletions src/api/python/z3/z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6857,6 +6857,21 @@ def trail(self):
"""
return AstVector(Z3_solver_get_trail(self.ctx.ref(), self.solver), self.ctx)

def value(self, e):
"""Return value of term in solver, if any is given.
"""
return _to_expr_ref(Z3_solver_get_implied_value(self.ctx.ref(), self.solver, e.as_ast()), self.ctx)

def lower(self, e):
"""Return lower bound known to solver based on the last call.
"""
return _to_expr_ref(Z3_solver_get_implied_lower(self.ctx.ref(), self.solver, e.as_ast()), self.ctx)

def upper(self, e):
"""Return upper bound known to solver based on the last call.
"""
return _to_expr_ref(Z3_solver_get_implied_upper(self.ctx.ref(), self.solver, e.as_ast()), self.ctx)

def statistics(self):
"""Return statistics for the last `check()`.
Expand Down
28 changes: 28 additions & 0 deletions src/api/z3_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6486,6 +6486,34 @@ extern "C" {
*/
void Z3_API Z3_solver_get_levels(Z3_context c, Z3_solver s, Z3_ast_vector literals, unsigned sz, unsigned levels[]);

/**
\brief retrieve implied value for expression, if any is implied by solver at search level.
The method works for expressions that are known to the solver state, such as Boolean and
arithmetical variables.
def_API('Z3_solver_get_implied_value', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
Z3_ast Z3_API Z3_solver_get_implied_value(Z3_context c, Z3_solver s, Z3_ast e);

/**
\brief retrieve implied lower bound value for arithmetic expression.
If a lower bound is implied at search level, the arithmetic expression returned
is a constant representing the bound.
def_API('Z3_solver_get_implied_lower', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
Z3_ast Z3_API Z3_solver_get_implied_lower(Z3_context c, Z3_solver s, Z3_ast e);

/**
\brief retrieve implied upper bound value for arithmetic expression.
If an upper bound is implied at search level, the arithmetic expression returned
is a constant representing the bound.
def_API('Z3_solver_get_implied_upper', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/

Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e);

/**
\brief Check whether the assertions in a given solver are consistent or not.
Expand Down
3 changes: 3 additions & 0 deletions src/muz/spacer/spacer_iuc_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class iuc_solver : public solver {
expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); }
void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); }
expr_ref_vector get_trail() override { return m_solver.get_trail(); }
expr_ref get_implied_value(expr* e) override { return m_solver.get_implied_value(e); }
expr_ref get_implied_lower_bound(expr* e) override { return m_solver.get_implied_lower_bound(e); }
expr_ref get_implied_upper_bound(expr* e) override { return m_solver.get_implied_upper_bound(e); }

void push() override;
void pop(unsigned n) override;
Expand Down
3 changes: 3 additions & 0 deletions src/opt/opt_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ namespace opt {
void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override;
expr_ref_vector get_trail() override { return m_context.get_trail(); }
expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); }
expr_ref get_implied_value(expr* e) override { return m_context.get_implied_value(e); }
expr_ref get_implied_lower_bound(expr* e) override { return m_context.get_implied_lower_bound(e); }
expr_ref get_implied_upper_bound(expr* e) override { return m_context.get_implied_upper_bound(e); }

void set_logic(symbol const& logic);

Expand Down
13 changes: 13 additions & 0 deletions src/sat/sat_solver/inc_sat_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,19 @@ class inc_sat_solver : public solver {
return nullptr;
}

// TODO
expr_ref get_implied_value(expr* e) override {
return expr_ref(e, m);
}

expr_ref get_implied_lower_bound(expr* e) override {
return expr_ref(e, m);
}

expr_ref get_implied_upper_bound(expr* e) override {
return expr_ref(e, m);
}

expr_ref_vector last_cube(bool is_sat) {
expr_ref_vector result(m);
result.push_back(is_sat ? m.mk_true() : m.mk_false());
Expand Down
28 changes: 28 additions & 0 deletions src/smt/smt_arith_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,37 @@ namespace smt {
return false;
}

expr_ref arith_value::get_lo(expr* e) const {
rational lo;
bool s = false;
if (a.is_int_real(e) && get_lo(e, lo, s) && !s) {
return expr_ref(a.mk_numeral(lo, m.get_sort(e)), m);
}
return expr_ref(e, m);
}

expr_ref arith_value::get_up(expr* e) const {
rational up;
bool s = false;
if (a.is_int_real(e) && get_up(e, up, s) && !s) {
return expr_ref(a.mk_numeral(up, m.get_sort(e)), m);
}
return expr_ref(e, m);
}

expr_ref arith_value::get_fixed(expr* e) const {
rational lo, up;
bool s = false;
if (a.is_int_real(e) && get_lo(e, lo, s) && !s && get_up(e, up, s) && !s && lo == up) {
return expr_ref(a.mk_numeral(lo, m.get_sort(e)), m);
}
return expr_ref(e, m);
}

final_check_status arith_value::final_check() {
family_id afid = a.get_family_id();
theory * th = m_ctx->get_theory(afid);
return th->final_check_eh();
}

};
3 changes: 3 additions & 0 deletions src/smt/smt_arith_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ namespace smt {
bool get_lo(expr* e, rational& lo, bool& strict) const;
bool get_up(expr* e, rational& up, bool& strict) const;
bool get_value(expr* e, rational& value) const;
expr_ref get_lo(expr* e) const;
expr_ref get_up(expr* e) const;
expr_ref get_fixed(expr* e) const;
final_check_status final_check();
};
};
43 changes: 43 additions & 0 deletions src/smt/smt_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Revision History:
#include "smt/smt_model_checker.h"
#include "smt/smt_model_finder.h"
#include "smt/smt_parallel.h"
#include "smt/smt_arith_value.h"

namespace smt {

Expand Down Expand Up @@ -4555,6 +4556,48 @@ namespace smt {
TRACE("model", tout << *m_model << "\n";);
}

expr_ref context::get_implied_value(expr* e) {
pop_to_search_lvl();
if (m.is_bool(e)) {
if (b_internalized(e)) {
bool_var v = get_bool_var(e);
switch (get_assignment(get_bool_var(e))) {
case l_true: e = m.mk_true(); break;
case l_false: e = m.mk_false(); break;
default: break;
}
}
return expr_ref(e, m);
}

if (e_internalized(e)) {
enode* n = get_enode(e);
for (enode* r : *n) {
if (m.is_value(r->get_owner())) {
return expr_ref(r->get_owner(), m);
}
}
}

arith_value av(m);
av.init(this);
return av.get_fixed(e);
}

expr_ref context::get_implied_lower_bound(expr* e) {
pop_to_search_lvl();
arith_value av(m);
av.init(this);
return av.get_lo(e);
}

expr_ref context::get_implied_upper_bound(expr* e) {
pop_to_search_lvl();
arith_value av(m);
av.init(this);
return av.get_up(e);
}

};


Expand Down
7 changes: 7 additions & 0 deletions src/smt/smt_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,13 @@ namespace smt {
return get_bdata(v).get_theory();
}

expr_ref get_implied_value(expr* e);

expr_ref get_implied_lower_bound(expr* e);

expr_ref get_implied_upper_bound(expr* e);


friend class set_var_theory_trail;
void set_var_theory(bool_var v, theory_id tid);

Expand Down
25 changes: 25 additions & 0 deletions src/smt/smt_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ namespace smt {
lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) {
return m_kernel.find_mutexes(vars, mutexes);
}

expr_ref get_implied_value(expr* e) {
return m_kernel.get_implied_value(e);
}

expr_ref get_implied_lower_bound(expr* e) {
return m_kernel.get_implied_lower_bound(e);
}

expr_ref get_implied_upper_bound(expr* e) {
return m_kernel.get_implied_upper_bound(e);
}

void get_model(model_ref & m) {
m_kernel.get_model(m);
Expand Down Expand Up @@ -412,5 +424,18 @@ namespace smt {
return m_imp->get_trail();
}

expr_ref kernel::get_implied_value(expr* e) {
return m_imp->get_implied_value(e);
}

expr_ref kernel::get_implied_lower_bound(expr* e) {
return m_imp->get_implied_lower_bound(e);
}

expr_ref kernel::get_implied_upper_bound(expr* e) {
return m_imp->get_implied_upper_bound(e);
}



};
11 changes: 11 additions & 0 deletions src/smt/smt_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ namespace smt {
*/
expr_ref next_cube();

/**
\brief retrieve upper/lower bound for arithmetic term, if it is implied.
retrieve implied values if terms are fixed to a value.
*/

expr_ref get_implied_value(expr* e);

expr_ref get_implied_lower_bound(expr* e);

expr_ref get_implied_upper_bound(expr* e);

/**
\brief retrieve depth of variables from decision stack.
*/
Expand Down
12 changes: 12 additions & 0 deletions src/smt/smt_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,18 @@ namespace {
}
}

expr_ref get_implied_value(expr* e) override {
return m_context.get_implied_value(e);
}

expr_ref get_implied_lower_bound(expr* e) override {
return m_context.get_implied_lower_bound(e);
}

expr_ref get_implied_upper_bound(expr* e) override {
return m_context.get_implied_upper_bound(e);
}

bool fds_intersect(func_decl_set & pattern_fds, func_decl_set & assrtn_fds) {
for (func_decl * fd : pattern_fds) {
if (assrtn_fds.contains(fd))
Expand Down
1 change: 1 addition & 0 deletions src/smt/theory_lra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,7 @@ class theory_lra::imp {
is_sat = make_feasible();
}
final_check_status st = FC_DONE;

switch (is_sat) {
case l_true:
TRACE("arith", /*display(tout);*/
Expand Down
22 changes: 22 additions & 0 deletions src/solver/combined_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,28 @@ class combined_solver : public solver {
return m_solver1->get_scope_level();
}

expr_ref get_implied_value(expr* e) override {
if (m_use_solver1_results)
return m_solver1->get_implied_value(e);
else
return m_solver2->get_implied_value(e);
}

expr_ref get_implied_lower_bound(expr* e) override {
if (m_use_solver1_results)
return m_solver1->get_implied_lower_bound(e);
else
return m_solver2->get_implied_lower_bound(e);
}

expr_ref get_implied_upper_bound(expr* e) override {
if (m_use_solver1_results)
return m_solver1->get_implied_upper_bound(e);
else
return m_solver2->get_implied_upper_bound(e);
}


lbool get_consequences(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override {
switch_inc_mode();
m_use_solver1_results = false;
Expand Down
Loading

0 comments on commit 59d8895

Please sign in to comment.