Skip to content

Commit

Permalink
ir: add syntactic support for scalable vectors
Browse files Browse the repository at this point in the history
As a first step towards verifying programs with scalable vectors, add
syntactic support with vscale implicitly assumed to be 1. This
assumption should be conservatively correct for the purposes of
verification.
  • Loading branch information
artagnon committed Nov 28, 2024
1 parent 708948d commit dcb24fe
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 10 deletions.
28 changes: 20 additions & 8 deletions ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,14 +1077,15 @@ void ArrayType::print(ostream &os) const {
}
}


VectorType::VectorType(string &&name, unsigned elements, Type &elementTy)
: AggregateType(std::move(name), false) {
assert(elements != 0);
this->elements = elements;
VectorType::VectorType(string &&name, unsigned minElts, Type &elementTy,
bool isScalableTy)
: AggregateType(std::move(name), false) {
assert(minElts != 0);
this->elements = minElts;
this->isScalableTy = isScalableTy;
defined = true;
children.resize(elements, &elementTy);
is_padding.resize(elements, false);
children.resize(minElts, &elementTy);
is_padding.resize(minElts, false);
}

StateValue VectorType::extract(const StateValue &vector,
Expand Down Expand Up @@ -1157,14 +1158,25 @@ bool VectorType::isVectorType() const {
return true;
}

expr VectorType::operator==(const VectorType &rhs) const {
expr res = this->AggregateType::operator==(rhs);
res &= isScalable() == rhs.isScalable();
return res;
}

bool VectorType::isScalable() const {
return isScalableTy;
}

expr VectorType::enforceVectorType(
const function<expr(const Type&)> &enforceElem) const {
return enforceElem(*children[0]);
}

void VectorType::print(ostream &os) const {
if (elements)
os << '<' << elements << " x " << *children[0] << '>';
os << '<' << (isScalable() ? "vscale x " : "") << elements << " x "
<< *children[0] << '>';
}


Expand Down
9 changes: 8 additions & 1 deletion ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,12 @@ class ArrayType final : public AggregateType {


class VectorType final : public AggregateType {
bool isScalableTy = false;

public:
VectorType(std::string &&name) : AggregateType(std::move(name)) {}
VectorType(std::string &&name, unsigned elements, Type &elementTy);
VectorType(std::string &&name, unsigned minElts, Type &elementTy,
bool isScalableTy = false);

IR::StateValue extract(const IR::StateValue &vector,
const smt::expr &index) const;
Expand All @@ -344,6 +347,10 @@ class VectorType final : public AggregateType {
smt::expr getTypeConstraints() const override;
smt::expr scalarSize() const override;
bool isVectorType() const override;
smt::expr operator==(const VectorType &rhs) const;

// TODO: handle vscale values other than 1.
bool isScalable() const;
smt::expr enforceVectorType(
const std::function<smt::expr(const Type&)> &enforceElem) const override;
void print(std::ostream &os) const override;
Expand Down
14 changes: 13 additions & 1 deletion llvm_util/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ Type* llvm_type2alive(const llvm::Type *ty) {
}
return cache.get();
}
// TODO: non-fixed sized vectors
case llvm::Type::FixedVectorTyID: {
auto &cache = type_cache[ty];
if (!cache) {
Expand All @@ -212,6 +211,19 @@ Type* llvm_type2alive(const llvm::Type *ty) {
}
return cache.get();
}
case llvm::Type::ScalableVectorTyID: {
auto &cache = type_cache[ty];
if (!cache) {
auto vty = cast<llvm::VectorType>(ty);
auto minelems = vty->getElementCount().getKnownMinValue();
auto ety = llvm_type2alive(vty->getElementType());
if (!ety || minelems > 1024)
return nullptr;
cache = make_unique<VectorType>("ty_" + to_string(type_id_counter++),
minelems, *ety, true);
}
return cache.get();
}
case llvm::Type::ArrayTyID: {
auto &cache = type_cache[ty];
if (!cache) {
Expand Down
14 changes: 14 additions & 0 deletions tests/alive-tv/vector/vscale/poison.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
define i32 @src(i32 %a) {
%poison = add nsw i32 2147483647, 100
%v = insertelement <vscale x 2 x i32> undef, i32 %a, i32 0
%v2 = insertelement <vscale x 2 x i32> %v, i32 %poison, i32 1
%w = extractelement <vscale x 2 x i32> %v2, i32 0
ret i32 %w
}

define i32 @tgt(i32 %a) {
%poison = add nsw i32 2147483647, 100
ret i32 %poison
}

; ERROR: Target is more poisonous than source
12 changes: 12 additions & 0 deletions tests/alive-tv/vector/vscale/rem.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
define <vscale x 2 x i8> @src(<vscale x 2 x i8> %x) {
%rem.i = srem <vscale x 2 x i8> %x, splat(i8 2)
%cmp.i = icmp slt <vscale x 2 x i8> %rem.i, zeroinitializer
%add.i = select <vscale x 2 x i1> %cmp.i, <vscale x 2 x i8> splat(i8 2), <vscale x 2 x i8> zeroinitializer
ret <vscale x 2 x i8> %add.i
}

define <vscale x 2 x i8> @tgt(<vscale x 2 x i8> %x) {
%rem.i = srem <vscale x 2 x i8> %x, splat(i8 2)
%tmp1 = and <vscale x 2 x i8> %rem.i, splat(i8 2)
ret <vscale x 2 x i8> %tmp1
}
10 changes: 10 additions & 0 deletions tests/alive-tv/vector/vscale/typecheck.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
define <vscale x 2 x ptr> @src() {
%a = getelementptr i8, ptr undef, <vscale x 2 x i64> zeroinitializer
ret <vscale x 2 x ptr> %a
}

define <2 x ptr> @tgt() {
ret <2 x ptr> undef
}

; CHECK: ERROR: program doesn't type check!
8 changes: 8 additions & 0 deletions tests/alive-tv/vector/vscale/undef.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
define <vscale x 2 x ptr> @src() {
%a = getelementptr i8, ptr undef, <vscale x 2 x i64> zeroinitializer
ret <vscale x 2 x ptr> %a
}

define <vscale x 2 x ptr> @tgt() {
ret <vscale x 2 x ptr> undef
}

0 comments on commit dcb24fe

Please sign in to comment.