Skip to content

Commit

Permalink
ir: add syntactic support for vscale, vscale_range
Browse files Browse the repository at this point in the history
In an effort to support checking of programs with scalable vectors, add
syntactic support for vscale and vscale_range, making the vscale an
smt::expr in State. For the moment, we fail to type-check any program
using scalable vectors.
  • Loading branch information
artagnon committed Dec 9, 2024
1 parent 08285b3 commit 0aa59bd
Show file tree
Hide file tree
Showing 33 changed files with 361 additions and 98 deletions.
4 changes: 4 additions & 0 deletions ir/attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ ostream& operator<<(ostream &os, const FnAttrs &attr) {
os << ", " << attr.allocsize_1;
os << ')';
}
if (attr.vscaleRange) {
auto [low, high] = *attr.vscaleRange;
os << " vscale_range(" << low << ", " << high << ')';
}

attr.fp_denormal.print(os);
if (attr.fp_denormal32)
Expand Down
2 changes: 2 additions & 0 deletions ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class FnAttrs final {
AllocSize = 1 << 12, ZeroExt = 1<<13,
SignExt = 1<<14, NoFPClass = 1<<15, Asm = 1<<16 };

std::optional<std::pair<uint16_t, uint16_t>> vscaleRange;

FnAttrs(unsigned bits = None) : bits(bits) {}

bool has(Attribute a) const { return (bits & a) != 0; }
Expand Down
22 changes: 11 additions & 11 deletions ir/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ StateValue IntConst::toSMT(State &s) const {
return { expr::mkInt(get<string>(val).c_str(), bits()), true };
}

expr IntConst::getTypeConstraints() const {
expr IntConst::getTypeConstraints(const Function &f) const {
unsigned min_bits = 0;
if (auto v = get_if<int64_t>(&val))
min_bits = (*v >= 0 ? 63 : 64) - num_sign_bits(*v);

return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
getType().sizeVar().uge(min_bits);
}
Expand Down Expand Up @@ -86,8 +86,8 @@ FloatConst::FloatConst(Type &type, string val, bool bit_value)
: Constant(type, bit_value ? int_to_readable_float(type, val) : val),
val(std::move(val)), bit_value(bit_value) {}

expr FloatConst::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr FloatConst::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
getType().enforceFloatType();
}

Expand All @@ -108,8 +108,8 @@ StateValue ConstantInput::toSMT(State &s) const {
return { expr::mkVar(getName().c_str(), type), true };
}

expr ConstantInput::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr ConstantInput::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
(getType().enforceIntType() || getType().enforceFloatType());
}

Expand Down Expand Up @@ -157,8 +157,8 @@ StateValue ConstantBinOp::toSMT(State &s) const {
return { std::move(val), ap && bp };
}

expr ConstantBinOp::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr ConstantBinOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
getType() == lhs.getType() &&
getType() == rhs.getType();
Expand Down Expand Up @@ -210,10 +210,10 @@ StateValue ConstantFn::toSMT(State &s) const {
return { std::move(r), true };
}

expr ConstantFn::getTypeConstraints() const {
expr r = Value::getTypeConstraints();
expr ConstantFn::getTypeConstraints(const Function &f) const {
expr r = Value::getTypeConstraints(f);
for (auto a : args) {
r &= a->getTypeConstraints();
r &= a->getTypeConstraints(f);
}

Type &ty = getType();
Expand Down
10 changes: 5 additions & 5 deletions ir/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class IntConst final : public Constant {
IntConst(Type &type, int64_t val);
IntConst(Type &type, std::string &&val);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
auto getInt() const { return std::get_if<int64_t>(&val); }
};

Expand All @@ -38,7 +38,7 @@ class FloatConst final : public Constant {
FloatConst(Type &type, std::string val, bool bit_value);

StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -47,7 +47,7 @@ class ConstantInput final : public Constant {
ConstantInput(Type &type, std::string &&name)
: Constant(type, std::move(name)) {}
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -62,7 +62,7 @@ class ConstantBinOp final : public Constant {
public:
ConstantBinOp(Type &type, Constant &lhs, Constant &rhs, Op op);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -73,7 +73,7 @@ class ConstantFn final : public Constant {
public:
ConstantFn(Type &type, std::string_view name, std::vector<Value*> &&args);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};

struct ConstantFnException {
Expand Down
2 changes: 1 addition & 1 deletion ir/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ expr Function::getTypeConstraints() const {
}
for (auto &l : { getConstants(), getInputs(), getUndefs() }) {
for (auto &v : l) {
t &= v.getTypeConstraints();
t &= v.getTypeConstraints(*this);
}
}
return t;
Expand Down
56 changes: 28 additions & 28 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ uint64_t getGlobalVarSize(const IR::Value *V) {

namespace IR {

expr Instr::getTypeConstraints() const {
expr Instr::getTypeConstraints(const Function &f) const {
UNREACHABLE();
return {};
}
Expand Down Expand Up @@ -596,7 +596,7 @@ expr BinOp::getTypeConstraints(const Function &f) const {
getType() == rhs->getType();
break;
}
return Value::getTypeConstraints() && std::move(instrconstr);
return Value::getTypeConstraints(f) && std::move(instrconstr);
}

unique_ptr<Instr> BinOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -958,7 +958,7 @@ StateValue FpBinOp::toSMT(State &s) const {
}

expr FpBinOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceFloatOrVectorType() &&
getType() == lhs->getType() &&
getType() == rhs->getType();
Expand Down Expand Up @@ -1086,7 +1086,7 @@ expr UnaryOp::getTypeConstraints(const Function &f) const {
break;
}

return Value::getTypeConstraints() && std::move(instrconstr);
return Value::getTypeConstraints(f) && std::move(instrconstr);
}

static Value* dup_aggregate(Function &f, Value *val) {
Expand Down Expand Up @@ -1213,7 +1213,7 @@ StateValue FpUnaryOp::toSMT(State &s) const {
}

expr FpUnaryOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType() &&
getType().enforceFloatOrVectorType();
}
Expand Down Expand Up @@ -1286,7 +1286,7 @@ StateValue UnaryReductionOp::toSMT(State &s) const {
}

expr UnaryReductionOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
val->getType().enforceVectorType(
[this](auto &scalar) { return scalar == getType(); });
Expand Down Expand Up @@ -1405,7 +1405,7 @@ expr TernaryOp::getTypeConstraints(const Function &f) const {
getType().enforceIntOrVectorType();
break;
}
return Value::getTypeConstraints() && instrconstr;
return Value::getTypeConstraints(f) && instrconstr;
}

unique_ptr<Instr> TernaryOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -1486,7 +1486,7 @@ StateValue FpTernaryOp::toSMT(State &s) const {
}

expr FpTernaryOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == a->getType() &&
getType() == b->getType() &&
getType() == c->getType() &&
Expand Down Expand Up @@ -1557,7 +1557,7 @@ StateValue TestOp::toSMT(State &s) const {
}

expr TestOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
lhs->getType().enforceFloatOrVectorType() &&
rhs->getType().enforceIntType(32) &&
getType().enforceIntOrVectorType(1) &&
Expand Down Expand Up @@ -1721,7 +1721,7 @@ expr ConversionOp::getTypeConstraints(const Function &f) const {
break;
}

c &= Value::getTypeConstraints();
c &= Value::getTypeConstraints(f);
if (op != BitCast)
c &= getType().enforceVectorTypeEquiv(val->getType());
return c;
Expand Down Expand Up @@ -1965,7 +1965,7 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const {
val->getType().scalarSize().ugt(getType().scalarSize());
break;
}
return Value::getTypeConstraints() && c;
return Value::getTypeConstraints(f) && c;
}

unique_ptr<Instr> FpConversionOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -2027,7 +2027,7 @@ StateValue Select::toSMT(State &s) const {
}

expr Select::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
cond->getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeIff(cond->getType()) &&
(fmath.isNone() ? expr(true) : getType().enforceFloatOrVectorType()) &&
Expand Down Expand Up @@ -2080,7 +2080,7 @@ StateValue ExtractValue::toSMT(State &s) const {
}

expr ExtractValue::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
val->getType().enforceAggregateType();

Type *type = &val->getType();
Expand Down Expand Up @@ -2172,7 +2172,7 @@ StateValue InsertValue::toSMT(State &s) const {
}

expr InsertValue::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
val->getType().enforceAggregateType() &&
val->getType() == getType();

Expand Down Expand Up @@ -2646,7 +2646,7 @@ StateValue FnCall::toSMT(State &s) const {

expr FnCall::getTypeConstraints(const Function &f) const {
// TODO : also need to name each arg type smt var uniquely
expr ret = Value::getTypeConstraints();
expr ret = Value::getTypeConstraints(f);
if (fnptr)
ret &= fnptr->getType().enforcePtrType();
return ret;
Expand Down Expand Up @@ -2809,7 +2809,7 @@ StateValue ICmp::toSMT(State &s) const {
}

expr ICmp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeEquiv(a->getType()) &&
a->getType().enforceIntOrPtrOrVectorType() &&
Expand Down Expand Up @@ -2908,7 +2908,7 @@ StateValue FCmp::toSMT(State &s) const {
}

expr FCmp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeEquiv(a->getType()) &&
a->getType().enforceFloatOrVectorType() &&
Expand Down Expand Up @@ -2968,7 +2968,7 @@ StateValue Freeze::toSMT(State &s) const {
}

expr Freeze::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType();
}

Expand Down Expand Up @@ -3080,7 +3080,7 @@ StateValue Phi::toSMT(State &s) const {
}

expr Phi::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints();
auto c = Value::getTypeConstraints(f);
for (auto &[val, bb] : values) {
c &= val->getType() == getType();
}
Expand Down Expand Up @@ -3324,7 +3324,7 @@ StateValue Return::toSMT(State &s) const {
}

expr Return::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType() &&
f.getType() == getType();
}
Expand Down Expand Up @@ -3711,7 +3711,7 @@ StateValue Alloc::toSMT(State &s) const {
}

expr Alloc::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforcePtrType() &&
size->getType().enforceIntType();
}
Expand Down Expand Up @@ -3967,7 +3967,7 @@ StateValue GEP::toSMT(State &s) const {
}

expr GEP::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
getType().enforceVectorTypeIff(ptr->getType()) &&
getType().enforcePtrOrVectorType();
for (auto &[sz, idx] : idxs) {
Expand Down Expand Up @@ -4052,7 +4052,7 @@ StateValue PtrMask::toSMT(State &s) const {
}

expr PtrMask::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrOrVectorType() &&
getType() == ptr->getType() &&
mask->getType().enforceIntOrVectorType() &&
Expand Down Expand Up @@ -4101,7 +4101,7 @@ StateValue Load::toSMT(State &s) const {
}

expr Load::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrType();
}

Expand Down Expand Up @@ -4583,7 +4583,7 @@ StateValue Strlen::toSMT(State &s) const {
}

expr Strlen::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrType() &&
getType().enforceIntType();
}
Expand Down Expand Up @@ -4886,7 +4886,7 @@ StateValue ExtractElement::toSMT(State &s) const {
}

expr ExtractElement::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
v->getType().enforceVectorType([&](auto &ty)
{ return ty == getType(); }) &&
idx->getType().enforceIntType();
Expand Down Expand Up @@ -4929,7 +4929,7 @@ StateValue InsertElement::toSMT(State &s) const {
}

expr InsertElement::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == v->getType() &&
v->getType().enforceVectorType([&](auto &ty)
{ return ty == e->getType(); }) &&
Expand Down Expand Up @@ -4984,7 +4984,7 @@ StateValue ShuffleVector::toSMT(State &s) const {
}

expr ShuffleVector::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceVectorTypeSameChildTy(v1->getType()) &&
getType().getAsAggregateType()->numElements() == mask.size() &&
v1->getType().enforceVectorType() &&
Expand Down
3 changes: 1 addition & 2 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class Instr : public Value {
virtual bool propagatesPoison() const = 0;
virtual bool hasSideEffects() const = 0;
virtual bool isTerminator() const;
smt::expr getTypeConstraints() const override;
virtual smt::expr getTypeConstraints(const Function &f) const = 0;
smt::expr getTypeConstraints(const Function &f) const override;
virtual std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const = 0;
};
Expand Down
Loading

0 comments on commit 0aa59bd

Please sign in to comment.