From 4bbdd96fe535f97844eb6649aa983bd2d5a9928d Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Sat, 28 Dec 2024 10:03:34 +0000 Subject: [PATCH] try to copy UB conditions from src to tgt this is in preparation for subsequent work to take advantage of implied conditions from UB --- ir/state.cpp | 217 +++++++++++++++++++++++++++++++++---------- ir/state.h | 21 ++++- smt/expr.cpp | 2 +- smt/exprs.cpp | 53 +++++++++++ smt/exprs.h | 5 + tools/alive-exec.cpp | 2 +- tools/transform.cpp | 11 ++- 7 files changed, 253 insertions(+), 58 deletions(-) diff --git a/ir/state.cpp b/ir/state.cpp index 87c2bed4a..61655cfc7 100644 --- a/ir/state.cpp +++ b/ir/state.cpp @@ -102,38 +102,27 @@ void State::ValueAnalysis::meet_with(const State::ValueAnalysis &other) { non_poison_vals = intersect_set(non_poison_vals, other.non_poison_vals); non_undef_vals = intersect_set(non_undef_vals, other.non_undef_vals); unused_vars = intersect_set(unused_vars, other.unused_vars); + ranges_fn_calls.meet_with(other.ranges_fn_calls); +} - for (auto &[fn, pair] : other.ranges_fn_calls) { - auto &[calls, access] = pair; - auto [I, inserted] = ranges_fn_calls.try_emplace(fn, pair); - if (inserted) { - I->second.first.emplace(0); - } else { - I->second.first.insert(calls.begin(), calls.end()); - I->second.second |= access; - } - } - - for (auto &[fn, pair] : ranges_fn_calls) { - auto &[calls, access] = pair; - if (!other.ranges_fn_calls.count(fn)) - calls.emplace(0); - } +void State::ValueAnalysis::clear_smt() { + non_poison_vals.clear(); + non_undef_vals = decltype(non_undef_vals)(); + unused_vars = decltype(unused_vars)(); } void State::ValueAnalysis::FnCallRanges::inc(const string &name, const SMTMemoryAccess &access) { - if (access.canWriteSomething().isFalse()) - return; + bool canwrite = !access.canWriteSomething().isFalse(); auto [I, inserted] = try_emplace(name); if (inserted) { - I->second.first.emplace(1); + I->second.first.emplace(1, canwrite); I->second.second = access; } else { - set new_set; - for (unsigned n : I->second.first) { - new_set.emplace(n+1); + set> new_set; + for (auto [n, writes0] : I->second.first) { + new_set.emplace(n+1, writes0 | canwrite); } I->second.first = std::move(new_set); I->second.second |= access; @@ -177,14 +166,13 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee, for (auto &[fn, pair] : *this) { auto &[calls, access] = pair; - assert(!access.canWriteSomething().isFalse()); if (skip(fn, access)) continue; auto I = other.find(fn); if (I == other.end()) { - if (calls.count(0)) + if (calls.count({0, true})) continue; return false; } @@ -194,7 +182,10 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee, if ((access | I->second.second).canReadSomething().isFalse()) continue; - if (intersect_set(calls, I->second.first).empty()) + auto set = intersect_set(calls, I->second.first); + // must only have write accesses + assert(ranges::all_of(set, [](auto &p) { return p.second; })); + if (set.empty()) return false; } @@ -203,10 +194,31 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee, if (skip(fn, access)) continue; - if (!calls.count(0) && !count(fn)) + if (!calls.count({0, true}) && !count(fn)) + return false; + } + + return true; +} + +bool State::ValueAnalysis::FnCallRanges::isLargerThanInclReads( + const FnCallRanges &other) const { + for (auto &[fn, pair] : *this) { + auto &[calls, access] = pair; + auto I = other.find(fn); + if (I == other.end()) + continue; + + auto first_val = calls.begin()->first; + auto other_last_val = I->second.first.rbegin()->first; + if (first_val < other_last_val) return false; } + for (auto &[fn, pair] : other) { + if (!count(fn)) + return false; + } return true; } @@ -220,6 +232,41 @@ State::ValueAnalysis::FnCallRanges::project(const string &name) const { return ranges; } +void State::ValueAnalysis::FnCallRanges::keep_only_writes() { + for (auto I = begin(); I != end(); ) { + auto &[calls, access] = I->second; + for (auto II = calls.begin(); II != calls.end(); ) { + if (!II->second) + II = calls.erase(II); + else + ++II; + } + if (calls.empty()) + I = erase(I); + else + ++I; + } +} + +void State::ValueAnalysis::FnCallRanges::meet_with(const FnCallRanges &other) { + for (auto &[fn, pair] : other) { + auto &[calls, access] = pair; + auto [I, inserted] = try_emplace(fn, pair); + if (inserted) { + I->second.first.emplace(0, true); + } else { + I->second.first.insert(calls.begin(), calls.end()); + I->second.second |= access; + } + } + + for (auto &[fn, pair] : *this) { + auto &[calls, access] = pair; + if (!other.count(fn)) + calls.emplace(0, true); + } +} + State::VarArgsData State::VarArgsData::mkIf(const expr &cond, VarArgsData &&then, VarArgsData &&els) { @@ -267,9 +314,9 @@ const State::ValTy& State::exec(const Value &v) { domain.noreturn = true; auto val = v.toSMT(*this); - auto value_ub = domain.UB(); + auto value_ub = domain.UB; if (config::disallow_ub_exploitation) - value_ub &= !guardable_ub(); + value_ub.add(!guardable_ub()); auto [I, inserted] = values.try_emplace(&v, ValTy{std::move(val), domain.noreturn, @@ -720,13 +767,64 @@ void State::cleanupPredecessorData() { predecessor_data.clear(); } +void State::copyUBFrom(const BasicBlock &bb) { + if (config::disallow_ub_exploitation) + return; + + // Time-travel UB: anything that happens before a possibly non-returning call + // can be moved up to the entry of the BB. + const Value *before_call = nullptr; + for (auto &i : bb.instrs()) { + if (auto *call = dynamic_cast(&i)) { + if (!call->hasAttribute(FnAttrs::WillReturn)) + break; + } + before_call = &i; + } + if (!before_call) + return; + + auto src_val_I = src_state->values.find(before_call); + assert(src_val_I != src_state->values.end()); + domain.UB.add(src_val_I->second.domain); +} + +void State::copyUBFromBB( + const unordered_map &tgt_data) { + auto I = src_bb_paths.find(domain.path); + if (I == src_bb_paths.end()) + return; + + for (auto *src_bb : I->second) { + bool all_paths_ok = true; + for (auto &[_, src_data] : src_state->predecessor_data.at(src_bb)) { + auto I = ranges::find_if(tgt_data, [&](const auto &p) { + return is_eq(p.second.path <=> src_data.path); + }); + if (I == tgt_data.end() || + !I->second.analysis.ranges_fn_calls.isLargerThanInclReads( + src_data.analysis.ranges_fn_calls)) { + all_paths_ok = false; + break; + } + } + if (all_paths_ok) + copyUBFrom(*src_bb); + } +} + bool State::startBB(const BasicBlock &bb) { assert(undef_vars.empty()); ENSURE(seen_bbs.emplace(&bb).second); current_bb = &bb; - if (&f.getFirstBB() == &bb) + if (&f.getFirstBB() == &bb) { + if (src_state) { + copyUBFromBB({}); + copyUBFrom(src_state->f.getFirstBB()); + } return true; + } auto I = predecessor_data.find(&bb); if (I == predecessor_data.end()) @@ -736,10 +834,12 @@ bool State::startBB(const BasicBlock &bb) { throw_oom_exception(); DisjointExpr in_memory; - DisjointExpr UB; + DisjointExpr UB; DisjointExpr var_args_in; OrExpr path; + domain.UB = AndExpr(); + bool isFirst = true; for (auto &[src, data] : I->second) { path.add(data.path); @@ -753,25 +853,29 @@ bool State::startBB(const BasicBlock &bb) { data.undef_vars.clear(); if (isFirst) - analysis = std::move(data.analysis); - else { + analysis = data.analysis; + else analysis.meet_with(data.analysis); - data.analysis = {}; - } + + if (isSource()) + data.analysis.clear_smt(); isFirst = false; } assert(!isFirst); - domain.path = std::move(path)(); - domain.UB = *std::move(UB)(); - memory = *std::move(in_memory)(); - var_args_data = *std::move(var_args_in)(); + domain.UB.add(std::move(UB).factor()); + domain.path = std::move(path)(); + memory = *std::move(in_memory)(); + var_args_data = *std::move(var_args_in)(); + + if (src_state) + copyUBFromBB(I->second); return domain; } void State::addJump(expr &&cond, const BasicBlock &dst0, bool always_jump) { - always_jump |= cond.isTrue(); + always_jump = always_jump || cond.isTrue(); cond &= domain.path; if (cond.isFalse() || !domain) @@ -792,13 +896,13 @@ void State::addJump(expr &&cond, const BasicBlock &dst0, bool always_jump) { data.analysis = analysis; data.var_args = var_args_data; } - data.UB.add(domain.UB(), cond); + data.UB.add(domain.UB, cond); data.path.add(std::move(cond)); data.undef_vars.insert(undef_vars.begin(), undef_vars.end()); data.undef_vars.insert(domain.undef_vars.begin(), domain.undef_vars.end()); if (always_jump) - addUB(expr(false)); + domain.path = false; } void State::addJump(const BasicBlock &dst) { @@ -821,7 +925,7 @@ void State::addReturn(StateValue &&val) { return_undef_vars.insert(undef_vars.begin(), undef_vars.end()); return_undef_vars.insert(domain.undef_vars.begin(), domain.undef_vars.end()); undef_vars.clear(); - addUB(expr(false)); + domain.path = false; } void State::addAxiom(AndExpr &&ands) { @@ -1099,7 +1203,7 @@ State::addFnCall(const string &name, vector &&inputs, = inaccessiblemem_bids.try_emplace(name, inaccessiblemem_bids.size()) .first->second; - State::ValueAnalysis::FnCallRanges call_ranges; + ValueAnalysis::FnCallRanges call_ranges; if (!memaccess.canRead(MemoryAccess::Inaccessible).isFalse() || !memaccess.canRead(MemoryAccess::Errno).isFalse() || !memaccess.canRead(MemoryAccess::Other).isFalse()) @@ -1107,6 +1211,8 @@ State::addFnCall(const string &name, vector &&inputs, ? analysis.ranges_fn_calls.project(name) : analysis.ranges_fn_calls; + call_ranges.keep_only_writes(); + if (ret_arg_ty && (*ret_arg_ty == out_type).isFalse()) { ret_arg = out_type.fromInt(ret_arg_ty->toInt(*this, std::move(ret_arg))); } @@ -1350,7 +1456,7 @@ expr State::sinkDomain(bool include_ub) const { OrExpr ret; for (auto &[src, data] : I->second) { - ret.add(data.path() && (include_ub ? *data.UB() : true)); + ret.add(data.path() && (include_ub ? data.UB.factor()() : true)); } return ret(); } @@ -1363,8 +1469,7 @@ const StateValue& State::returnValCached() { Memory& State::returnMemory() { if (auto *m = get_if>(&return_memory)) { - auto val = std::move(*m)(); - return_memory = val ? *std::move(val) : memory.dup(); + return_memory = *std::move(*m)(); } return get(return_memory); } @@ -1376,7 +1481,7 @@ expr State::getJumpCond(const BasicBlock &src, const BasicBlock &dst) const { auto J = I->second.find(&src); return J == I->second.end() ? expr(false) - : J->second.path() && *J->second.UB(); + : J->second.path() && J->second.UB.factor()(); } void State::addGlobalVarBid(const string &glbvar, unsigned bid) { @@ -1419,6 +1524,15 @@ void State::syncSEdataWithSrc(State &src) { fn_call_data = std::move(src.fn_call_data); inaccessiblemem_bids = std::move(src.inaccessiblemem_bids); memory.syncWithSrc(src.returnMemory()); + + src_state = &src; + for (auto &[bb, srcs] : src.predecessor_data) { + OrExpr path; + for (auto &[src, data] : srcs) { + path.add(data.path); + } + src_bb_paths[std::move(path)()].emplace_back(bb); + } } void State::mkAxioms(State &tgt) { @@ -1439,4 +1553,13 @@ void State::mkAxioms(State &tgt) { } } +void State::cleanup() { + src_bb_paths.clear(); + undef_vars.clear(); + fn_call_data.clear(); + domain = {}; + analysis = {}; + var_args_data = {}; +} + } diff --git a/ir/state.h b/ir/state.h index f92db7e1e..55b3d2346 100644 --- a/ir/state.h +++ b/ir/state.h @@ -69,7 +69,7 @@ class State { struct ValTy { StateValue val; smt::expr return_domain; - smt::expr domain; + smt::AndExpr domain; std::set undef_vars; }; @@ -94,18 +94,24 @@ class State { // Possible number of calls per function name that occurred so far // This is an over-approximation, union over all predecessors struct FnCallRanges - : public std::map, - SMTMemoryAccess>> { + : public std::map>, + SMTMemoryAccess>> { void inc(const std::string &name, const SMTMemoryAccess &access); bool overlaps(const std::string &callee, const SMTMemoryAccess &call_access, const FnCallRanges &other) const; + bool isLargerThanInclReads(const FnCallRanges &other) const; // remove all ranges but name FnCallRanges project(const std::string &name) const; + void keep_only_writes(); + void meet_with(const FnCallRanges &other); }; FnCallRanges ranges_fn_calls; void meet_with(const ValueAnalysis &other); + void clear_smt(); }; struct VarArgsEntry { @@ -134,7 +140,7 @@ class State { struct BasicBlockInfo { smt::OrExpr path; - smt::DisjointExpr UB; + smt::DisjointExpr UB; smt::DisjointExpr mem; std::set undef_vars; ValueAnalysis analysis; @@ -148,6 +154,9 @@ class State { smt::AndExpr precondition; smt::AndExpr axioms; + State *src_state = nullptr; + std::map> src_bb_paths; + // for -disallow-ub-exploitation smt::OrExpr unreachable_paths; @@ -183,6 +192,9 @@ class State { unsigned i_tmp_values = 0; // next available position in tmp_values void check_enough_tmp_slots(); + void copyUBFrom(const BasicBlock &bb); + void copyUBFromBB( + const std::unordered_map &tgt_data); // return_domain: a boolean expression describing return condition smt::OrExpr return_domain; @@ -367,6 +379,7 @@ class State { void syncSEdataWithSrc(State &src); void mkAxioms(State &tgt); + void cleanup(); private: smt::expr strip_undef_and_add_ub(const Value &val, const smt::expr &e, diff --git a/smt/expr.cpp b/smt/expr.cpp index 503570b3e..744711ca4 100644 --- a/smt/expr.cpp +++ b/smt/expr.cpp @@ -28,7 +28,7 @@ using namespace util; // helpers to check if all input arguments are non-null #define C(...) \ - if (!isValid() || !expr::allValid( __VA_ARGS__)) [[unlikely]] \ + if (!isValid() || !expr::allValid(__VA_ARGS__)) [[unlikely]] \ return {} #define C2(...) \ diff --git a/smt/exprs.cpp b/smt/exprs.cpp index 2c0a4d48c..cb6c628b4 100644 --- a/smt/exprs.cpp +++ b/smt/exprs.cpp @@ -184,6 +184,59 @@ DisjointExpr::DisjointExpr(const expr &e, unsigned depth_limit) { } while (!worklist.empty()); } +// factor the common terms out +template<> AndExpr DisjointExpr::factor() const { + assert(!vals.empty()); + if (vals.size() == 1) + return vals.begin()->first; + + AndExpr ret; + vector> vals2; + vals2.insert(vals2.end(), vals.begin(), vals.end()); + + vector::iterator, set::iterator>> its; + its.reserve(vals2.size()); + for (auto &v : vals2) { + its.emplace_back(v.first.exprs.begin(), v.first.exprs.end()); + } + + auto &it0 = its[0].first; + while (it0 != its[0].second) { + const expr &e0 = *it0; + bool in_all = true; + for (unsigned i = 1, e = its.size(); i != e; ++i) { + auto &it2 = its[i].first; + if (it2 == its[i].second) { + goto end; + } + auto cmp = *it2 <=> e0; + if (cmp < 0) { + ++it2; + --i; // repeate this AndExpr + } else if (cmp > 0) { + ++it0; + in_all = false; + break; + } + } + if (in_all) { + ret.add(e0); + unsigned i = 0; + for (auto &v : vals2) { + auto &it = its[i++].first; + it = v.first.exprs.erase(it); + } + } + } +end: + DisjointExpr leftovers; + for (auto &[v, domain] : vals2) { + leftovers.add(std::move(v)(), std::move(domain)); + } + ret.add(*std::move(leftovers)()); + return ret; +} + void FunctionExpr::add(const expr &key, expr &&val) { ENSURE(fn.emplace(key, std::move(val)).second); diff --git a/smt/exprs.h b/smt/exprs.h index 546cd2dc3..f258ad6da 100644 --- a/smt/exprs.h +++ b/smt/exprs.h @@ -34,7 +34,9 @@ class AndExpr { expr operator()() const; operator bool() const; bool isTrue() const { return exprs.empty(); } + auto operator<=>(const AndExpr&) const = default; friend std::ostream &operator<<(std::ostream &os, const AndExpr &e); + template friend class DisjointExpr; }; @@ -47,6 +49,7 @@ class OrExpr { void add(const OrExpr &other); expr operator()() const; bool empty() const { return exprs.empty(); } + auto operator<=>(const OrExpr&) const = default; friend std::ostream &operator<<(std::ostream &os, const OrExpr &e); }; @@ -121,6 +124,8 @@ class DisjointExpr { std::optional operator()() && { return std::move(*this).mk({}); } + T factor() const; + expr domain() const { OrExpr ret; for (auto &[val, domain] : vals) { diff --git a/tools/alive-exec.cpp b/tools/alive-exec.cpp index ffefb3d48..9331e0185 100644 --- a/tools/alive-exec.cpp +++ b/tools/alive-exec.cpp @@ -156,7 +156,7 @@ optional exec(llvm::Function &F, continue; } - solver.add(val.domain); + solver.add(val.domain()); r = solver.check(); if (error(r)) return {}; diff --git a/tools/transform.cpp b/tools/transform.cpp index f4f014a99..f502ba54e 100644 --- a/tools/transform.cpp +++ b/tools/transform.cpp @@ -295,7 +295,7 @@ static bool error(Errors &errs, State &src_state, State &tgt_state, if (m.eval(val.return_domain).isFalse()) { s << *var << " = function did not return!\n"; break; - } else if (m.eval(val.domain).isFalse()) { + } else if (m.eval(val.domain()).isFalse()) { s << "Function " << call->getFnName() << " triggered UB\n"; break; } else if (var->isVoid()) { @@ -304,8 +304,7 @@ static bool error(Errors &errs, State &src_state, State &tgt_state, } } - if (!dynamic_cast(var) && // domain always false after exec - m.eval(val.domain).isFalse()) { + if (m.eval(val.domain()).isFalse()) { s << *var << " = UB triggered!\n"; break; } @@ -497,8 +496,8 @@ check_refinement(Errors &errs, const Transform &t, State &src_state, State &tgt_state, const Value *var, const Type &type, const State::ValTy &ap, const State::ValTy &bp, bool check_each_var) { - auto &fndom_a = ap.domain; - auto &fndom_b = bp.domain; + auto fndom_a = ap.domain(); + auto fndom_b = bp.domain(); auto &retdom_a = ap.return_domain; auto &retdom_b = bp.return_domain; auto &a = ap.val; @@ -1369,7 +1368,9 @@ pair, unique_ptr> TransformVerify::exec() const { auto tgt_state = make_unique(t.tgt, false); sym_exec(*src_state); tgt_state->syncSEdataWithSrc(*src_state); + src_state->cleanup(); sym_exec(*tgt_state); + tgt_state->cleanup(); src_state->mkAxioms(*tgt_state); return { std::move(src_state), std::move(tgt_state) };