Skip to content

Commit

Permalink
try to copy UB conditions from src to tgt
Browse files Browse the repository at this point in the history
this is in preparation for subsequent work to take advantage
of implied conditions from UB
  • Loading branch information
nunoplopes committed Dec 28, 2024
1 parent e2051fa commit 4bbdd96
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 58 deletions.
217 changes: 170 additions & 47 deletions ir/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,27 @@ void State::ValueAnalysis::meet_with(const State::ValueAnalysis &other) {
non_poison_vals = intersect_set(non_poison_vals, other.non_poison_vals);
non_undef_vals = intersect_set(non_undef_vals, other.non_undef_vals);
unused_vars = intersect_set(unused_vars, other.unused_vars);
ranges_fn_calls.meet_with(other.ranges_fn_calls);
}

for (auto &[fn, pair] : other.ranges_fn_calls) {
auto &[calls, access] = pair;
auto [I, inserted] = ranges_fn_calls.try_emplace(fn, pair);
if (inserted) {
I->second.first.emplace(0);
} else {
I->second.first.insert(calls.begin(), calls.end());
I->second.second |= access;
}
}

for (auto &[fn, pair] : ranges_fn_calls) {
auto &[calls, access] = pair;
if (!other.ranges_fn_calls.count(fn))
calls.emplace(0);
}
void State::ValueAnalysis::clear_smt() {
non_poison_vals.clear();
non_undef_vals = decltype(non_undef_vals)();
unused_vars = decltype(unused_vars)();
}

void State::ValueAnalysis::FnCallRanges::inc(const string &name,
const SMTMemoryAccess &access) {
if (access.canWriteSomething().isFalse())
return;
bool canwrite = !access.canWriteSomething().isFalse();

auto [I, inserted] = try_emplace(name);
if (inserted) {
I->second.first.emplace(1);
I->second.first.emplace(1, canwrite);
I->second.second = access;
} else {
set<unsigned> new_set;
for (unsigned n : I->second.first) {
new_set.emplace(n+1);
set<pair<unsigned, bool>> new_set;
for (auto [n, writes0] : I->second.first) {
new_set.emplace(n+1, writes0 | canwrite);
}
I->second.first = std::move(new_set);
I->second.second |= access;
Expand Down Expand Up @@ -177,14 +166,13 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee,

for (auto &[fn, pair] : *this) {
auto &[calls, access] = pair;
assert(!access.canWriteSomething().isFalse());

if (skip(fn, access))
continue;

auto I = other.find(fn);
if (I == other.end()) {
if (calls.count(0))
if (calls.count({0, true}))
continue;
return false;
}
Expand All @@ -194,7 +182,10 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee,
if ((access | I->second.second).canReadSomething().isFalse())
continue;

if (intersect_set(calls, I->second.first).empty())
auto set = intersect_set(calls, I->second.first);
// must only have write accesses
assert(ranges::all_of(set, [](auto &p) { return p.second; }));
if (set.empty())
return false;
}

Expand All @@ -203,10 +194,31 @@ State::ValueAnalysis::FnCallRanges::overlaps(const string &callee,
if (skip(fn, access))
continue;

if (!calls.count(0) && !count(fn))
if (!calls.count({0, true}) && !count(fn))
return false;
}

return true;
}

bool State::ValueAnalysis::FnCallRanges::isLargerThanInclReads(
const FnCallRanges &other) const {
for (auto &[fn, pair] : *this) {
auto &[calls, access] = pair;
auto I = other.find(fn);
if (I == other.end())
continue;

auto first_val = calls.begin()->first;
auto other_last_val = I->second.first.rbegin()->first;
if (first_val < other_last_val)
return false;
}

for (auto &[fn, pair] : other) {
if (!count(fn))
return false;
}
return true;
}

Expand All @@ -220,6 +232,41 @@ State::ValueAnalysis::FnCallRanges::project(const string &name) const {
return ranges;
}

void State::ValueAnalysis::FnCallRanges::keep_only_writes() {
for (auto I = begin(); I != end(); ) {
auto &[calls, access] = I->second;
for (auto II = calls.begin(); II != calls.end(); ) {
if (!II->second)
II = calls.erase(II);
else
++II;
}
if (calls.empty())
I = erase(I);
else
++I;
}
}

void State::ValueAnalysis::FnCallRanges::meet_with(const FnCallRanges &other) {
for (auto &[fn, pair] : other) {
auto &[calls, access] = pair;
auto [I, inserted] = try_emplace(fn, pair);
if (inserted) {
I->second.first.emplace(0, true);
} else {
I->second.first.insert(calls.begin(), calls.end());
I->second.second |= access;
}
}

for (auto &[fn, pair] : *this) {
auto &[calls, access] = pair;
if (!other.count(fn))
calls.emplace(0, true);
}
}

State::VarArgsData
State::VarArgsData::mkIf(const expr &cond, VarArgsData &&then,
VarArgsData &&els) {
Expand Down Expand Up @@ -267,9 +314,9 @@ const State::ValTy& State::exec(const Value &v) {
domain.noreturn = true;
auto val = v.toSMT(*this);

auto value_ub = domain.UB();
auto value_ub = domain.UB;
if (config::disallow_ub_exploitation)
value_ub &= !guardable_ub();
value_ub.add(!guardable_ub());

auto [I, inserted]
= values.try_emplace(&v, ValTy{std::move(val), domain.noreturn,
Expand Down Expand Up @@ -720,13 +767,64 @@ void State::cleanupPredecessorData() {
predecessor_data.clear();
}

void State::copyUBFrom(const BasicBlock &bb) {
if (config::disallow_ub_exploitation)
return;

// Time-travel UB: anything that happens before a possibly non-returning call
// can be moved up to the entry of the BB.
const Value *before_call = nullptr;
for (auto &i : bb.instrs()) {
if (auto *call = dynamic_cast<const FnCall*>(&i)) {
if (!call->hasAttribute(FnAttrs::WillReturn))
break;
}
before_call = &i;
}
if (!before_call)
return;

auto src_val_I = src_state->values.find(before_call);
assert(src_val_I != src_state->values.end());
domain.UB.add(src_val_I->second.domain);
}

void State::copyUBFromBB(
const unordered_map<const BasicBlock*, BasicBlockInfo> &tgt_data) {
auto I = src_bb_paths.find(domain.path);
if (I == src_bb_paths.end())
return;

for (auto *src_bb : I->second) {
bool all_paths_ok = true;
for (auto &[_, src_data] : src_state->predecessor_data.at(src_bb)) {
auto I = ranges::find_if(tgt_data, [&](const auto &p) {
return is_eq(p.second.path <=> src_data.path);
});
if (I == tgt_data.end() ||
!I->second.analysis.ranges_fn_calls.isLargerThanInclReads(
src_data.analysis.ranges_fn_calls)) {
all_paths_ok = false;
break;
}
}
if (all_paths_ok)
copyUBFrom(*src_bb);
}
}

bool State::startBB(const BasicBlock &bb) {
assert(undef_vars.empty());
ENSURE(seen_bbs.emplace(&bb).second);
current_bb = &bb;

if (&f.getFirstBB() == &bb)
if (&f.getFirstBB() == &bb) {
if (src_state) {
copyUBFromBB({});
copyUBFrom(src_state->f.getFirstBB());
}
return true;
}

auto I = predecessor_data.find(&bb);
if (I == predecessor_data.end())
Expand All @@ -736,10 +834,12 @@ bool State::startBB(const BasicBlock &bb) {
throw_oom_exception();

DisjointExpr<Memory> in_memory;
DisjointExpr<expr> UB;
DisjointExpr<AndExpr> UB;
DisjointExpr<VarArgsData> var_args_in;
OrExpr path;

domain.UB = AndExpr();

bool isFirst = true;
for (auto &[src, data] : I->second) {
path.add(data.path);
Expand All @@ -753,25 +853,29 @@ bool State::startBB(const BasicBlock &bb) {
data.undef_vars.clear();

if (isFirst)
analysis = std::move(data.analysis);
else {
analysis = data.analysis;
else
analysis.meet_with(data.analysis);
data.analysis = {};
}

if (isSource())
data.analysis.clear_smt();
isFirst = false;
}
assert(!isFirst);

domain.path = std::move(path)();
domain.UB = *std::move(UB)();
memory = *std::move(in_memory)();
var_args_data = *std::move(var_args_in)();
domain.UB.add(std::move(UB).factor());
domain.path = std::move(path)();
memory = *std::move(in_memory)();
var_args_data = *std::move(var_args_in)();

if (src_state)
copyUBFromBB(I->second);

return domain;
}

void State::addJump(expr &&cond, const BasicBlock &dst0, bool always_jump) {
always_jump |= cond.isTrue();
always_jump = always_jump || cond.isTrue();

cond &= domain.path;
if (cond.isFalse() || !domain)
Expand All @@ -792,13 +896,13 @@ void State::addJump(expr &&cond, const BasicBlock &dst0, bool always_jump) {
data.analysis = analysis;
data.var_args = var_args_data;
}
data.UB.add(domain.UB(), cond);
data.UB.add(domain.UB, cond);
data.path.add(std::move(cond));
data.undef_vars.insert(undef_vars.begin(), undef_vars.end());
data.undef_vars.insert(domain.undef_vars.begin(), domain.undef_vars.end());

if (always_jump)
addUB(expr(false));
domain.path = false;
}

void State::addJump(const BasicBlock &dst) {
Expand All @@ -821,7 +925,7 @@ void State::addReturn(StateValue &&val) {
return_undef_vars.insert(undef_vars.begin(), undef_vars.end());
return_undef_vars.insert(domain.undef_vars.begin(), domain.undef_vars.end());
undef_vars.clear();
addUB(expr(false));
domain.path = false;
}

void State::addAxiom(AndExpr &&ands) {
Expand Down Expand Up @@ -1099,14 +1203,16 @@ State::addFnCall(const string &name, vector<StateValue> &&inputs,
= inaccessiblemem_bids.try_emplace(name, inaccessiblemem_bids.size())
.first->second;

State::ValueAnalysis::FnCallRanges call_ranges;
ValueAnalysis::FnCallRanges call_ranges;
if (!memaccess.canRead(MemoryAccess::Inaccessible).isFalse() ||
!memaccess.canRead(MemoryAccess::Errno).isFalse() ||
!memaccess.canRead(MemoryAccess::Other).isFalse())
call_ranges = memaccess.canOnlyRead(MemoryAccess::Inaccessible).isTrue()
? analysis.ranges_fn_calls.project(name)
: analysis.ranges_fn_calls;

call_ranges.keep_only_writes();

if (ret_arg_ty && (*ret_arg_ty == out_type).isFalse()) {
ret_arg = out_type.fromInt(ret_arg_ty->toInt(*this, std::move(ret_arg)));
}
Expand Down Expand Up @@ -1350,7 +1456,7 @@ expr State::sinkDomain(bool include_ub) const {

OrExpr ret;
for (auto &[src, data] : I->second) {
ret.add(data.path() && (include_ub ? *data.UB() : true));
ret.add(data.path() && (include_ub ? data.UB.factor()() : true));
}
return ret();
}
Expand All @@ -1363,8 +1469,7 @@ const StateValue& State::returnValCached() {

Memory& State::returnMemory() {
if (auto *m = get_if<DisjointExpr<Memory>>(&return_memory)) {
auto val = std::move(*m)();
return_memory = val ? *std::move(val) : memory.dup();
return_memory = *std::move(*m)();
}
return get<Memory>(return_memory);
}
Expand All @@ -1376,7 +1481,7 @@ expr State::getJumpCond(const BasicBlock &src, const BasicBlock &dst) const {

auto J = I->second.find(&src);
return J == I->second.end() ? expr(false)
: J->second.path() && *J->second.UB();
: J->second.path() && J->second.UB.factor()();
}

void State::addGlobalVarBid(const string &glbvar, unsigned bid) {
Expand Down Expand Up @@ -1419,6 +1524,15 @@ void State::syncSEdataWithSrc(State &src) {
fn_call_data = std::move(src.fn_call_data);
inaccessiblemem_bids = std::move(src.inaccessiblemem_bids);
memory.syncWithSrc(src.returnMemory());

src_state = &src;
for (auto &[bb, srcs] : src.predecessor_data) {
OrExpr path;
for (auto &[src, data] : srcs) {
path.add(data.path);
}
src_bb_paths[std::move(path)()].emplace_back(bb);
}
}

void State::mkAxioms(State &tgt) {
Expand All @@ -1439,4 +1553,13 @@ void State::mkAxioms(State &tgt) {
}
}

void State::cleanup() {
src_bb_paths.clear();
undef_vars.clear();
fn_call_data.clear();
domain = {};
analysis = {};
var_args_data = {};
}

}
Loading

0 comments on commit 4bbdd96

Please sign in to comment.