diff --git a/ir/function.h b/ir/function.h index bb04c65c1..faa33dff6 100644 --- a/ir/function.h +++ b/ir/function.h @@ -119,6 +119,7 @@ class Function final { auto& getFnAttrs() { return attrs; } auto& getFnAttrs() const { return attrs; } + bool has(FnAttrs::Attribute a) const { return attrs.has(a); } smt::expr getTypeConstraints() const; void fixupTypes(const smt::Model &m); diff --git a/ir/instr.cpp b/ir/instr.cpp index 5376f1aee..c7edb6096 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -2403,7 +2403,7 @@ StateValue FnCall::toSMT(State &s) const { fnName_mangled << '!' << getType(); // Callee must return if caller must return - if (s.getFn().getFnAttrs().has(FnAttrs::WillReturn) && + if (s.getFn().has(FnAttrs::WillReturn) && !attrs.has(FnAttrs::WillReturn)) s.addGuardableUB(expr(false)); @@ -2435,7 +2435,7 @@ StateValue FnCall::toSMT(State &s) const { check_can_store(s, allocptr); Pointer ptr_old(m, allocptr); - if (s.getFn().getFnAttrs().has(FnAttrs::NoFree)) + if (s.getFn().has(FnAttrs::NoFree)) s.addGuardableUB(ptr_old.isNull() || ptr_old.isLocal()); m.copy(ptr_old, Pointer(m, p_new)); @@ -2466,7 +2466,7 @@ StateValue FnCall::toSMT(State &s) const { if (!hasAttribute(FnAttrs::NoFree)) { m.free(allocptr, false); - if (s.getFn().getFnAttrs().has(FnAttrs::NoFree)) { + if (s.getFn().has(FnAttrs::NoFree)) { Pointer ptr(m, allocptr); s.addGuardableUB(ptr.isNull() || ptr.isLocal()); } diff --git a/ir/memory.cpp b/ir/memory.cpp index 38585aa93..048eba136 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -2048,9 +2048,26 @@ expr Memory::ptr2int(const expr &ptr) const { return p.getAddress(); } +Pointer Memory::searchPointer(const expr &val0) const { + DisjointExpr ret; + expr val = val0.zextOrTrunc(bits_program_pointer); + + auto add = [&](unsigned limit, bool local) { + for (unsigned i = 0; i != limit; ++i) { + Pointer p(*this, i, local); + Pointer p_end = p + p.blockSize(); + ret.add((p + (val - p.getAddress())).release(), + val.uge(p.getAddress()) && val.ult(p_end.getAddress())); + } + }; + add(numLocals(), true); + add(numNonlocals(), false); + return Pointer(*this, *std::move(ret)()); +} + expr Memory::int2ptr(const expr &val0) const { assert(!memory_unused() && observesAddresses()); - if (state->getFn().getFnAttrs().has(FnAttrs::Asm)) { + if (state->getFn().has(FnAttrs::Asm)) { DisjointExpr ret(Pointer::mkNullPointer(*this).release()); expr val = val0; OrExpr domain; @@ -2100,21 +2117,7 @@ expr Memory::int2ptr(const expr &val0) const { if (processed_all) return std::move(ret)()->simplify(); - val = val.simplify(); - - expr valx = val.zextOrTrunc(bits_program_pointer); - - auto add = [&](unsigned limit, bool local) { - for (unsigned i = 0; i != limit; ++i) { - Pointer p(*this, i, local); - Pointer p_end = p + p.blockSize(); - ret.add((p + (valx - p.getAddress())).release(), - valx.uge(p.getAddress()) && valx.ult(p_end.getAddress())); - } - }; - add(numLocals(), true); - add(numNonlocals(), false); - return *std::move(ret)(); + return searchPointer(val.simplify()).release(); } expr null = Pointer::mkNullPointer(*this).release(); @@ -2379,6 +2382,7 @@ void Memory::print(ostream &os, const Model &m) const { P("size", p.blockSize()); P("align", expr::mkInt(1, 64) << p.blockAlignment().zextOrTrunc(64)); P("alloc type", p.getAllocType()); + P("alive", p.isBlockAlive()); if (observesAddresses()) P("address", p.getAddress()); if (!local && is_constglb(bid)) diff --git a/ir/memory.h b/ir/memory.h index c97aed256..0d0010f0b 100644 --- a/ir/memory.h +++ b/ir/memory.h @@ -330,6 +330,7 @@ class Memory { smt::expr ptr2int(const smt::expr &ptr) const; smt::expr int2ptr(const smt::expr &val) const; + Pointer searchPointer(const smt::expr &val) const; std::tuple> refined(const Memory &other, bool fncall, diff --git a/ir/pointer.cpp b/ir/pointer.cpp index 81542a113..844cb5672 100644 --- a/ir/pointer.cpp +++ b/ir/pointer.cpp @@ -2,6 +2,7 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "ir/pointer.h" +#include "ir/function.h" #include "ir/memory.h" #include "ir/globals.h" #include "ir/state.h" @@ -490,7 +491,7 @@ void Pointer::isDisjointOrEqual(const expr &len1, const Pointer &ptr2, expr Pointer::isBlockAlive() const { // NULL block is dead - if (has_null_block && !null_is_dereferenceable & getBid().isZero()) + if (has_null_block && !null_is_dereferenceable && getBid().isZero()) return false; auto bid = getShortBid(); @@ -518,6 +519,8 @@ expr Pointer::isHeapAllocated() const { } expr Pointer::refined(const Pointer &other) const { + bool is_asm = other.m.state->getFn().has(FnAttrs::Asm); + // This refers to a block that was malloc'ed within the function expr local = other.isLocal(); local &= getAllocType() == other.getAllocType(); @@ -528,13 +531,19 @@ expr Pointer::refined(const Pointer &other) const { // TODO: this induces an infinite loop //local &= block_refined(other); + expr nonlocal = is_asm ? getAddress() == other.getAddress() : *this == other; + + Pointer other_deref + = is_asm ? other.m.searchPointer(other.getAddress()) : other; + return expr::mkIf(isNull(), other.isNull(), - expr::mkIf(isLocal(), std::move(local), *this == other) && - isBlockAlive().implies(other.isBlockAlive())); + expr::mkIf(isLocal(), std::move(local), nonlocal) && + isBlockAlive().implies(other_deref.isBlockAlive())); } expr Pointer::fninputRefined(const Pointer &other, set &undef, const expr &byval_bytes) const { + bool is_asm = other.m.state->getFn().has(FnAttrs::Asm); expr size = blockSizeOffsetT(); expr off = getOffsetSizet(); expr size2 = other.blockSizeOffsetT(); @@ -563,9 +572,14 @@ expr Pointer::fninputRefined(const Pointer &other, set &undef, // TODO: this induces an infinite loop // block_refined(other); + expr nonlocal = is_asm ? getAddress() == other.getAddress() : *this == other; + + Pointer other_deref + = is_asm ? other.m.searchPointer(other.getAddress()) : other; + return expr::mkIf(isNull(), other.isNull(), - expr::mkIf(isLocal(), local, *this == other) && - isBlockAlive().implies(other.isBlockAlive())); + expr::mkIf(isLocal(), local, nonlocal) && + isBlockAlive().implies(other_deref.isBlockAlive())); } expr Pointer::isWritable() const { diff --git a/tests/alive-tv/asm/ptr-refinement.srctgt.ll b/tests/alive-tv/asm/ptr-refinement.srctgt.ll new file mode 100644 index 000000000..4b40b3be4 --- /dev/null +++ b/tests/alive-tv/asm/ptr-refinement.srctgt.ll @@ -0,0 +1,14 @@ +; TEST-ARGS: -tgt-is-asm +; SKIP-IDENTITY + +define ptr @src(ptr %0, ptr %1) { + %3 = ptrtoint ptr %0 to i64 + %4 = ptrtoint ptr %1 to i64 + %5 = sub i64 %4, %3 + %6 = getelementptr i8, ptr %0, i64 %5 + ret ptr %6 +} + +define ptr @tgt(ptr %0, ptr %1) { + ret ptr %1 +} diff --git a/tools/transform.cpp b/tools/transform.cpp index 5460e4fd4..a4b8a7d9f 100644 --- a/tools/transform.cpp +++ b/tools/transform.cpp @@ -99,7 +99,7 @@ void tools::print_model_val(ostream &os, const State &st, const Model &m, using print_var_val_ty = function; -static bool error(Errors &errs, const State &src_state, const State &tgt_state, +static bool error(Errors &errs, State &src_state, State &tgt_state, const Result &r, Solver &solver, const Value *var, const char *msg, bool check_each_var, print_var_val_ty print_var_val) { @@ -185,7 +185,7 @@ static bool error(Errors &errs, const State &src_state, const State &tgt_state, // this *may* be a pointer if (bw == Pointer::totalBits()) { - Pointer p(src_state.getMemory(), var); + Pointer p(src_state.returnMemory(), var); reduce(p.getOffset()); } } @@ -304,7 +304,7 @@ static bool error(Errors &errs, const State &src_state, const State &tgt_state, s << '\n'; } - st->getMemory().print(s, m); + st->returnMemory().print(s, m); } print_var_val(s, m);