From e3ad4dbb49f9c7648c16c5565a4a66e508dd5e2b Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Thu, 17 Oct 2024 10:23:12 +0100 Subject: [PATCH] close #900: improve support for tail calls --- CMakeLists.txt | 1 + ir/attrs.cpp | 117 +++++++++++++++++++++++++---------------------- ir/attrs.h | 5 +- ir/functions.cpp | 21 +++++++++ ir/functions.h | 33 +++++++++++++ ir/instr.cpp | 23 +++++----- ir/memory.cpp | 11 ----- ir/memory.h | 20 +------- ir/state.cpp | 6 +-- ir/state.h | 6 +-- 10 files changed, 139 insertions(+), 104 deletions(-) create mode 100644 ir/functions.cpp create mode 100644 ir/functions.h diff --git a/CMakeLists.txt b/CMakeLists.txt index fa14353b8..f8352753c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,7 @@ set(IR_SRCS ir/constant.cpp ir/fast_math.cpp ir/function.cpp + ir/functions.cpp ir/globals.cpp ir/instr.cpp ir/memory.cpp diff --git a/ir/attrs.cpp b/ir/attrs.cpp index 1d9519a2f..2081f44a2 100644 --- a/ir/attrs.cpp +++ b/ir/attrs.cpp @@ -631,74 +631,81 @@ ostream& operator<<(std::ostream &os, const TailCallInfo &tci) { return os << str; } -void TailCallInfo::checkTailCall(const Instr &i, State &s) const { +void TailCallInfo::check(State &s, const Instr &i, + const vector &args) const { if (type == TailCallInfo::None) return; - bool preconditions_OK = true; + // Cannot access allocas, va_args, or byval arguments from the caller. + // Exception: alloca or byval arg may be passed to the callee as byval + for (const auto &arg : args) { + Pointer ptr(s.getMemory(), arg.val.value); + s.addUB(arg.val.non_poison.implies( + (ptr.isStackAllocated() || ptr.isByval()).implies(arg.byval) && + true // TODO: check for !var_args + )); + } - auto *callee = dynamic_cast(&i); - if (callee) { - for (const auto &[arg, attrs] : callee->getArgs()) { - bool callee_has_byval = attrs.has(ParamAttrs::ByVal); - if (dynamic_cast(arg) && !callee_has_byval) { - preconditions_OK = false; - break; - } - if (auto *input = dynamic_cast(arg)) { - bool caller_has_byval = input->hasAttribute(ParamAttrs::ByVal); - if (callee_has_byval != caller_has_byval) { - preconditions_OK = false; - break; + if (type != TailCallInfo::MustTail) + return; + + // additional rules for musttail + + auto *call = dynamic_cast(&i); + + // - The call must immediately precede a ret instruction, or a bitcast + // - The ret instruction must return the (possibly bitcasted) value produced + // by the call, undef/poison, or void. + + bool found_instr = false, found_ret = false; + const Value *val = &i; + for (auto &instr : s.getFn().bbOf(i).instrs()) { + if (&instr == val) { + assert(!found_instr); + found_instr = true; + continue; + } + + if (found_instr) { + if (auto *cast = isCast(ConversionOp::BitCast, instr)) { + if (&cast->getValue() != val) { + s.addUB(expr(false)); + return; } + val = cast; + continue; } - } - } else { - // Handling memcpy / memcmp et alia. - for (const auto &op : i.operands()) { - if (dynamic_cast(op)) { - preconditions_OK = false; - break; + if (auto *ret = dynamic_cast(&instr)) { + found_ret = true; + if (ret->getType().isVoid() && i.getType().isVoid()) + break; + auto *ret_val = ret->operands()[0]; + if (dynamic_cast(ret_val) || + dynamic_cast(ret_val) || + ret_val == val) + break; } + s.addUB(expr(false)); } } + ENSURE(found_instr); + if (!found_ret) + s.addUB(expr(false)); - if (callee && type == TailCallInfo::MustTail) { - bool callee_is_vararg = callee->getVarArgIdx() != -1u; - bool caller_is_vararg = s.getFn().isVarArgs(); - if (!has_same_calling_convention || (callee_is_vararg && !caller_is_vararg)) - preconditions_OK = false; - } - - if (preconditions_OK && type == TailCallInfo::MustTail) { - bool found = false; - const auto &instrs = s.getFn().bbOf(i).instrs(); - auto it = instrs.begin(); - for (auto e = instrs.end(); it != e; ++it) { - if (&*it == &i) { - found = true; - break; - } - } - assert(found); - - ++it; - auto &next_instr = *it; - if (auto *ret = dynamic_cast(&next_instr)) { - if (ret->getType().isVoid() && i.getType().isVoid()) - return; - auto *ret_val = ret->operands()[0]; - if (ret_val == &i) - return; - } + // The calling conventions of the caller and callee must match. + if (!has_same_calling_convention) + s.addUB(expr(false)); - preconditions_OK = false; + // The callee must be varargs iff the caller is varargs. + if (call) { + bool callee_is_vararg = call->getVarArgIdx() != -1u; + bool caller_is_vararg = s.getFn().isVarArgs(); + if (callee_is_vararg && !caller_is_vararg) + s.addUB(expr(false)); } - if (!preconditions_OK) { - // Preconditions unsatifisfied or refinement for musttail failed, hence UB. - s.addUB(expr(false)); - } + // TODO: + // - The return type must not undergo automatic conversion to an sret pointer. } } diff --git a/ir/attrs.h b/ir/attrs.h index 409bf7bca..2544f90df 100644 --- a/ir/attrs.h +++ b/ir/attrs.h @@ -3,6 +3,7 @@ // Copyright (c) 2018-present The Alive2 Authors. // Distributed under the MIT license that can be found in the LICENSE file. +#include "ir/functions.h" #include "smt/exprs.h" #include #include @@ -221,9 +222,9 @@ smt::expr isfpclass(const smt::expr &v, const Type &ty, uint16_t mask); struct TailCallInfo final { enum TailCallType { None, Tail, MustTail } type = None; // Determine if callee and caller have the same calling convention. - bool has_same_calling_convention = false; + bool has_same_calling_convention = true; - void checkTailCall(const Instr &i, State &s) const; + void check(State &s, const Instr &i, const std::vector &args) const; friend std::ostream& operator<<(std::ostream &os, const TailCallInfo &tci); }; diff --git a/ir/functions.cpp b/ir/functions.cpp new file mode 100644 index 000000000..17aee1891 --- /dev/null +++ b/ir/functions.cpp @@ -0,0 +1,21 @@ +// Copyright (c) 2018-present The Alive2 Authors. +// Distributed under the MIT license that can be found in the LICENSE file. + +#include "ir/functions.h" + +using namespace smt; + +namespace IR { + +expr PtrInput::implies(const PtrInput &rhs) const { + return implies_attrs(rhs) && val == rhs.val && idx == rhs.idx; +} + +expr PtrInput::implies_attrs(const PtrInput &rhs) const { + return byval == rhs.byval && + rhs.noread .implies(noread) && + rhs.nowrite .implies(nowrite) && + rhs.nocapture.implies(nocapture); +} + +} diff --git a/ir/functions.h b/ir/functions.h new file mode 100644 index 000000000..447dc08ff --- /dev/null +++ b/ir/functions.h @@ -0,0 +1,33 @@ +#pragma once + +// Copyright (c) 2018-present The Alive2 Authors. +// Distributed under the MIT license that can be found in the LICENSE file. + +#include "ir/state_value.h" +#include "smt/expr.h" + +namespace IR { + +struct PtrInput { + unsigned idx = 0; + StateValue val; + smt::expr byval = false; + smt::expr noread = false; + smt::expr nowrite = false; + smt::expr nocapture = false; + + PtrInput(unsigned idx, StateValue &&val, smt::expr &&byval, + smt::expr &&noread, smt::expr &&nowrite, smt::expr &&nocapture) : + idx(idx), val(std::move(val)), byval(std::move(byval)), + noread(std::move(noread)), nowrite(std::move(nowrite)), + nocapture(std::move(nocapture)) {} + + PtrInput(const StateValue &val) : val(val) {} + PtrInput(const smt::expr &val) : val(StateValue(smt::expr(val), true)) {} + + smt::expr implies(const PtrInput &rhs) const; + smt::expr implies_attrs(const PtrInput &rhs) const; + auto operator<=>(const PtrInput &rhs) const = default; +}; + +} diff --git a/ir/instr.cpp b/ir/instr.cpp index c05fc1c48..9d3572454 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -2359,8 +2359,7 @@ static void check_can_store(State &s, const expr &p0) { static void unpack_inputs(State &s, Value &argv, Type &ty, const ParamAttrs &argflag, StateValue value, StateValue value2, vector &inputs, - vector &ptr_inputs, - unsigned idx) { + vector &ptr_inputs, unsigned idx) { if (auto agg = ty.getAsAggregateType()) { for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { if (agg->isPadding(i)) @@ -2421,7 +2420,7 @@ StateValue FnCall::toSMT(State &s) const { auto &m = s.getMemory(); vector inputs; - vector ptr_inputs; + vector ptr_inputs; unsigned indirect_hash = 0; auto ptr = fnptr; @@ -2493,7 +2492,7 @@ StateValue FnCall::toSMT(State &s) const { !attrs.has(FnAttrs::WillReturn)) s.addGuardableUB(expr(false)); - tci.checkTailCall(*this, s); + tci.check(s, *this, ptr_inputs); auto get_alloc_ptr = [&]() -> Value& { for (auto &[arg, flags] : args) { @@ -4154,7 +4153,7 @@ StateValue Memset::toSMT(State &s) const { vptr = sv_ptr.value; } check_can_store(s, vptr); - tci.checkTailCall(*this, s); + tci.check(s, *this, { vptr }); s.getMemory().memset(vptr, s[*val].zextOrTrunc(8), vbytes, align, s.getUndefVars()); @@ -4216,7 +4215,7 @@ StateValue MemsetPattern::toSMT(State &s) const { auto &vbytes = s.getAndAddPoisonUB(*bytes, true).value; check_can_store(s, vptr); check_can_load(s, vpattern); - tci.checkTailCall(*this, s); + tci.check(s, *this, { vptr }); s.getMemory().memset_pattern(vptr, vpattern, vbytes, pattern_length); return {}; @@ -4344,7 +4343,7 @@ StateValue Memcpy::toSMT(State &s) const { check_can_load(s, vsrc); check_can_store(s, vdst); - tci.checkTailCall(*this, s); + tci.check(s, *this, { vsrc, vdst }); s.getMemory().memcpy(vdst, vsrc, vbytes, align_dst, align_src, move); return {}; @@ -4396,14 +4395,16 @@ void Memcmp::print(ostream &os) const { } StateValue Memcmp::toSMT(State &s) const { - auto &[vptr1, np1] = s[*ptr1]; - auto &[vptr2, np2] = s[*ptr2]; + auto &stptr1 = s[*ptr1]; + auto &stptr2 = s[*ptr2]; + auto &[vptr1, np1] = stptr1; + auto &[vptr2, np2] = stptr2; auto &vnum = s.getAndAddPoisonUB(*num).value; s.addGuardableUB((vnum != 0).implies(np1 && np2)); check_can_load(s, vptr1); check_can_load(s, vptr2); - tci.checkTailCall(*this, s); + tci.check(s, *this, { stptr1, stptr2 }); Pointer p1(s.getMemory(), vptr1), p2(s.getMemory(), vptr2); // memcmp can be optimized to load & icmps, and it requires this @@ -4501,7 +4502,7 @@ void Strlen::print(ostream &os) const { StateValue Strlen::toSMT(State &s) const { auto &eptr = s.getWellDefinedPtr(*ptr); check_can_load(s, eptr); - tci.checkTailCall(*this, s); + tci.check(s, *this, { eptr }); Pointer p(s.getMemory(), eptr); Type &ty = getType(); diff --git a/ir/memory.cpp b/ir/memory.cpp index 6485271a8..2ba387a6e 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -1657,17 +1657,6 @@ pair Memory::mkUndefInput(const ParamAttrs &attrs0) { return { std::move(ptr).release(), std::move(undef) }; } -expr Memory::PtrInput::implies(const PtrInput &rhs) const { - return implies_attrs(rhs) && val == rhs.val && idx == rhs.idx; -} - -expr Memory::PtrInput::implies_attrs(const PtrInput &rhs) const { - return byval == rhs.byval && - rhs.noread .implies(noread) && - rhs.nowrite .implies(nowrite) && - rhs.nocapture.implies(nocapture); -} - Memory::FnRetData Memory::FnRetData::mkIf(const expr &cond, const FnRetData &a, const FnRetData &b) { return { expr::mkIf(cond, a.size, b.size), diff --git a/ir/memory.h b/ir/memory.h index c7d68db95..2d585278e 100644 --- a/ir/memory.h +++ b/ir/memory.h @@ -4,6 +4,7 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "ir/attrs.h" +#include "ir/functions.h" #include "ir/pointer.h" #include "ir/state_value.h" #include "ir/type.h" @@ -273,25 +274,6 @@ class Memory { smt::expr mkInput(const char *name, const ParamAttrs &attrs); std::pair mkUndefInput(const ParamAttrs &attrs); - struct PtrInput { - unsigned idx; - StateValue val; - smt::expr byval; - smt::expr noread; - smt::expr nowrite; - smt::expr nocapture; - - PtrInput(unsigned idx, StateValue &&val, smt::expr &&byval, - smt::expr &&noread, smt::expr &&nowrite, smt::expr &&nocapture) : - idx(idx), val(std::move(val)), byval(std::move(byval)), - noread(std::move(noread)), nowrite(std::move(nowrite)), - nocapture(std::move(nocapture)) {} - - smt::expr implies(const PtrInput &rhs) const; - smt::expr implies_attrs(const PtrInput &rhs) const; - auto operator<=>(const PtrInput &rhs) const = default; - }; - struct FnRetData { smt::expr size; smt::expr align; diff --git a/ir/state.cpp b/ir/state.cpp index 8fd14c42f..32fe55609 100644 --- a/ir/state.cpp +++ b/ir/state.cpp @@ -891,7 +891,7 @@ expr State::FnCallInput::implies(const FnCallInput &rhs) const { expr State::FnCallInput::refinedBy( State &s, const string &callee, unsigned inaccessible_bid, const vector &args_nonptr2, - const vector &args_ptr2, + const vector &args_ptr2, const ValueAnalysis::FnCallRanges &fncall_ranges2, const Memory &m2, const SMTMemoryAccess &memaccess2, bool noret2, bool willret2) const { @@ -933,7 +933,7 @@ expr State::FnCallInput::refinedBy( if (memaccess2.canReadSomething().isTrue()) { bool argmemonly = memaccess2.canOnlyRead(MemoryAccess::Args).isTrue(); - vector dummy1, dummy2; + vector dummy1, dummy2; auto restrict_ptrs = argmemonly ? &args_ptr : nullptr; auto restrict_ptrs2 = argmemonly ? &args_ptr2 : nullptr; if (memaccess2.canOnlyRead(MemoryAccess::Inaccessible).isTrue()) { @@ -1004,7 +1004,7 @@ expr State::FnCallOutput::implies(const FnCallOutput &rhs, StateValue State::addFnCall(const string &name, vector &&inputs, - vector &&ptr_inputs, + vector &&ptr_inputs, const Type &out_type, StateValue &&ret_arg, const Type *ret_arg_ty, vector &&ret_args, const FnAttrs &attrs, unsigned indirect_call_hash) { diff --git a/ir/state.h b/ir/state.h index 689390489..c72154d76 100644 --- a/ir/state.h +++ b/ir/state.h @@ -195,7 +195,7 @@ class State { struct FnCallInput { std::vector args_nonptr; - std::vector args_ptr; + std::vector args_ptr; ValueAnalysis::FnCallRanges fncall_ranges; Memory m; SMTMemoryAccess memaccess; @@ -205,7 +205,7 @@ class State { smt::expr refinedBy(State &s, const std::string &callee, unsigned inaccessible_bid, const std::vector &args_nonptr, - const std::vector &args_ptr, + const std::vector &args_ptr, const ValueAnalysis::FnCallRanges &fncall_ranges, const Memory &m, const SMTMemoryAccess &memaccess, bool noret, bool willret) const; @@ -296,7 +296,7 @@ class State { StateValue addFnCall(const std::string &name, std::vector &&inputs, - std::vector &&ptr_inputs, + std::vector &&ptr_inputs, const Type &out_type, StateValue &&ret_arg, const Type *ret_arg_ty, std::vector &&ret_args, const FnAttrs &attrs,