diff --git a/ir/memory.cpp b/ir/memory.cpp index 1300e4586..74674ebc2 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -1285,7 +1285,8 @@ void Memory::storeLambda(const Pointer &ptr, const expr &offset, if (bytes.eq(ptr.blockSizeAligned())) { blk.val = val_no_offset ? mk_block_if(cond, val, std::move(blk.val)) - : expr::mkLambda(offset, mk_block_if(cond, val, std::move(orig_val))); + : expr::mkLambda(offset, "#offset", + mk_block_if(cond, val, std::move(orig_val))); if (cond.isTrue()) { blk.undef.clear(); @@ -1297,7 +1298,7 @@ void Memory::storeLambda(const Pointer &ptr, const expr &offset, offset.ult((ptr + bytes).getShortOffset()); blk.val - = expr::mkLambda(offset, + = expr::mkLambda(offset, "#offset", mk_block_if(cond && offset_cond, val, std::move(orig_val))); } @@ -1362,7 +1363,7 @@ void Memory::record_store(const Pointer &p, const smt::expr &bytes) { has_stored_arg.store(bid.concat(offset + off_expr), true)); } } else { - expr var = expr::mkFreshVar("#bid_off", bid.concat(offset)); + expr var = expr::mkQVar(0, bid.concat(offset)); expr var_bid = var.extract(var.bits()-1, offset.bits()); expr var_off = var.trunc(offset.bits()); @@ -1371,7 +1372,7 @@ void Memory::record_store(const Pointer &p, const smt::expr &bytes) { has_stored_arg = expr::mkIf(is_local, has_stored_arg, - expr::mkLambda(var, + expr::mkLambda(var, "#bid_off", (bid == var_bid && var_off.uge(offset) && var_off.ult(offset + bytes_div)) || @@ -1402,13 +1403,14 @@ static expr mk_liveness_array() { } void Memory::mkNonPoisonAxioms(bool local) const { - expr offset = expr::mkVar("#axoff", Pointer::bitsShortOffset()); + expr offset = expr::mkQVar(0, Pointer::bitsShortOffset()); + const char *name = "#axoff"; unsigned bid = 0; for (auto &block : local ? local_block_val : non_local_block_val) { if (isInitialMemBlock(block.val, config::disallow_ub_exploitation)) state->addAxiom( - expr::mkForAll({ offset }, + expr::mkForAll(1, &offset, &name, !raw_load(local, bid, offset).isPoison())); ++bid; } @@ -1449,7 +1451,8 @@ void Memory::mkNonlocalValAxioms(bool skip_consts) const { if (!does_ptr_mem_access && !(num_sub_byte_bits && isAsmMode())) return; - expr offset = expr::mkVar("#axoff", Pointer::bitsShortOffset()); + expr offset = expr::mkQVar(0, Pointer::bitsShortOffset()); + const char *name = "#axoff"; for (unsigned i = 0, e = numNonlocals(); i != e; ++i) { if (always_noread(i, true)) @@ -1461,7 +1464,7 @@ void Memory::mkNonlocalValAxioms(bool skip_consts) const { // per the ABI. if (num_sub_byte_bits && isAsmMode()) { state->addAxiom( - expr::mkForAll({ offset }, mkSubByteZExtStoreCond(byte, byte))); + expr::mkForAll(1, &offset, &name, mkSubByteZExtStoreCond(byte, byte))); } if (!does_ptr_mem_access) @@ -1489,7 +1492,7 @@ void Memory::mkNonlocalValAxioms(bool skip_consts) const { bid_cond &= bid.ule(upperbid); state->addAxiom( - expr::mkForAll({ offset }, + expr::mkForAll(1, &offset, &name, byte.isPtr().implies(!loadedptr.isLocal(false) && !loadedptr.isNocapture(false) && std::move(bid_cond)))); @@ -1679,13 +1682,15 @@ void Memory::mkAxioms(const Memory &tgt) const { // tame down quadratic explosion in disjointness constraint with a quantifier. if (num_nonlocals > max_quadratic_disjoint) { auto bid_ty = expr::mkUInt(0, Pointer::bitsShortBid()); - expr bid1 = expr::mkFreshVar("#bid1", bid_ty); - expr bid2 = expr::mkFreshVar("#bid2", bid_ty); + expr bid1 = expr::mkQVar(0, bid_ty); + expr bid2 = expr::mkQVar(1, bid_ty); expr offset = expr::mkUInt(0, bits_for_offset); Pointer p1(tgt, Pointer::mkLongBid(bid1, false), offset); Pointer p2(tgt, Pointer::mkLongBid(bid2, false), offset); + expr vars[] = {bid1, bid2}; + const char *names[] = {"#bid1", "#bid2"}; state->addAxiom( - expr::mkForAll({bid1, bid2}, + expr::mkForAll(2, vars, names, bid1 == bid2 || disjoint(p1.getAddress(), p1.blockSizeAligned().zextOrTrunc(bits_ptr_address), @@ -2418,8 +2423,7 @@ void Memory::memset(const expr &p, const StateValue &val, const expr &bytesize, } store(ptr, to_store, undef_vars, align); } else { - expr offset - = expr::mkFreshVar("#off", expr::mkUInt(0, Pointer::bitsShortOffset())); + expr offset = expr::mkQVar(0, Pointer::bitsShortOffset()); storeLambda(ptr, offset, bytesize, {{0, raw_byte}}, undef_vars, align); } } @@ -2451,8 +2455,7 @@ void Memory::memset_pattern(const expr &ptr0, const expr &pattern0, for (unsigned i = 0; i < pattern_length; i += bytesz) { to_store.emplace_back(i * bytesz, std::move(bytes[i/bytesz])()); } - expr offset - = expr::mkFreshVar("#off", expr::mkUInt(0, Pointer::bitsShortOffset())); + expr offset = expr::mkQVar(0, Pointer::bitsShortOffset()); storeLambda(ptr, offset, bytesize, to_store, undef_vars, 1); } } @@ -2482,8 +2485,7 @@ void Memory::memcpy(const expr &d, const expr &s, const expr &bytesize, } store(dst, to_store, undef, align_dst); } else { - expr offset - = expr::mkFreshVar("#off", expr::mkUInt(0, Pointer::bitsShortOffset())); + expr offset = expr::mkQVar(0, Pointer::bitsShortOffset()); Pointer ptr_src = src + (offset - dst.getShortOffset()); set undef; storeLambda(dst, offset, bytesize, {{0, raw_load(ptr_src, undef)()}}, undef, diff --git a/smt/expr.cpp b/smt/expr.cpp index a2da79b3d..503570b3e 100644 --- a/smt/expr.cpp +++ b/smt/expr.cpp @@ -294,6 +294,15 @@ bool expr::isTernaryOp(expr &a, expr &b, expr &c, int z3op) const { return false; } +expr expr::mkQVar(unsigned n, const expr &type) { + C2(type); + return Z3_mk_bound(ctx(), n, type.sort()); +} + +expr expr::mkQVar(unsigned n, unsigned bits) { + return Z3_mk_bound(ctx(), n, mkBVSort(bits)); +} + expr expr::mkVar(const char *name, const expr &type) { C2(type); return ::mkVar(name, type.sort()); @@ -559,6 +568,7 @@ bool expr::isLambda(expr &body) const { } return false; } + expr expr::lambdaIdxType() const { C(); assert(Z3_get_quantifier_num_bound(ctx(), ast()) == 1); @@ -2142,7 +2152,23 @@ expr expr::mkForAll(const set &vars, expr &&val) { val()); } -expr expr::mkLambda(const expr &var, const expr &val) { +expr expr::mkForAll(unsigned num_vars, const expr *vars, const char **names, + expr &&val) { + if (num_vars == 0 || val.isConst() || !val.isValid()) + return std::move(val); + + const unsigned max_vars = 4; + ENSURE(num_vars <= max_vars); + Z3_sort sorts[max_vars]; + Z3_symbol syms[max_vars]; + for (unsigned i = 0; i < num_vars; ++i) { + sorts[i] = vars[i].sort(); + syms[i] = Z3_mk_string_symbol(ctx(), names[i]); + } + return Z3_mk_forall(ctx(), 0, 0, nullptr, num_vars, sorts, syms, val()); +} + +expr expr::mkLambda(const expr &var, const char *var_name, const expr &val) { C2(var, val); if (!val.vars().count(var)) @@ -2152,8 +2178,9 @@ expr expr::mkLambda(const expr &var, const expr &val) { if (val.isLoad(array, idx) && idx.eq(var)) return array; - auto ast = (Z3_app)var(); - return Z3_mk_lambda_const(ctx(), 1, &ast, val()); + auto sort = var.sort(); + auto name = Z3_mk_string_symbol(ctx(), var_name); + return Z3_mk_lambda(ctx(), 1, &sort, &name, val()); } expr expr::simplify() const { @@ -2243,6 +2270,9 @@ set expr::vars(const vector &exprs) { switch (Z3_get_ast_kind(ctx(), ast)) { case Z3_VAR_AST: + result.emplace(expr(ast)); + break; + case Z3_NUMERAL_AST: break; diff --git a/smt/expr.h b/smt/expr.h index 427f40fd0..078b227a6 100644 --- a/smt/expr.h +++ b/smt/expr.h @@ -87,6 +87,8 @@ class expr { static expr mkQuad(double n); static expr mkNaN(const expr &type); static expr mkNumber(const char *n, const expr &type); + static expr mkQVar(unsigned n, const expr &type); + static expr mkQVar(unsigned n, unsigned bits); static expr mkVar(const char *name, const expr &type); static expr mkVar(const char *name, unsigned bits, bool fresh = false); static expr mkBoolVar(const char *name); @@ -350,7 +352,9 @@ class expr { static expr mkIf(const expr &cond, const expr &then, const expr &els); static expr mkForAll(const std::set &vars, expr &&val); - static expr mkLambda(const expr &var, const expr &val); + static expr mkForAll(unsigned num_vars, const expr *vars, const char **names, + expr &&val); + static expr mkLambda(const expr &var, const char *var_name, const expr &val); expr simplify() const; expr simplifyNoTimeout() const;