diff --git a/ir/instr.cpp b/ir/instr.cpp index 3fa147b97..1d3b33ba3 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -633,21 +633,39 @@ void FpBinOp::rauw(const Value &what, Value &with) { RAUW(rhs); } -void FpBinOp::print(ostream &os) const { - const char *str = nullptr; +const char* FpBinOp::getOpName() const { switch (op) { - case FAdd: str = "fadd "; break; - case FSub: str = "fsub "; break; - case FMul: str = "fmul "; break; - case FDiv: str = "fdiv "; break; - case FRem: str = "frem "; break; - case FMax: str = "fmax "; break; - case FMin: str = "fmin "; break; - case FMaximum: str = "fmaximum "; break; - case FMinimum: str = "fminimum "; break; - case CopySign: str = "copysign "; break; - } - os << getName() << " = " << str << fmath << *lhs << ", " << rhs->getName(); + case FAdd: return "fadd"; + case FSub: return "fsub"; + case FMul: return "fmul"; + case FDiv: return "fdiv"; + case FRem: return "frem"; + case FMax: return "fmax"; + case FMin: return "fmin"; + case FMaximum: return "fmaximum"; + case FMinimum: return "fminimum"; + case CopySign: return "copysign"; + } + UNREACHABLE(); +} + +bool FpBinOp::isCommutative() const { + switch (op) { + case FAdd: + case FMin: + case FMax: + case FMinimum: + case FMaximum: + return true; + default: + return false; + } + UNREACHABLE(); +} + +void FpBinOp::print(ostream &os) const { + os << getName() << " = " << getOpName() << ' ' << fmath + << *lhs << ", " << rhs->getName(); if (!rm.isDefault()) os << ", rounding=" << rm; if (!ex.ignore()) @@ -840,6 +858,60 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, ty, fmath, rm, bitwise, flags_in_only, 1); } +static StateValue uf_float(State &s, const string &name, + const vector &args, + const expr &res, + FastMathFlags fmath = FastMathFlags(), + bool is_commutative = false, + bool is_partial = false) { + + vector arg_values; + arg_values.reserve(args.size()); + for (auto &arg : args) { + arg_values.push_back(arg.value); + } + + auto value = expr::mkUF(name, arg_values, res); + s.doesApproximation("uf_float", value); + if (is_commutative) { + // Commutative functions are encoded as + // op(x, y) = op'(x, y) & op'(y, x) + // where & is the bitwise and operator and op' is an uninterpreted function. + // This encoding comes from "SMT-based Translation Validation for Machine + // Learning Compiler" by Seongwon Bang, Seunghyeon Nam, Inwhan Chun, + // Ho Young Jhoo, and Juneyoung Lee + assert(args.size() == 2); + value = value & expr::mkUF(name, {arg_values[1], arg_values[0]}, res); + } + + AndExpr non_poison; + for (auto &arg : args) { + non_poison.add(arg.non_poison); + } + + auto poison_condition = [&](const char* suffix) { + auto np_name = name + ".np_" + suffix; + auto poison_uf = expr::mkUF(np_name, arg_values, false); + s.doesApproximation("uf_float", poison_uf); + if (is_commutative) { + assert(args.size() == 2); + poison_uf &= expr::mkUF(np_name, {arg_values[1], arg_values[0]}, false); + } + non_poison.add(poison_uf); + }; + + if (fmath.flags & FastMathFlags::NNaN) + poison_condition("nnan"); + if (fmath.flags & FastMathFlags::NInf) + poison_condition("ninf"); + + // Partial functions may produce poison for inputs where they are not defined. + if (is_partial) + poison_condition("partial"); + + return { std::move(value), non_poison() }; +} + StateValue FpBinOp::toSMT(State &s) const { function fn; bool bitwise = false; @@ -914,12 +986,23 @@ StateValue FpBinOp::toSMT(State &s) const { break; } - auto scalar = [&](const auto &a, const auto &b, const Type &ty) { + function scalar = + [&](const StateValue &a, const StateValue &b, const Type &ty) { return fm_poison(s, a.value, a.non_poison, b.value, b.non_poison, [&](auto &a, auto &b, auto &rm){ return fn(a, b, rm); }, ty, fmath, rm, bitwise); }; + if (config::is_uf_float()) { + scalar = [&](const StateValue &a, const StateValue &b, + const Type &ty) -> StateValue { + ostringstream name; + name << getOpName() << '.' << ty; + return uf_float(s, std::move(name).str(), {a, b}, a.value, + fmath, isCommutative()); + }; + } + auto &a = s[*lhs]; auto &b = s[*rhs]; @@ -1107,23 +1190,25 @@ void FpUnaryOp::rauw(const Value &what, Value &with) { RAUW(val); } -void FpUnaryOp::print(ostream &os) const { - const char *str = nullptr; +const char* FpUnaryOp::getOpName() const { switch (op) { - case FAbs: str = "fabs "; break; - case FNeg: str = "fneg "; break; - case Canonicalize: str = "canonicalize "; break; - case Ceil: str = "ceil "; break; - case Floor: str = "floor "; break; - case RInt: str = "rint "; break; - case NearbyInt: str = "nearbyint "; break; - case Round: str = "round "; break; - case RoundEven: str = "roundeven "; break; - case Trunc: str = "trunc "; break; - case Sqrt: str = "sqrt "; break; - } - - os << getName() << " = " << str << fmath << *val; + case FAbs: return "fabs"; + case FNeg: return "fneg"; + case Canonicalize: return "canonicalize"; + case Ceil: return "ceil"; + case Floor: return "floor"; + case RInt: return "rint"; + case NearbyInt: return "nearbyint"; + case Round: return "round"; + case RoundEven: return "roundeven"; + case Trunc: return "trunc"; + case Sqrt: return "sqrt"; + } + UNREACHABLE(); +} + +void FpUnaryOp::print(ostream &os) const { + os << getName() << " = " << getOpName() << ' ' << fmath << *val; if (!rm.isDefault()) os << ", rounding=" << rm; if (!ex.ignore()) @@ -1171,12 +1256,21 @@ StateValue FpUnaryOp::toSMT(State &s) const { break; } - auto scalar = [&](const StateValue &v, const Type &ty) { + function scalar = + [&](const StateValue &v, const Type &ty) { return fm_poison(s, v.value, v.non_poison, [fn](auto &v, auto &rm) {return fn(v, rm);}, ty, fmath, rm, bitwise, false); }; + if (config::is_uf_float()) { + scalar = [&](const StateValue &v, const Type &ty) -> StateValue { + ostringstream name; + name << getOpName() << '.' << ty; + return uf_float(s, std::move(name).str(), {v}, v.value, fmath, false); + }; + } + auto &v = s[*val]; if (getType().isVectorType()) { @@ -1413,14 +1507,17 @@ void FpTernaryOp::rauw(const Value &what, Value &with) { RAUW(c); } -void FpTernaryOp::print(ostream &os) const { - const char *str = nullptr; +const char* FpTernaryOp::getOpName() const { switch (op) { - case FMA: str = "fma "; break; - case MulAdd: str = "fmuladd "; break; + case FMA: return "fma"; + case MulAdd: return "fmuladd"; } + UNREACHABLE(); +} - os << getName() << " = " << str << fmath << *a << ", " << *b << ", " << *c; +void FpTernaryOp::print(ostream &os) const { + os << getName() << " = " << getOpName() << ' ' << fmath + << *a << ", " << *b << ", " << *c; if (!rm.isDefault()) os << ", rounding=" << rm; if (!ex.ignore()) @@ -1444,12 +1541,23 @@ StateValue FpTernaryOp::toSMT(State &s) const { break; } - auto scalar = [&](const StateValue &a, const StateValue &b, - const StateValue &c, const Type &ty) { + function scalar = + [&](const StateValue &a, const StateValue &b, + const StateValue &c, const Type &ty) { return fm_poison(s, a.value, a.non_poison, b.value, b.non_poison, c.value, c.non_poison, fn, ty, fmath, rm, false); }; + if (config::is_uf_float()) { + scalar = [&](const StateValue &a, const StateValue &b, + const StateValue &c, const Type &ty) -> StateValue { + ostringstream name; + name << getOpName() << '.' << ty; + return uf_float(s, std::move(name).str(), {a, b, c}, a.value, fmath); + }; + } + auto &av = s[*a]; auto &bv = s[*b]; auto &cv = s[*c]; @@ -1498,13 +1606,15 @@ void TestOp::rauw(const Value &what, Value &with) { RAUW(rhs); } -void TestOp::print(ostream &os) const { - const char *str = nullptr; +const char* TestOp::getOpName() const { switch (op) { - case Is_FPClass: str = "is.fpclass "; break; + case Is_FPClass: return "is.fpclass"; } + UNREACHABLE(); +} - os << getName() << " = " << str << *lhs << ", " << *rhs; +void TestOp::print(ostream &os) const { + os << getName() << " = " << getOpName() << ' ' << *lhs << ", " << *rhs; } StateValue TestOp::toSMT(State &s) const { @@ -1522,10 +1632,19 @@ StateValue TestOp::toSMT(State &s) const { break; } - auto scalar = [&](const StateValue &v, const Type &ty) -> StateValue { + function scalar = + [&](const StateValue &v, const Type &ty) -> StateValue { return { fn(v.value, ty), expr(v.non_poison) }; }; + if (config::is_uf_float()) { + scalar = [&](const StateValue &v, const Type &ty) -> StateValue { + ostringstream name; + name << getOpName() << '.' << ty; + return uf_float(s, std::move(name).str(), {v}, expr::mkUInt(0, 1)); + }; + } + if (getType().isVectorType()) { vector vals; auto ty = lhs->getType().getAsAggregateType(); @@ -1745,20 +1864,22 @@ void FpConversionOp::rauw(const Value &what, Value &with) { RAUW(val); } -void FpConversionOp::print(ostream &os) const { - const char *str = nullptr; +const char* FpConversionOp::getOpName() const { switch (op) { - case SIntToFP: str = "sitofp "; break; - case UIntToFP: str = "uitofp "; break; - case FPToSInt: str = "fptosi "; break; - case FPToUInt: str = "fptoui "; break; - case FPExt: str = "fpext "; break; - case FPTrunc: str = "fptrunc "; break; - case LRInt: str = "lrint "; break; - case LRound: str = "lround "; break; + case SIntToFP: return "sitofp"; + case UIntToFP: return "uitofp"; + case FPToSInt: return "fptosi"; + case FPToUInt: return "fptoui"; + case FPExt: return "fpext"; + case FPTrunc: return "fptrunc"; + case LRInt: return "lrint"; + case LRound: return "lround"; } + UNREACHABLE(); +} - os << getName() << " = " << str; +void FpConversionOp::print(ostream &os) const { + os << getName() << " = " << getOpName() << ' '; if (flags & NNEG) os << "nneg "; os << *val << print_type(getType(), " to ", ""); @@ -1840,8 +1961,9 @@ StateValue FpConversionOp::toSMT(State &s) const { break; } - auto scalar = [&](const StateValue &sv, const Type &from_type, - const Type &to_type) -> StateValue { + function scalar = + [&](const StateValue &sv, const Type &from_type, + const Type &to_type) -> StateValue { auto val = sv.value; if (from_type.isFloatType()) { @@ -1865,6 +1987,21 @@ StateValue FpConversionOp::toSMT(State &s) const { : std::move(ret.value), np()}; }; + if (config::is_uf_float()) { + scalar = [&](const StateValue &sv, const Type &from_type, + const Type &to_type) -> StateValue { + ostringstream name; + name << getOpName() << '.' << from_type << ".to." << to_type; + expr range = to_type.getDummyValue(true).value; + bool is_partial = (op == UIntToFP && (flags & NNEG)) || + op == FPToSInt || + op == FPToUInt; + + return uf_float(s, std::move(name).str(), {sv}, range, + FastMathFlags(), false, is_partial); + }; + } + if (getType().isVectorType()) { vector vals; auto ty = val->getType().getAsAggregateType(); @@ -2801,7 +2938,8 @@ StateValue FCmp::toSMT(State &s) const { auto &a_eval = s[*a]; auto &b_eval = s[*b]; - auto fn = [&](const auto &a, const auto &b, const Type &ty) -> StateValue { + function fn = + [&](const auto &a, const auto &b, const Type &ty) -> StateValue { auto cmp = [&](const expr &a, const expr &b, auto &rm) { switch (cond) { case OEQ: return a.foeq(b); @@ -2827,6 +2965,53 @@ StateValue FCmp::toSMT(State &s) const { return { val.toBVBool(), std::move(np) }; }; + if (config::is_uf_float()) { + fn = [&](const auto &a, const auto &b, const Type &ty) -> StateValue { + switch (cond) { + case TRUE: return {true, true}; + case FALSE: return {false, true}; + default: { + // All conditions are encoded using only 5 uninterpreted functions: + // oeq, ueq, olt, ult, ord + + StateValue lhs = a; + StateValue rhs = b; + + const char *name = nullptr; + bool negate = false; + bool commutative = false; + switch (cond) { + case OEQ: name = "oeq"; commutative = true; break; + case OGT: name = "olt"; std::swap(lhs, rhs); break; + case OGE: name = "ult"; negate = true; break; + case OLT: name = "olt"; break; + case OLE: name = "ult"; std::swap(lhs, rhs); negate = true; break; + case ONE: name = "ueq"; negate = true; commutative = true; break; + case ORD: name = "ord"; commutative = true; break; + case UEQ: name = "ueq"; commutative = true; break; + case UGT: name = "olt"; std::swap(lhs, rhs); break; + case UGE: name = "ult"; negate = true; break; + case ULT: name = "ult"; break; + case ULE: name = "olt"; negate = true; std::swap(lhs, rhs); break; + case UNE: name = "oeq"; negate = true; commutative = true; break; + case UNO: name = "ord"; negate = true; commutative = true; break; + default: UNREACHABLE(); + } + + ostringstream os; + os << name << '.' << ty; + auto value = uf_float(s, std::move(os).str(), {lhs, rhs}, + expr::mkUInt(0, 1), fmath, commutative); + + if (negate) { + value.value = ~value.value; + } + return value; + } + } + }; + } + if (auto agg = a->getType().getAsAggregateType()) { vector vals; for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { diff --git a/ir/instr.h b/ir/instr.h index 54029a5cb..a1d426d44 100644 --- a/ir/instr.h +++ b/ir/instr.h @@ -89,6 +89,8 @@ class FpBinOp final : public Instr { bool propagatesPoison() const override; bool hasSideEffects() const override; void rauw(const Value &what, Value &with) override; + const char* getOpName() const; + bool isCommutative() const; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; @@ -155,6 +157,7 @@ class FpUnaryOp final : public Instr { bool propagatesPoison() const override; bool hasSideEffects() const override; void rauw(const Value &what, Value &with) override; + const char* getOpName() const; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; @@ -244,6 +247,7 @@ class FpTernaryOp final : public Instr { bool propagatesPoison() const override; bool hasSideEffects() const override; void rauw(const Value &what, Value &with) override; + const char* getOpName() const; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; @@ -269,6 +273,7 @@ class TestOp final : public Instr { bool propagatesPoison() const override; bool hasSideEffects() const override; void rauw(const Value &what, Value &with) override; + const char* getOpName() const; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; @@ -333,6 +338,7 @@ class FpConversionOp final : public Instr { bool propagatesPoison() const override; bool hasSideEffects() const override; void rauw(const Value &what, Value &with) override; + const char* getOpName() const; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; diff --git a/llvm_util/cmd_args_def.h b/llvm_util/cmd_args_def.h index f1c473df5..04e1798aa 100644 --- a/llvm_util/cmd_args_def.h +++ b/llvm_util/cmd_args_def.h @@ -29,6 +29,12 @@ if ((config::disallow_ub_exploitation = opt_disallow_ub_exploitation)) { config::disable_poison_input = true; } +if (opt_uf_float) { + config::fp_encoding_mode = config::FpEncodingMode::UninterpretedFunctions; +} else { + config::fp_encoding_mode = config::FpEncodingMode::FloatingPoint; +} + func_names.insert(opt_funcs.begin(), opt_funcs.end()); if (!report_dir_created && !opt_report_dir.empty()) { diff --git a/llvm_util/cmd_args_list.h b/llvm_util/cmd_args_list.h index b732ec92a..0ca0d5786 100644 --- a/llvm_util/cmd_args_list.h +++ b/llvm_util/cmd_args_list.h @@ -185,4 +185,9 @@ llvm::cl::opt opt_disallow_ub_exploitation( llvm::cl::desc("Disallow UB exploitation by optimizations (default=allow)"), llvm::cl::init(false), llvm::cl::cat(alive_cmdargs)); +llvm::cl::opt opt_uf_float( + LLVM_ARGS_PREFIX "uf-float", + llvm::cl::desc("Approximate floating point operations as uninterpreted functions"), + llvm::cl::init(false), llvm::cl::cat(alive_cmdargs)); + } diff --git a/tests/alive-tv/uf-float/add-assoc-fail.srctgt.ll b/tests/alive-tv/uf-float/add-assoc-fail.srctgt.ll new file mode 100644 index 000000000..33a3d5323 --- /dev/null +++ b/tests/alive-tv/uf-float/add-assoc-fail.srctgt.ll @@ -0,0 +1,14 @@ +; TEST-ARGS: --uf-float +; ERROR: Couldn't prove the correctness of the transformation + +define float @src(float noundef %x, float noundef %y, float noundef %z) { + %a = fadd float %x, %y + %b = fadd float %a, %z + ret float %b +} + +define float @tgt(float noundef %x, float noundef %y, float noundef %z) { + %a = fadd float %y, %z + %b = fadd float %x, %a + ret float %b +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm-double.srctgt.ll b/tests/alive-tv/uf-float/add-comm-double.srctgt.ll new file mode 100644 index 000000000..4981661a4 --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm-double.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define double @src(double noundef %x, double noundef %y) { + %sum = fadd double %x, %y + ret double %sum +} + +define double @tgt(double noundef %x, double noundef %y) { + %sum = fadd double %y, %x + ret double %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm-fp128.srctgt.ll b/tests/alive-tv/uf-float/add-comm-fp128.srctgt.ll new file mode 100644 index 000000000..deb6c3213 --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm-fp128.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define fp128 @src(fp128 noundef %x, fp128 noundef %y) { + %sum = fadd fp128 %x, %y + ret fp128 %sum +} + +define fp128 @tgt(fp128 noundef %x, fp128 noundef %y) { + %sum = fadd fp128 %y, %x + ret fp128 %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm-half.srctgt.ll b/tests/alive-tv/uf-float/add-comm-half.srctgt.ll new file mode 100644 index 000000000..731eae0d5 --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm-half.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define half @src(half noundef %x, half noundef %y) { + %sum = fadd half %x, %y + ret half %sum +} + +define half @tgt(half noundef %x, half noundef %y) { + %sum = fadd half %y, %x + ret half %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm-ninf.srctgt.ll b/tests/alive-tv/uf-float/add-comm-ninf.srctgt.ll new file mode 100644 index 000000000..c52c7779c --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm-ninf.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x, float noundef %y) { + %sum = fadd ninf float %x, %y + ret float %sum +} + +define float @tgt(float noundef %x, float noundef %y) { + %sum = fadd ninf float %y, %x + ret float %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm-nnan.srctgt.ll b/tests/alive-tv/uf-float/add-comm-nnan.srctgt.ll new file mode 100644 index 000000000..50d7a7c7e --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm-nnan.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x, float noundef %y) { + %sum = fadd nnan float %x, %y + ret float %sum +} + +define float @tgt(float noundef %x, float noundef %y) { + %sum = fadd nnan float %y, %x + ret float %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/add-comm.srctgt.ll b/tests/alive-tv/uf-float/add-comm.srctgt.ll new file mode 100644 index 000000000..7f9fa9728 --- /dev/null +++ b/tests/alive-tv/uf-float/add-comm.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x, float noundef %y) { + %sum = fadd float %x, %y + ret float %sum +} + +define float @tgt(float noundef %x, float noundef %y) { + %sum = fadd float %y, %x + ret float %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-eq-comm.srctgt.ll b/tests/alive-tv/uf-float/cmp-eq-comm.srctgt.ll new file mode 100644 index 000000000..573ade1a4 --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-eq-comm.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp oeq float %x, %y + ret i1 %cmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp oeq float %y, %x + ret i1 %cmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-eq-ne.srctgt.ll b/tests/alive-tv/uf-float/cmp-eq-ne.srctgt.ll new file mode 100644 index 000000000..779b655e0 --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-eq-ne.srctgt.ll @@ -0,0 +1,12 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp oeq float %x, %y + ret i1 %cmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp une float %x, %y + %notcmp = xor i1 %cmp, 1 + ret i1 %notcmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-lt-gt.srctgt.ll b/tests/alive-tv/uf-float/cmp-lt-gt.srctgt.ll new file mode 100644 index 000000000..4b3ccdbbb --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-lt-gt.srctgt.ll @@ -0,0 +1,11 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp olt float %x, %y + ret i1 %cmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp ogt float %y, %x + ret i1 %cmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-lt-le.srctgt.ll b/tests/alive-tv/uf-float/cmp-lt-le.srctgt.ll new file mode 100644 index 000000000..ad124452e --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-lt-le.srctgt.ll @@ -0,0 +1,12 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp olt float %x, %y + ret i1 %cmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp ule float %y, %x + %notcmp = xor i1 %cmp, 1 + ret i1 %notcmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-olt-ult-fail.srctgt.ll b/tests/alive-tv/uf-float/cmp-olt-ult-fail.srctgt.ll new file mode 100644 index 000000000..b9f92c905 --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-olt-ult-fail.srctgt.ll @@ -0,0 +1,12 @@ +; TEST-ARGS: --uf-float +; ERROR: Couldn't prove the correctness of the transformation + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp olt float %x, %y + ret i1 %cmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp ult float %x, %y + ret i1 %cmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/cmp-ord-uno.srctgt.ll b/tests/alive-tv/uf-float/cmp-ord-uno.srctgt.ll new file mode 100644 index 000000000..9e0d7a0fd --- /dev/null +++ b/tests/alive-tv/uf-float/cmp-ord-uno.srctgt.ll @@ -0,0 +1,12 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %cmp = fcmp ord float %x, %y + %notcmp = xor i1 %cmp, 1 + ret i1 %notcmp +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp uno float %y, %x + ret i1 %cmp +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/load-bitcast.srctgt.ll b/tests/alive-tv/uf-float/load-bitcast.srctgt.ll new file mode 100644 index 000000000..8ec87e261 --- /dev/null +++ b/tests/alive-tv/uf-float/load-bitcast.srctgt.ll @@ -0,0 +1,13 @@ +; TEST-ARGS: --uf-float + +define i32 @src(float noundef %x) { + %a = bitcast float %x to i32 + ret i32 %a +} + +define i32 @tgt(float noundef %x) { + %a = alloca float + store float %x, ptr %a + %b = load i32, ptr %a + ret i32 %b +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/load-store.srctgt.ll b/tests/alive-tv/uf-float/load-store.srctgt.ll new file mode 100644 index 000000000..1c397c404 --- /dev/null +++ b/tests/alive-tv/uf-float/load-store.srctgt.ll @@ -0,0 +1,12 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x) { + ret float %x +} + +define float @tgt(float noundef %x) { + %a = alloca float + store float %x, ptr %a + %b = load float, ptr %a + ret float %b +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/select.srctgt.ll b/tests/alive-tv/uf-float/select.srctgt.ll new file mode 100644 index 000000000..2dcc46005 --- /dev/null +++ b/tests/alive-tv/uf-float/select.srctgt.ll @@ -0,0 +1,17 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x, float noundef %y) { + %cmp = fcmp olt float %x, %y + %a = fsub float %y, %x + %b = fsub float %x, %y + %c = select i1 %cmp, float %a, float %b + ret float %c +} + +define float @tgt(float noundef %x, float noundef %y) { + %cmp = fcmp olt float %x, %y + %min = select i1 %cmp, float %x, float %y + %max = select i1 %cmp, float %y, float %x + %res = fsub float %max, %min + ret float %res +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/vector-binop.srctgt.ll b/tests/alive-tv/uf-float/vector-binop.srctgt.ll new file mode 100644 index 000000000..98c599077 --- /dev/null +++ b/tests/alive-tv/uf-float/vector-binop.srctgt.ll @@ -0,0 +1,14 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x, float noundef %y) { + %sum = fadd float %x, %y + ret float %sum +} + +define float @tgt(float noundef %x, float noundef %y) { + %vec1 = insertelement <2 x float> poison, float %x, i32 0 + %vec2 = insertelement <2 x float> poison, float %y, i32 0 + %vec3 = fadd <2 x float> %vec1, %vec2 + %sum = extractelement <2 x float> %vec3, i32 0 + ret float %sum +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/vector-cmp.srctgt.ll b/tests/alive-tv/uf-float/vector-cmp.srctgt.ll new file mode 100644 index 000000000..ddc62145b --- /dev/null +++ b/tests/alive-tv/uf-float/vector-cmp.srctgt.ll @@ -0,0 +1,14 @@ +; TEST-ARGS: --uf-float + +define i1 @src(float noundef %x, float noundef %y) { + %res = fcmp olt float %x, %y + ret i1 %res +} + +define i1 @tgt(float noundef %x, float noundef %y) { + %vec1 = insertelement <2 x float> poison, float %x, i32 0 + %vec2 = insertelement <2 x float> poison, float %y, i32 0 + %vec3 = fcmp olt <2 x float> %vec1, %vec2 + %res = extractelement <2 x i1> %vec3, i32 0 + ret i1 %res +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/vector-convert.srctgt.ll b/tests/alive-tv/uf-float/vector-convert.srctgt.ll new file mode 100644 index 000000000..05636d146 --- /dev/null +++ b/tests/alive-tv/uf-float/vector-convert.srctgt.ll @@ -0,0 +1,15 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x) { + %a = fptosi float %x to i32 + %res = sitofp i32 %a to float + ret float %res +} + +define float @tgt(float noundef %x) { + %vec1 = insertelement <4 x float> poison, float %x, i32 0 + %vec2 = fptosi <4 x float> %vec1 to <4 x i32> + %vec3 = sitofp <4 x i32> %vec2 to <4 x float> + %res = extractelement <4 x float> %vec3, i32 0 + ret float %res +} \ No newline at end of file diff --git a/tests/alive-tv/uf-float/vector-unop.srctgt.ll b/tests/alive-tv/uf-float/vector-unop.srctgt.ll new file mode 100644 index 000000000..09e70791d --- /dev/null +++ b/tests/alive-tv/uf-float/vector-unop.srctgt.ll @@ -0,0 +1,13 @@ +; TEST-ARGS: --uf-float + +define float @src(float noundef %x) { + %res = fneg float %x + ret float %res +} + +define float @tgt(float noundef %x) { + %vec1 = insertelement <2 x float> poison, float %x, i32 0 + %vec2 = fneg <2 x float> %vec1 + %res = extractelement <2 x float> %vec2, i32 0 + ret float %res +} \ No newline at end of file diff --git a/util/config.cpp b/util/config.cpp index 82ab5f66a..35b4dcb78 100644 --- a/util/config.cpp +++ b/util/config.cpp @@ -23,6 +23,7 @@ unsigned src_unroll_cnt = 0; unsigned tgt_unroll_cnt = 0; unsigned max_offset_bits = 64; unsigned max_sizet_bits = 64; +FpEncodingMode fp_encoding_mode = FpEncodingMode::FloatingPoint; ostream &dbg() { return *debug_os; @@ -32,4 +33,8 @@ void set_debug(ostream &os) { debug_os = &os; } +bool is_uf_float() { + return fp_encoding_mode == FpEncodingMode::UninterpretedFunctions; +} + } diff --git a/util/config.h b/util/config.h index 57f58c888..3a6488e15 100644 --- a/util/config.h +++ b/util/config.h @@ -44,7 +44,15 @@ extern unsigned max_offset_bits; // size and size of pointers (not to be confused with program pointer size). extern unsigned max_sizet_bits; +enum FpEncodingMode { + FloatingPoint, + UninterpretedFunctions +}; + +extern FpEncodingMode fp_encoding_mode; + std::ostream &dbg(); void set_debug(std::ostream &os); +bool is_uf_float(); }