Skip to content

Commit

Permalink
close #900: improve support for tail calls
Browse files Browse the repository at this point in the history
  • Loading branch information
nunoplopes committed Oct 17, 2024
1 parent 40d1b5b commit e3ad4db
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 104 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ set(IR_SRCS
ir/constant.cpp
ir/fast_math.cpp
ir/function.cpp
ir/functions.cpp
ir/globals.cpp
ir/instr.cpp
ir/memory.cpp
Expand Down
117 changes: 62 additions & 55 deletions ir/attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,74 +631,81 @@ ostream& operator<<(std::ostream &os, const TailCallInfo &tci) {
return os << str;
}

void TailCallInfo::checkTailCall(const Instr &i, State &s) const {
void TailCallInfo::check(State &s, const Instr &i,
const vector<PtrInput> &args) const {
if (type == TailCallInfo::None)
return;

bool preconditions_OK = true;
// Cannot access allocas, va_args, or byval arguments from the caller.
// Exception: alloca or byval arg may be passed to the callee as byval
for (const auto &arg : args) {
Pointer ptr(s.getMemory(), arg.val.value);
s.addUB(arg.val.non_poison.implies(
(ptr.isStackAllocated() || ptr.isByval()).implies(arg.byval) &&
true // TODO: check for !var_args
));
}

auto *callee = dynamic_cast<const FnCall *>(&i);
if (callee) {
for (const auto &[arg, attrs] : callee->getArgs()) {
bool callee_has_byval = attrs.has(ParamAttrs::ByVal);
if (dynamic_cast<const Alloc *>(arg) && !callee_has_byval) {
preconditions_OK = false;
break;
}
if (auto *input = dynamic_cast<const Input *>(arg)) {
bool caller_has_byval = input->hasAttribute(ParamAttrs::ByVal);
if (callee_has_byval != caller_has_byval) {
preconditions_OK = false;
break;
if (type != TailCallInfo::MustTail)
return;

// additional rules for musttail

auto *call = dynamic_cast<const FnCall*>(&i);

// - The call must immediately precede a ret instruction, or a bitcast
// - The ret instruction must return the (possibly bitcasted) value produced
// by the call, undef/poison, or void.

bool found_instr = false, found_ret = false;
const Value *val = &i;
for (auto &instr : s.getFn().bbOf(i).instrs()) {
if (&instr == val) {
assert(!found_instr);
found_instr = true;
continue;
}

if (found_instr) {
if (auto *cast = isCast(ConversionOp::BitCast, instr)) {
if (&cast->getValue() != val) {
s.addUB(expr(false));
return;
}
val = cast;
continue;
}
}
} else {
// Handling memcpy / memcmp et alia.
for (const auto &op : i.operands()) {
if (dynamic_cast<const Alloc *>(op)) {
preconditions_OK = false;
break;
if (auto *ret = dynamic_cast<const Return*>(&instr)) {
found_ret = true;
if (ret->getType().isVoid() && i.getType().isVoid())
break;
auto *ret_val = ret->operands()[0];
if (dynamic_cast<UndefValue*>(ret_val) ||
dynamic_cast<PoisonValue*>(ret_val) ||
ret_val == val)
break;
}
s.addUB(expr(false));
}
}
ENSURE(found_instr);
if (!found_ret)
s.addUB(expr(false));

if (callee && type == TailCallInfo::MustTail) {
bool callee_is_vararg = callee->getVarArgIdx() != -1u;
bool caller_is_vararg = s.getFn().isVarArgs();
if (!has_same_calling_convention || (callee_is_vararg && !caller_is_vararg))
preconditions_OK = false;
}

if (preconditions_OK && type == TailCallInfo::MustTail) {
bool found = false;
const auto &instrs = s.getFn().bbOf(i).instrs();
auto it = instrs.begin();
for (auto e = instrs.end(); it != e; ++it) {
if (&*it == &i) {
found = true;
break;
}
}
assert(found);

++it;
auto &next_instr = *it;
if (auto *ret = dynamic_cast<const Return *>(&next_instr)) {
if (ret->getType().isVoid() && i.getType().isVoid())
return;
auto *ret_val = ret->operands()[0];
if (ret_val == &i)
return;
}
// The calling conventions of the caller and callee must match.
if (!has_same_calling_convention)
s.addUB(expr(false));

preconditions_OK = false;
// The callee must be varargs iff the caller is varargs.
if (call) {
bool callee_is_vararg = call->getVarArgIdx() != -1u;
bool caller_is_vararg = s.getFn().isVarArgs();
if (callee_is_vararg && !caller_is_vararg)
s.addUB(expr(false));
}

if (!preconditions_OK) {
// Preconditions unsatifisfied or refinement for musttail failed, hence UB.
s.addUB(expr(false));
}
// TODO:
// - The return type must not undergo automatic conversion to an sret pointer.
}

}
5 changes: 3 additions & 2 deletions ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Copyright (c) 2018-present The Alive2 Authors.
// Distributed under the MIT license that can be found in the LICENSE file.

#include "ir/functions.h"
#include "smt/exprs.h"
#include <optional>
#include <ostream>
Expand Down Expand Up @@ -221,9 +222,9 @@ smt::expr isfpclass(const smt::expr &v, const Type &ty, uint16_t mask);
struct TailCallInfo final {
enum TailCallType { None, Tail, MustTail } type = None;
// Determine if callee and caller have the same calling convention.
bool has_same_calling_convention = false;
bool has_same_calling_convention = true;

void checkTailCall(const Instr &i, State &s) const;
void check(State &s, const Instr &i, const std::vector<PtrInput> &args) const;
friend std::ostream& operator<<(std::ostream &os, const TailCallInfo &tci);
};

Expand Down
21 changes: 21 additions & 0 deletions ir/functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2018-present The Alive2 Authors.
// Distributed under the MIT license that can be found in the LICENSE file.

#include "ir/functions.h"

using namespace smt;

namespace IR {

expr PtrInput::implies(const PtrInput &rhs) const {
return implies_attrs(rhs) && val == rhs.val && idx == rhs.idx;
}

expr PtrInput::implies_attrs(const PtrInput &rhs) const {
return byval == rhs.byval &&
rhs.noread .implies(noread) &&
rhs.nowrite .implies(nowrite) &&
rhs.nocapture.implies(nocapture);
}

}
33 changes: 33 additions & 0 deletions ir/functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

// Copyright (c) 2018-present The Alive2 Authors.
// Distributed under the MIT license that can be found in the LICENSE file.

#include "ir/state_value.h"
#include "smt/expr.h"

namespace IR {

struct PtrInput {
unsigned idx = 0;
StateValue val;
smt::expr byval = false;
smt::expr noread = false;
smt::expr nowrite = false;
smt::expr nocapture = false;

PtrInput(unsigned idx, StateValue &&val, smt::expr &&byval,
smt::expr &&noread, smt::expr &&nowrite, smt::expr &&nocapture) :
idx(idx), val(std::move(val)), byval(std::move(byval)),
noread(std::move(noread)), nowrite(std::move(nowrite)),
nocapture(std::move(nocapture)) {}

PtrInput(const StateValue &val) : val(val) {}
PtrInput(const smt::expr &val) : val(StateValue(smt::expr(val), true)) {}

smt::expr implies(const PtrInput &rhs) const;
smt::expr implies_attrs(const PtrInput &rhs) const;
auto operator<=>(const PtrInput &rhs) const = default;
};

}
23 changes: 12 additions & 11 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2359,8 +2359,7 @@ static void check_can_store(State &s, const expr &p0) {
static void unpack_inputs(State &s, Value &argv, Type &ty,
const ParamAttrs &argflag, StateValue value,
StateValue value2, vector<StateValue> &inputs,
vector<Memory::PtrInput> &ptr_inputs,
unsigned idx) {
vector<PtrInput> &ptr_inputs, unsigned idx) {
if (auto agg = ty.getAsAggregateType()) {
for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) {
if (agg->isPadding(i))
Expand Down Expand Up @@ -2421,7 +2420,7 @@ StateValue FnCall::toSMT(State &s) const {
auto &m = s.getMemory();

vector<StateValue> inputs;
vector<Memory::PtrInput> ptr_inputs;
vector<PtrInput> ptr_inputs;

unsigned indirect_hash = 0;
auto ptr = fnptr;
Expand Down Expand Up @@ -2493,7 +2492,7 @@ StateValue FnCall::toSMT(State &s) const {
!attrs.has(FnAttrs::WillReturn))
s.addGuardableUB(expr(false));

tci.checkTailCall(*this, s);
tci.check(s, *this, ptr_inputs);

auto get_alloc_ptr = [&]() -> Value& {
for (auto &[arg, flags] : args) {
Expand Down Expand Up @@ -4154,7 +4153,7 @@ StateValue Memset::toSMT(State &s) const {
vptr = sv_ptr.value;
}
check_can_store(s, vptr);
tci.checkTailCall(*this, s);
tci.check(s, *this, { vptr });

s.getMemory().memset(vptr, s[*val].zextOrTrunc(8), vbytes, align,
s.getUndefVars());
Expand Down Expand Up @@ -4216,7 +4215,7 @@ StateValue MemsetPattern::toSMT(State &s) const {
auto &vbytes = s.getAndAddPoisonUB(*bytes, true).value;
check_can_store(s, vptr);
check_can_load(s, vpattern);
tci.checkTailCall(*this, s);
tci.check(s, *this, { vptr });

s.getMemory().memset_pattern(vptr, vpattern, vbytes, pattern_length);
return {};
Expand Down Expand Up @@ -4344,7 +4343,7 @@ StateValue Memcpy::toSMT(State &s) const {

check_can_load(s, vsrc);
check_can_store(s, vdst);
tci.checkTailCall(*this, s);
tci.check(s, *this, { vsrc, vdst });

s.getMemory().memcpy(vdst, vsrc, vbytes, align_dst, align_src, move);
return {};
Expand Down Expand Up @@ -4396,14 +4395,16 @@ void Memcmp::print(ostream &os) const {
}

StateValue Memcmp::toSMT(State &s) const {
auto &[vptr1, np1] = s[*ptr1];
auto &[vptr2, np2] = s[*ptr2];
auto &stptr1 = s[*ptr1];
auto &stptr2 = s[*ptr2];
auto &[vptr1, np1] = stptr1;
auto &[vptr2, np2] = stptr2;
auto &vnum = s.getAndAddPoisonUB(*num).value;
s.addGuardableUB((vnum != 0).implies(np1 && np2));

check_can_load(s, vptr1);
check_can_load(s, vptr2);
tci.checkTailCall(*this, s);
tci.check(s, *this, { stptr1, stptr2 });

Pointer p1(s.getMemory(), vptr1), p2(s.getMemory(), vptr2);
// memcmp can be optimized to load & icmps, and it requires this
Expand Down Expand Up @@ -4501,7 +4502,7 @@ void Strlen::print(ostream &os) const {
StateValue Strlen::toSMT(State &s) const {
auto &eptr = s.getWellDefinedPtr(*ptr);
check_can_load(s, eptr);
tci.checkTailCall(*this, s);
tci.check(s, *this, { eptr });

Pointer p(s.getMemory(), eptr);
Type &ty = getType();
Expand Down
11 changes: 0 additions & 11 deletions ir/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,17 +1657,6 @@ pair<expr, expr> Memory::mkUndefInput(const ParamAttrs &attrs0) {
return { std::move(ptr).release(), std::move(undef) };
}

expr Memory::PtrInput::implies(const PtrInput &rhs) const {
return implies_attrs(rhs) && val == rhs.val && idx == rhs.idx;
}

expr Memory::PtrInput::implies_attrs(const PtrInput &rhs) const {
return byval == rhs.byval &&
rhs.noread .implies(noread) &&
rhs.nowrite .implies(nowrite) &&
rhs.nocapture.implies(nocapture);
}

Memory::FnRetData Memory::FnRetData::mkIf(const expr &cond, const FnRetData &a,
const FnRetData &b) {
return { expr::mkIf(cond, a.size, b.size),
Expand Down
20 changes: 1 addition & 19 deletions ir/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Distributed under the MIT license that can be found in the LICENSE file.

#include "ir/attrs.h"
#include "ir/functions.h"
#include "ir/pointer.h"
#include "ir/state_value.h"
#include "ir/type.h"
Expand Down Expand Up @@ -273,25 +274,6 @@ class Memory {
smt::expr mkInput(const char *name, const ParamAttrs &attrs);
std::pair<smt::expr, smt::expr> mkUndefInput(const ParamAttrs &attrs);

struct PtrInput {
unsigned idx;
StateValue val;
smt::expr byval;
smt::expr noread;
smt::expr nowrite;
smt::expr nocapture;

PtrInput(unsigned idx, StateValue &&val, smt::expr &&byval,
smt::expr &&noread, smt::expr &&nowrite, smt::expr &&nocapture) :
idx(idx), val(std::move(val)), byval(std::move(byval)),
noread(std::move(noread)), nowrite(std::move(nowrite)),
nocapture(std::move(nocapture)) {}

smt::expr implies(const PtrInput &rhs) const;
smt::expr implies_attrs(const PtrInput &rhs) const;
auto operator<=>(const PtrInput &rhs) const = default;
};

struct FnRetData {
smt::expr size;
smt::expr align;
Expand Down
6 changes: 3 additions & 3 deletions ir/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ expr State::FnCallInput::implies(const FnCallInput &rhs) const {
expr State::FnCallInput::refinedBy(
State &s, const string &callee, unsigned inaccessible_bid,
const vector<StateValue> &args_nonptr2,
const vector<Memory::PtrInput> &args_ptr2,
const vector<PtrInput> &args_ptr2,
const ValueAnalysis::FnCallRanges &fncall_ranges2,
const Memory &m2, const SMTMemoryAccess &memaccess2, bool noret2,
bool willret2) const {
Expand Down Expand Up @@ -933,7 +933,7 @@ expr State::FnCallInput::refinedBy(

if (memaccess2.canReadSomething().isTrue()) {
bool argmemonly = memaccess2.canOnlyRead(MemoryAccess::Args).isTrue();
vector<Memory::PtrInput> dummy1, dummy2;
vector<PtrInput> dummy1, dummy2;
auto restrict_ptrs = argmemonly ? &args_ptr : nullptr;
auto restrict_ptrs2 = argmemonly ? &args_ptr2 : nullptr;
if (memaccess2.canOnlyRead(MemoryAccess::Inaccessible).isTrue()) {
Expand Down Expand Up @@ -1004,7 +1004,7 @@ expr State::FnCallOutput::implies(const FnCallOutput &rhs,

StateValue
State::addFnCall(const string &name, vector<StateValue> &&inputs,
vector<Memory::PtrInput> &&ptr_inputs,
vector<PtrInput> &&ptr_inputs,
const Type &out_type, StateValue &&ret_arg,
const Type *ret_arg_ty, vector<StateValue> &&ret_args,
const FnAttrs &attrs, unsigned indirect_call_hash) {
Expand Down
Loading

0 comments on commit e3ad4db

Please sign in to comment.