From 0aa59bd45b4ca056e2131dfadf358e4d286e636e Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Sat, 7 Dec 2024 18:11:36 +0000 Subject: [PATCH] ir: add syntactic support for vscale, vscale_range In an effort to support checking of programs with scalable vectors, add syntactic support for vscale and vscale_range, making the vscale an smt::expr in State. For the moment, we fail to type-check any program using scalable vectors. --- ir/attrs.cpp | 4 ++ ir/attrs.h | 2 + ir/constant.cpp | 22 +++--- ir/constant.h | 10 +-- ir/function.cpp | 2 +- ir/instr.cpp | 56 +++++++-------- ir/instr.h | 3 +- ir/memory.h | 2 +- ir/precondition.cpp | 10 +-- ir/precondition.h | 6 +- ir/state.cpp | 15 +++- ir/state.h | 6 ++ ir/type.cpp | 71 +++++++++++++------ ir/type.h | 24 ++++--- ir/value.cpp | 10 +-- ir/value.h | 5 +- llvm_util/llvm2alive.cpp | 7 ++ llvm_util/utils.cpp | 14 +++- smt/expr.h | 6 ++ .../vscale/dse-scalable-fixed-neg.srctgt.ll | 15 ++++ .../vscale/dse-scalable-fixed.srctgt.ll | 15 ++++ .../dse-scalable-scalable-neg.srctgt.ll | 15 ++++ .../vscale/dse-scalable-scalable.srctgt.ll | 15 ++++ .../vector/vscale/inbounds-poison.srctgt.ll | 12 ++++ .../insert-extract-constvscale.srctgt.ll | 13 ++++ .../vector/vscale/insert-extract.srctgt.ll | 13 ++++ .../vscale/out-of-bounds-poison.srctgt.ll | 12 ++++ .../vscale/poison-constvscale.srctgt.ll | 16 +++++ .../vector/vscale/rem-constvscale.srctgt.ll | 16 +++++ tests/alive-tv/vector/vscale/rem.srctgt.ll | 16 +++++ .../typecheck-missing-vscale-range.srctgt.ll | 12 ++++ .../typecheck-scalable-non-scalable.srctgt.ll | 12 ++++ tools/transform.cpp | 2 +- 33 files changed, 361 insertions(+), 98 deletions(-) create mode 100644 tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/insert-extract.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/rem.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll create mode 100644 tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll diff --git a/ir/attrs.cpp b/ir/attrs.cpp index ee1cdd5e7..d9e1184a8 100644 --- a/ir/attrs.cpp +++ b/ir/attrs.cpp @@ -148,6 +148,10 @@ ostream& operator<<(ostream &os, const FnAttrs &attr) { os << ", " << attr.allocsize_1; os << ')'; } + if (attr.vscaleRange) { + auto [low, high] = *attr.vscaleRange; + os << " vscale_range(" << low << ", " << high << ')'; + } attr.fp_denormal.print(os); if (attr.fp_denormal32) diff --git a/ir/attrs.h b/ir/attrs.h index dbb56826b..f451cfd84 100644 --- a/ir/attrs.h +++ b/ir/attrs.h @@ -137,6 +137,8 @@ class FnAttrs final { AllocSize = 1 << 12, ZeroExt = 1<<13, SignExt = 1<<14, NoFPClass = 1<<15, Asm = 1<<16 }; + std::optional> vscaleRange; + FnAttrs(unsigned bits = None) : bits(bits) {} bool has(Attribute a) const { return (bits & a) != 0; } diff --git a/ir/constant.cpp b/ir/constant.cpp index 5a3dee67c..1dba7b9b2 100644 --- a/ir/constant.cpp +++ b/ir/constant.cpp @@ -35,12 +35,12 @@ StateValue IntConst::toSMT(State &s) const { return { expr::mkInt(get(val).c_str(), bits()), true }; } -expr IntConst::getTypeConstraints() const { +expr IntConst::getTypeConstraints(const Function &f) const { unsigned min_bits = 0; if (auto v = get_if(&val)) min_bits = (*v >= 0 ? 63 : 64) - num_sign_bits(*v); - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntType() && getType().sizeVar().uge(min_bits); } @@ -86,8 +86,8 @@ FloatConst::FloatConst(Type &type, string val, bool bit_value) : Constant(type, bit_value ? int_to_readable_float(type, val) : val), val(std::move(val)), bit_value(bit_value) {} -expr FloatConst::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr FloatConst::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && getType().enforceFloatType(); } @@ -108,8 +108,8 @@ StateValue ConstantInput::toSMT(State &s) const { return { expr::mkVar(getName().c_str(), type), true }; } -expr ConstantInput::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr ConstantInput::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && (getType().enforceIntType() || getType().enforceFloatType()); } @@ -157,8 +157,8 @@ StateValue ConstantBinOp::toSMT(State &s) const { return { std::move(val), ap && bp }; } -expr ConstantBinOp::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr ConstantBinOp::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && getType().enforceIntType() && getType() == lhs.getType() && getType() == rhs.getType(); @@ -210,10 +210,10 @@ StateValue ConstantFn::toSMT(State &s) const { return { std::move(r), true }; } -expr ConstantFn::getTypeConstraints() const { - expr r = Value::getTypeConstraints(); +expr ConstantFn::getTypeConstraints(const Function &f) const { + expr r = Value::getTypeConstraints(f); for (auto a : args) { - r &= a->getTypeConstraints(); + r &= a->getTypeConstraints(f); } Type &ty = getType(); diff --git a/ir/constant.h b/ir/constant.h index b7a2f014c..94e562fd7 100644 --- a/ir/constant.h +++ b/ir/constant.h @@ -26,7 +26,7 @@ class IntConst final : public Constant { IntConst(Type &type, int64_t val); IntConst(Type &type, std::string &&val); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; auto getInt() const { return std::get_if(&val); } }; @@ -38,7 +38,7 @@ class FloatConst final : public Constant { FloatConst(Type &type, std::string val, bool bit_value); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -47,7 +47,7 @@ class ConstantInput final : public Constant { ConstantInput(Type &type, std::string &&name) : Constant(type, std::move(name)) {} StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -62,7 +62,7 @@ class ConstantBinOp final : public Constant { public: ConstantBinOp(Type &type, Constant &lhs, Constant &rhs, Op op); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -73,7 +73,7 @@ class ConstantFn final : public Constant { public: ConstantFn(Type &type, std::string_view name, std::vector &&args); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; struct ConstantFnException { diff --git a/ir/function.cpp b/ir/function.cpp index b5a55990f..988ea0507 100644 --- a/ir/function.cpp +++ b/ir/function.cpp @@ -165,7 +165,7 @@ expr Function::getTypeConstraints() const { } for (auto &l : { getConstants(), getInputs(), getUndefs() }) { for (auto &v : l) { - t &= v.getTypeConstraints(); + t &= v.getTypeConstraints(*this); } } return t; diff --git a/ir/instr.cpp b/ir/instr.cpp index 2c9b97aa0..502e999d3 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -102,7 +102,7 @@ uint64_t getGlobalVarSize(const IR::Value *V) { namespace IR { -expr Instr::getTypeConstraints() const { +expr Instr::getTypeConstraints(const Function &f) const { UNREACHABLE(); return {}; } @@ -596,7 +596,7 @@ expr BinOp::getTypeConstraints(const Function &f) const { getType() == rhs->getType(); break; } - return Value::getTypeConstraints() && std::move(instrconstr); + return Value::getTypeConstraints(f) && std::move(instrconstr); } unique_ptr BinOp::dup(Function &f, const string &suffix) const { @@ -958,7 +958,7 @@ StateValue FpBinOp::toSMT(State &s) const { } expr FpBinOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceFloatOrVectorType() && getType() == lhs->getType() && getType() == rhs->getType(); @@ -1086,7 +1086,7 @@ expr UnaryOp::getTypeConstraints(const Function &f) const { break; } - return Value::getTypeConstraints() && std::move(instrconstr); + return Value::getTypeConstraints(f) && std::move(instrconstr); } static Value* dup_aggregate(Function &f, Value *val) { @@ -1213,7 +1213,7 @@ StateValue FpUnaryOp::toSMT(State &s) const { } expr FpUnaryOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType() && getType().enforceFloatOrVectorType(); } @@ -1286,7 +1286,7 @@ StateValue UnaryReductionOp::toSMT(State &s) const { } expr UnaryReductionOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntType() && val->getType().enforceVectorType( [this](auto &scalar) { return scalar == getType(); }); @@ -1405,7 +1405,7 @@ expr TernaryOp::getTypeConstraints(const Function &f) const { getType().enforceIntOrVectorType(); break; } - return Value::getTypeConstraints() && instrconstr; + return Value::getTypeConstraints(f) && instrconstr; } unique_ptr TernaryOp::dup(Function &f, const string &suffix) const { @@ -1486,7 +1486,7 @@ StateValue FpTernaryOp::toSMT(State &s) const { } expr FpTernaryOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == a->getType() && getType() == b->getType() && getType() == c->getType() && @@ -1557,7 +1557,7 @@ StateValue TestOp::toSMT(State &s) const { } expr TestOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && lhs->getType().enforceFloatOrVectorType() && rhs->getType().enforceIntType(32) && getType().enforceIntOrVectorType(1) && @@ -1721,7 +1721,7 @@ expr ConversionOp::getTypeConstraints(const Function &f) const { break; } - c &= Value::getTypeConstraints(); + c &= Value::getTypeConstraints(f); if (op != BitCast) c &= getType().enforceVectorTypeEquiv(val->getType()); return c; @@ -1965,7 +1965,7 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const { val->getType().scalarSize().ugt(getType().scalarSize()); break; } - return Value::getTypeConstraints() && c; + return Value::getTypeConstraints(f) && c; } unique_ptr FpConversionOp::dup(Function &f, const string &suffix) const { @@ -2027,7 +2027,7 @@ StateValue Select::toSMT(State &s) const { } expr Select::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && cond->getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeIff(cond->getType()) && (fmath.isNone() ? expr(true) : getType().enforceFloatOrVectorType()) && @@ -2080,7 +2080,7 @@ StateValue ExtractValue::toSMT(State &s) const { } expr ExtractValue::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && val->getType().enforceAggregateType(); Type *type = &val->getType(); @@ -2172,7 +2172,7 @@ StateValue InsertValue::toSMT(State &s) const { } expr InsertValue::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && val->getType().enforceAggregateType() && val->getType() == getType(); @@ -2646,7 +2646,7 @@ StateValue FnCall::toSMT(State &s) const { expr FnCall::getTypeConstraints(const Function &f) const { // TODO : also need to name each arg type smt var uniquely - expr ret = Value::getTypeConstraints(); + expr ret = Value::getTypeConstraints(f); if (fnptr) ret &= fnptr->getType().enforcePtrType(); return ret; @@ -2809,7 +2809,7 @@ StateValue ICmp::toSMT(State &s) const { } expr ICmp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeEquiv(a->getType()) && a->getType().enforceIntOrPtrOrVectorType() && @@ -2908,7 +2908,7 @@ StateValue FCmp::toSMT(State &s) const { } expr FCmp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeEquiv(a->getType()) && a->getType().enforceFloatOrVectorType() && @@ -2968,7 +2968,7 @@ StateValue Freeze::toSMT(State &s) const { } expr Freeze::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType(); } @@ -3080,7 +3080,7 @@ StateValue Phi::toSMT(State &s) const { } expr Phi::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints(); + auto c = Value::getTypeConstraints(f); for (auto &[val, bb] : values) { c &= val->getType() == getType(); } @@ -3324,7 +3324,7 @@ StateValue Return::toSMT(State &s) const { } expr Return::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType() && f.getType() == getType(); } @@ -3711,7 +3711,7 @@ StateValue Alloc::toSMT(State &s) const { } expr Alloc::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforcePtrType() && size->getType().enforceIntType(); } @@ -3967,7 +3967,7 @@ StateValue GEP::toSMT(State &s) const { } expr GEP::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && getType().enforceVectorTypeIff(ptr->getType()) && getType().enforcePtrOrVectorType(); for (auto &[sz, idx] : idxs) { @@ -4052,7 +4052,7 @@ StateValue PtrMask::toSMT(State &s) const { } expr PtrMask::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrOrVectorType() && getType() == ptr->getType() && mask->getType().enforceIntOrVectorType() && @@ -4101,7 +4101,7 @@ StateValue Load::toSMT(State &s) const { } expr Load::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrType(); } @@ -4583,7 +4583,7 @@ StateValue Strlen::toSMT(State &s) const { } expr Strlen::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrType() && getType().enforceIntType(); } @@ -4886,7 +4886,7 @@ StateValue ExtractElement::toSMT(State &s) const { } expr ExtractElement::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && v->getType().enforceVectorType([&](auto &ty) { return ty == getType(); }) && idx->getType().enforceIntType(); @@ -4929,7 +4929,7 @@ StateValue InsertElement::toSMT(State &s) const { } expr InsertElement::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == v->getType() && v->getType().enforceVectorType([&](auto &ty) { return ty == e->getType(); }) && @@ -4984,7 +4984,7 @@ StateValue ShuffleVector::toSMT(State &s) const { } expr ShuffleVector::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceVectorTypeSameChildTy(v1->getType()) && getType().getAsAggregateType()->numElements() == mask.size() && v1->getType().enforceVectorType() && diff --git a/ir/instr.h b/ir/instr.h index 53efe0b75..768f23c94 100644 --- a/ir/instr.h +++ b/ir/instr.h @@ -23,8 +23,7 @@ class Instr : public Value { virtual bool propagatesPoison() const = 0; virtual bool hasSideEffects() const = 0; virtual bool isTerminator() const; - smt::expr getTypeConstraints() const override; - virtual smt::expr getTypeConstraints(const Function &f) const = 0; + smt::expr getTypeConstraints(const Function &f) const override; virtual std::unique_ptr dup(Function &f, const std::string &suffix) const = 0; }; diff --git a/ir/memory.h b/ir/memory.h index f41857627..6efcfccbd 100644 --- a/ir/memory.h +++ b/ir/memory.h @@ -7,11 +7,11 @@ #include "ir/functions.h" #include "ir/pointer.h" #include "ir/state_value.h" -#include "ir/type.h" #include "smt/expr.h" #include "smt/exprs.h" #include "util/spaceship.h" #include +#include #include #include #include diff --git a/ir/precondition.cpp b/ir/precondition.cpp index 7b66b2a47..0a9959791 100644 --- a/ir/precondition.cpp +++ b/ir/precondition.cpp @@ -14,7 +14,7 @@ using namespace util; namespace IR { -expr Predicate::getTypeConstraints() const { +expr Predicate::getTypeConstraints(const Function &f) const { return true; } @@ -145,10 +145,10 @@ expr FnPred::toSMT(State &s) const { return r; } -expr FnPred::getTypeConstraints() const { +expr FnPred::getTypeConstraints(const Function &f) const { expr r(true); for (auto a : args) { - r &= a->getTypeConstraints(); + r &= a->getTypeConstraints(f); } switch (fn) { case AddNSW: @@ -210,8 +210,8 @@ expr CmpPred::toSMT(State &s) const { return { ap && bp && std::move(r) }; } -expr CmpPred::getTypeConstraints() const { - return lhs.getTypeConstraints() && +expr CmpPred::getTypeConstraints(const Function &f) const { + return lhs.getTypeConstraints(f) && lhs.getType().enforceIntType() && lhs.getType() == rhs.getType(); } diff --git a/ir/precondition.h b/ir/precondition.h index 808d726bd..7e3877a86 100644 --- a/ir/precondition.h +++ b/ir/precondition.h @@ -17,7 +17,7 @@ class Predicate { public: virtual void print(std::ostream &os) const = 0; virtual smt::expr toSMT(State &s) const = 0; - virtual smt::expr getTypeConstraints() const; + virtual smt::expr getTypeConstraints(const Function &f) const; virtual void fixupTypes(const smt::Model &m); virtual ~Predicate() {} }; @@ -49,7 +49,7 @@ class FnPred final : public Predicate { FnPred(std::string_view name, std::vector &&args); void print(std::ostream &os) const override; smt::expr toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixupTypes(const smt::Model &m) override; }; @@ -73,7 +73,7 @@ class CmpPred final : public Predicate { void print(std::ostream &os) const override; smt::expr toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixupTypes(const smt::Model &m) override; }; diff --git a/ir/state.cpp b/ir/state.cpp index 3e940eb65..da6f6c894 100644 --- a/ir/state.cpp +++ b/ir/state.cpp @@ -255,8 +255,21 @@ State::State(const Function &f, bool source) : f(f), source(source), memory(*this), fp_rounding_mode(expr::mkVar("fp_rounding_mode", 3)), fp_denormal_mode(expr::mkVar("fp_denormal_mode", 2)), + vscale_data(vscaleFromAttr(f.getFnAttrs().vscaleRange)), return_val(DisjointExpr(f.getType().getDummyValue(false))) {} +expr State::vscaleFromAttr( + std::optional> vscaleAttr) { + if (vscaleAttr) { + auto [low, high] = *vscaleAttr; + unsigned r = 0; + for (unsigned i = ilog2(low); i <= ilog2(high); ++i) + r |= 1 << i; + return expr::mkUInt(r, var_vector_elements); + } + return expr::mkVscaleMin(); +} + void State::resetGlobals() { Memory::resetGlobals(); } @@ -697,7 +710,7 @@ bool State::isAsmMode() const { expr State::getPath(BasicBlock &bb) const { if (&f.getFirstBB() == &bb) return true; - + auto I = predecessor_data.find(&bb); if (I == predecessor_data.end()) return false; // Block is unreachable diff --git a/ir/state.h b/ir/state.h index 37ce7ff0b..49edbc067 100644 --- a/ir/state.h +++ b/ir/state.h @@ -182,6 +182,9 @@ class State { std::array tmp_values; unsigned i_tmp_values = 0; // next available position in tmp_values + // for scalable vectors + smt::expr vscale_data; + void check_enough_tmp_slots(); // return_domain: a boolean expression describing return condition @@ -303,6 +306,9 @@ class State { unsigned indirect_call_hash); auto& getVarArgsData() { return var_args_data.data; } + smt::expr getVscale() const { return vscale_data; } + static smt::expr + vscaleFromAttr(std::optional> vscaleAttr); void doesApproximation(std::string &&name, std::optional e = {}); auto& getApproximations() const { return used_approximations; } diff --git a/ir/type.cpp b/ir/type.cpp index 404fae097..4a3ebaca1 100644 --- a/ir/type.cpp +++ b/ir/type.cpp @@ -2,6 +2,8 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "ir/type.h" +#include "ir/attrs.h" +#include "ir/function.h" #include "ir/globals.h" #include "ir/state.h" #include "smt/solver.h" @@ -17,8 +19,6 @@ using namespace std; static constexpr unsigned var_type_bits = 3; static constexpr unsigned var_bw_bits = 11; -static constexpr unsigned var_vector_elements = 16; - namespace IR { @@ -297,7 +297,7 @@ StateValue VoidType::getDummyValue(bool non_poison) const { return { false, non_poison }; } -expr VoidType::getTypeConstraints() const { +expr VoidType::getTypeConstraints(const Function &f) const { return true; } @@ -341,7 +341,7 @@ StateValue IntType::getDummyValue(bool non_poison) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr IntType::getTypeConstraints() const { +expr IntType::getTypeConstraints(const Function &f) const { // since size cannot be unbounded, limit it between 1 and 64 bits if undefined auto bw = sizeVar(); auto r = bw != 0; @@ -570,7 +570,7 @@ StateValue FloatType::getDummyValue(bool non_poison) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr FloatType::getTypeConstraints() const { +expr FloatType::getTypeConstraints(const Function &f) const { if (defined) return true; @@ -675,7 +675,7 @@ StateValue PtrType::getDummyValue(bool non_poison) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr PtrType::getTypeConstraints() const { +expr PtrType::getTypeConstraints(const Function &f) const { return sizeVar() == bits(); } @@ -894,10 +894,10 @@ StateValue AggregateType::getDummyValue(bool non_poison) const { return aggregateVals(vals); } -expr AggregateType::getTypeConstraints() const { +expr AggregateType::getTypeConstraints(const Function &f) const { expr r(true), elems = numElements(); for (unsigned i = 0, e = children.size(); i != e; ++i) { - r &= elems.ugt(i).implies(children[i]->getTypeConstraints()); + r &= elems.ugt(i).implies(children[i]->getTypeConstraints(f)); } if (!defined) r &= elems.ule(4); @@ -1077,11 +1077,12 @@ void ArrayType::print(ostream &os) const { } } - -VectorType::VectorType(string &&name, unsigned elements, Type &elementTy) - : AggregateType(std::move(name), false) { - assert(elements != 0); - this->elements = elements; +VectorType::VectorType(string &&name, unsigned minElems, Type &elementTy, + bool isScalableTy) + : AggregateType(std::move(name), false) { + assert(minElems != 0); + this->isScalableTy = isScalableTy; + this->elements = minElems; defined = true; children.resize(elements, &elementTy); is_padding.resize(elements, false); @@ -1133,9 +1134,19 @@ StateValue VectorType::update(const StateValue &vector, (vector.non_poison & mask_np) | np_shifted}); } -expr VectorType::getTypeConstraints() const { +expr VectorType::getTypeConstraints(const Function &f) const { + auto vscaleAttr = f.getFnAttrs().vscaleRange; + if (isScalable()) { + // TODO: if we don't have a vscale_range on the function, fail the type + // check for now. + // If we don't havethe underlying storage for the high range of the vscale, + // fail the type check. + if (!vscaleAttr || vscaleAttr->second > var_vector_max_vscale) + return false; + } + auto &elementTy = *children[0]; - expr r = AggregateType::getTypeConstraints() && + expr r = AggregateType::getTypeConstraints(f) && (elementTy.enforceIntType() || elementTy.enforceFloatType() || elementTy.enforcePtrType()) && @@ -1146,6 +1157,9 @@ expr VectorType::getTypeConstraints() const { r &= numElements().ugt(i).implies(elementTy == *children[i]); } + // TODO: remove once scalable vectors are supported. + r &= !isScalable(); + return r; } @@ -1157,6 +1171,16 @@ bool VectorType::isVectorType() const { return true; } +expr VectorType::operator==(const VectorType &rhs) const { + expr res = this->AggregateType::operator==(rhs); + res &= isScalable() == rhs.isScalable(); + return res; +} + +bool VectorType::isScalable() const { + return isScalableTy; +} + expr VectorType::enforceVectorType( const function &enforceElem) const { return enforceElem(*children[0]); @@ -1164,7 +1188,8 @@ expr VectorType::enforceVectorType( void VectorType::print(ostream &os) const { if (elements) - os << '<' << elements << " x " << *children[0] << '>'; + os << '<' << (isScalable() ? "vscale x " : "") << elements << " x " + << *children[0] << '>'; } @@ -1258,14 +1283,14 @@ StateValue SymbolicType::getDummyValue(bool non_poison) const { DISPATCH(getDummyValue(non_poison), UNREACHABLE()); } -expr SymbolicType::getTypeConstraints() const { +expr SymbolicType::getTypeConstraints(const Function &fn) const { expr c(false); - if (i) c |= isInt() && i->getTypeConstraints(); - if (f) c |= isFloat() && f->getTypeConstraints(); - if (p) c |= isPtr() && p->getTypeConstraints(); - if (a) c |= isArray() && a->getTypeConstraints(); - if (v) c |= isVector() && v->getTypeConstraints(); - if (s) c |= isStruct() && s->getTypeConstraints(); + if (i) c |= isInt() && i->getTypeConstraints(fn); + if (f) c |= isFloat() && f->getTypeConstraints(fn); + if (p) c |= isPtr() && p->getTypeConstraints(fn); + if (a) c |= isArray() && a->getTypeConstraints(fn); + if (v) c |= isVector() && v->getTypeConstraints(fn); + if (s) c |= isStruct() && s->getTypeConstraints(fn); return c; } diff --git a/ir/type.h b/ir/type.h index b2b02c9e7..9c9bf21b5 100644 --- a/ir/type.h +++ b/ir/type.h @@ -26,6 +26,7 @@ class SymbolicType; class VectorType; class VoidType; class State; +class Function; struct StateValue; class Type { @@ -49,7 +50,7 @@ class Type { // to use when one needs the corresponding SMT type virtual IR::StateValue getDummyValue(bool non_poison) const = 0; - virtual smt::expr getTypeConstraints() const = 0; + virtual smt::expr getTypeConstraints(const Function &f) const = 0; virtual smt::expr sizeVar() const; virtual smt::expr scalarSize() const; smt::expr operator==(const Type &rhs) const; @@ -137,7 +138,7 @@ class VoidType final : public Type { VoidType() : Type("void") {} unsigned bits() const override; IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixup(const smt::Model &m) override; std::pair refines(State &src_s, State &tgt_s, const StateValue &src, @@ -162,7 +163,7 @@ class IntType final : public Type { unsigned maxSubBitAccess() const; unsigned bits() const override; IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const IntType &rhs) const; void fixup(const smt::Model &m) override; @@ -209,7 +210,7 @@ class FloatType final : public Type { smt::expr isNaN(const smt::expr &v, bool signalling) const; IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const FloatType &rhs) const; void fixup(const smt::Model &m) override; @@ -239,7 +240,7 @@ class PtrType final : public Type { unsigned bits() const override; unsigned np_bits(bool fromInt) const override; IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const PtrType &rhs) const; void fixup(const smt::Model &m) override; @@ -293,7 +294,7 @@ class AggregateType : public Type { unsigned np_bits(bool fromInt) const override; // Padding is filled with poison regardless of non_poison. IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const AggregateType &rhs) const; void fixup(const smt::Model &m) override; @@ -332,18 +333,23 @@ class ArrayType final : public AggregateType { class VectorType final : public AggregateType { + bool isScalableTy = false; + public: VectorType(std::string &&name) : AggregateType(std::move(name)) {} - VectorType(std::string &&name, unsigned elements, Type &elementTy); + VectorType(std::string &&name, unsigned minElems, Type &elementTy, + bool isScalableTy = false); IR::StateValue extract(const IR::StateValue &vector, const smt::expr &index) const; IR::StateValue update(const IR::StateValue &vector, const IR::StateValue &val, const smt::expr &idx) const; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr scalarSize() const override; bool isVectorType() const override; + smt::expr operator==(const VectorType &rhs) const; + bool isScalable() const; smt::expr enforceVectorType( const std::function &enforceElem) const override; void print(std::ostream &os) const override; @@ -384,7 +390,7 @@ class SymbolicType final : public Type { unsigned bits() const override; unsigned np_bits(bool fromInt) const override; IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr scalarSize() const override; smt::expr operator==(const Type &rhs) const; diff --git a/ir/value.cpp b/ir/value.cpp index dfb98dab5..03a8631d9 100644 --- a/ir/value.cpp +++ b/ir/value.cpp @@ -22,8 +22,8 @@ void Value::rauw(const Value &what, Value &with) { UNREACHABLE(); } -expr Value::getTypeConstraints() const { - return getType().getTypeConstraints(); +expr Value::getTypeConstraints(const Function &f) const { + return getType().getTypeConstraints(f); } void Value::fixupTypes(const Model &m) { @@ -163,8 +163,8 @@ void AggregateValue::rauw(const Value &what, Value &with) { setName(agg_str(getType(), vals)); } -expr AggregateValue::getTypeConstraints() const { - expr r = Value::getTypeConstraints(); +expr AggregateValue::getTypeConstraints(const Function &f) const { + expr r = Value::getTypeConstraints(f); vector types; for (auto *val : vals) { types.emplace_back(&val->getType()); @@ -172,7 +172,7 @@ expr AggregateValue::getTypeConstraints() const { // Instr's type constraints are already generated by BasicBlock's // getTypeConstraints() continue; - r &= val->getTypeConstraints(); + r &= val->getTypeConstraints(f); } return r && getType().enforceAggregateType(&types); } diff --git a/ir/value.h b/ir/value.h index f00afeaee..633906dde 100644 --- a/ir/value.h +++ b/ir/value.h @@ -16,6 +16,7 @@ namespace smt { class Model; } namespace IR { class VoidValue; +class Function; class Value { @@ -37,7 +38,7 @@ class Value { virtual void rauw(const Value &what, Value &with); virtual void print(std::ostream &os) const = 0; virtual StateValue toSMT(State &s) const = 0; - virtual smt::expr getTypeConstraints() const; + virtual smt::expr getTypeConstraints(const Function &f) const; void fixupTypes(const smt::Model &m); static VoidValue voidVal; @@ -109,7 +110,7 @@ class AggregateValue final : public Value { AggregateValue(Type &type, std::vector &&vals); auto& getVals() const { return vals; } void rauw(const Value &what, Value &with) override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; }; diff --git a/llvm_util/llvm2alive.cpp b/llvm_util/llvm2alive.cpp index a029025c0..061d7d43b 100644 --- a/llvm_util/llvm2alive.cpp +++ b/llvm_util/llvm2alive.cpp @@ -1737,6 +1737,13 @@ class llvm2alive_ : public llvm::InstVisitor> { attrs.set(FnAttrs::NullPointerIsValid); break; + case llvm::Attribute::VScaleRange: { + auto l = llvmattr.getVScaleRangeMin(); + auto r = llvmattr.getVScaleRangeMax().value_or(l); + attrs.vscaleRange = {l, r}; + break; + } + default: break; } diff --git a/llvm_util/utils.cpp b/llvm_util/utils.cpp index a359edea0..de4e059ed 100644 --- a/llvm_util/utils.cpp +++ b/llvm_util/utils.cpp @@ -198,7 +198,6 @@ Type* llvm_type2alive(const llvm::Type *ty) { } return cache.get(); } - // TODO: non-fixed sized vectors case llvm::Type::FixedVectorTyID: { auto &cache = type_cache[ty]; if (!cache) { @@ -212,6 +211,19 @@ Type* llvm_type2alive(const llvm::Type *ty) { } return cache.get(); } + case llvm::Type::ScalableVectorTyID: { + auto &cache = type_cache[ty]; + if (!cache) { + auto vty = cast(ty); + auto minelems = vty->getElementCount().getKnownMinValue(); + auto ety = llvm_type2alive(vty->getElementType()); + if (!ety || minelems > 1024) + return nullptr; + cache = make_unique("ty_" + to_string(type_id_counter++), + minelems, *ety, true); + } + return cache.get(); + } case llvm::Type::ArrayTyID: { auto &cache = type_cache[ty]; if (!cache) { diff --git a/smt/expr.h b/smt/expr.h index daae4add1..76e87ae2c 100644 --- a/smt/expr.h +++ b/smt/expr.h @@ -12,6 +12,9 @@ #include #include +static constexpr unsigned var_vector_elements = 16; +static constexpr unsigned var_vector_max_vscale = 16; + typedef struct _Z3_context* Z3_context; typedef struct _Z3_func_decl* Z3_decl; typedef struct _Z3_app* Z3_app; @@ -97,6 +100,9 @@ class expr { static expr mkQuadVar(const char *name); static expr mkFreshVar(const char *prefix, const expr &type); + // vscale-specific functions + static expr mkVscaleMin() { return expr::mkUInt(1, var_vector_elements); } + // return a constant value of the given type static expr some(const expr &type); diff --git a/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll new file mode 100644 index 000000000..9a29b9f35 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 2) { + %gep.ptr.16 = getelementptr i64, ptr %ptr, i64 16 + store <2 x i64> zeroinitializer, ptr %gep.ptr.16 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 2) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll new file mode 100644 index 000000000..c405dea46 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 2) { + %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2 + store <2 x i64> zeroinitializer, ptr %gep.ptr.2 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 2) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll new file mode 100644 index 000000000..be4b207d0 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 4) { + %gep.ptr.8 = getelementptr i64, ptr %ptr, i64 8 + store zeroinitializer, ptr %gep.ptr.8 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 4) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll new file mode 100644 index 000000000..95ad3d6b7 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 4) { + %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2 + store zeroinitializer, ptr %gep.ptr.2 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 4) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll b/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll new file mode 100644 index 000000000..35357765b --- /dev/null +++ b/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(2, 4) { + %v = insertelement %a, i8 -1, i64 2 + ret %v +} + +define @tgt( %a) vscale_range(2, 4) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll new file mode 100644 index 000000000..675d811fe --- /dev/null +++ b/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll @@ -0,0 +1,13 @@ +; SKIP-IDENTITY + +define i8 @src( %a) vscale_range(4, 4) { + %v = insertelement %a, i8 -1, i64 2 + %r = extractelement %v, i64 2 + ret i8 %r +} + +define i8 @tgt( %a) vscale_range(4, 4) { + ret i8 -1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll b/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll new file mode 100644 index 000000000..4b9301679 --- /dev/null +++ b/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll @@ -0,0 +1,13 @@ +; SKIP-IDENTITY + +define i8 @src( %a) vscale_range(2, 4) { + %v = insertelement %a, i8 -1, i64 2 + %r = extractelement %v, i64 2 + ret i8 %r +} + +define i8 @tgt( %a) vscale_range(2, 4) { + ret i8 -1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll b/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll new file mode 100644 index 000000000..7cf8207a0 --- /dev/null +++ b/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(1, 2) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define @tgt( %a) vscale_range(1, 2) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll new file mode 100644 index 000000000..41fd0cd93 --- /dev/null +++ b/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll @@ -0,0 +1,16 @@ +; SKIP-IDENTITY + +define i32 @src(i32 %a) vscale_range(4, 4) { + %poison = add nsw i32 2147483647, 100 + %v = insertelement poison, i32 %a, i64 0 + %v2 = insertelement %v, i32 %poison, i64 1 + %w = extractelement %v2, i64 0 + ret i32 %w +} + +define i32 @tgt(i32 %a) vscale_range(4, 4) { + %poison = add nsw i32 2147483647, 100 + ret i32 %poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll new file mode 100644 index 000000000..60fe126da --- /dev/null +++ b/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll @@ -0,0 +1,16 @@ +; SKIP-IDENTITY + +define @src( %x) vscale_range(2, 2) { + %rem.i = srem %x, splat(i8 2) + %cmp.i = icmp slt %rem.i, zeroinitializer + %add.i = select %cmp.i, splat(i8 2), zeroinitializer + ret %add.i +} + +define @tgt( %x) vscale_range(2, 2) { + %rem.i = srem %x, splat(i8 2) + %tmp1 = and %rem.i, splat(i8 2) + ret %tmp1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/rem.srctgt.ll b/tests/alive-tv/vector/vscale/rem.srctgt.ll new file mode 100644 index 000000000..c481cbe3f --- /dev/null +++ b/tests/alive-tv/vector/vscale/rem.srctgt.ll @@ -0,0 +1,16 @@ +; SKIP-IDENTITY + +define @src( %x) vscale_range(1, 2) { + %rem.i = srem %x, splat(i8 2) + %cmp.i = icmp slt %rem.i, zeroinitializer + %add.i = select %cmp.i, splat(i8 2), zeroinitializer + ret %add.i +} + +define @tgt( %x) vscale_range(1, 2) { + %rem.i = srem %x, splat(i8 2) + %tmp1 = and %rem.i, splat(i8 2) + ret %tmp1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll b/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll new file mode 100644 index 000000000..3c70ae8dd --- /dev/null +++ b/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define @tgt( %a) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll b/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll new file mode 100644 index 000000000..e700e7fdd --- /dev/null +++ b/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(1, 2) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define <1 x i8> @tgt(<1 x i8> %a) vscale_range(1, 2) { + ret <1 x i8> poison +} + +; ERROR: program doesn't type check! diff --git a/tools/transform.cpp b/tools/transform.cpp index ad45c6511..6d6a894d7 100644 --- a/tools/transform.cpp +++ b/tools/transform.cpp @@ -1497,7 +1497,7 @@ TypingAssignments TransformVerify::getTypings() const { auto c = t.src.getTypeConstraints() && t.tgt.getTypeConstraints(); if (t.precondition) - c &= t.precondition->getTypeConstraints(); + c &= t.precondition->getTypeConstraints(t.src); // return type c &= t.src.getType() == t.tgt.getType();