Skip to content

Commit

Permalink
[RISCV] Support __riscv_v_fixed_vlen for vbool types. (llvm#76551)
Browse files Browse the repository at this point in the history
This adopts a similar behavior to AArch64 SVE, where bool vectors are
represented as a vector of chars with 1/8 the number of elements. This
ensures the vector always occupies a power of 2 number of bytes.

A consequence of this is that vbool64_t, vbool32_t, and vool16_t can
only be used with a vector length that guarantees at least 8 bits.
  • Loading branch information
topperc authored and tstellar committed Feb 14, 2024
1 parent f5e0ed7 commit 7227952
Show file tree
Hide file tree
Showing 20 changed files with 1,065 additions and 34 deletions.
2 changes: 2 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@ RISC-V Support
- Default ABI with F but without D was changed to ilp32f for RV32 and to lp64f
for RV64.

- ``__attribute__((rvv_vector_bits(N))) is now supported for RVV vbool*_t types.
CUDA/HIP Language Changes
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,9 @@ enum class VectorKind {

/// is RISC-V RVV fixed-length data vector
RVVFixedLengthData,

/// is RISC-V RVV fixed-length mask vector
RVVFixedLengthMask,
};

/// Represents a GCC generic vector type. This type is created using
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -2424,7 +2424,10 @@ only be a power of 2 between 64 and 65536.
For types where LMUL!=1, ``__riscv_v_fixed_vlen`` needs to be scaled by the LMUL
of the type before passing to the attribute.

``vbool*_t`` types are not supported at this time.
For ``vbool*_t`` types, ``__riscv_v_fixed_vlen`` needs to be divided by the
number from the type name. For example, ``vbool8_t`` needs to use
``__riscv_v_fixed_vlen`` / 8. If the resulting value is not a multiple of 8,
the type is not supported for that value of ``__riscv_v_fixed_vlen``.
}];
}

Expand Down
20 changes: 16 additions & 4 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,8 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
else if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
// Adjust the alignment for fixed-length SVE predicates.
Align = 16;
else if (VT->getVectorKind() == VectorKind::RVVFixedLengthData)
else if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
// Adjust the alignment for fixed-length RVV vectors.
Align = std::min<unsigned>(64, Width);
break;
Expand Down Expand Up @@ -9416,7 +9417,9 @@ bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
Second->getVectorKind() != VectorKind::SveFixedLengthData &&
Second->getVectorKind() != VectorKind::SveFixedLengthPredicate &&
First->getVectorKind() != VectorKind::RVVFixedLengthData &&
Second->getVectorKind() != VectorKind::RVVFixedLengthData)
Second->getVectorKind() != VectorKind::RVVFixedLengthData &&
First->getVectorKind() != VectorKind::RVVFixedLengthMask &&
Second->getVectorKind() != VectorKind::RVVFixedLengthMask)
return true;

return false;
Expand Down Expand Up @@ -9522,8 +9525,11 @@ static uint64_t getRVVTypeSize(ASTContext &Context, const BuiltinType *Ty) {

ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(Ty);

uint64_t EltSize = Context.getTypeSize(Info.ElementType);
uint64_t MinElts = Info.EC.getKnownMinValue();
unsigned EltSize = Context.getTypeSize(Info.ElementType);
if (Info.ElementType == Context.BoolTy)
EltSize = 1;

unsigned MinElts = Info.EC.getKnownMinValue();
return VScale->first * MinElts * EltSize;
}

Expand All @@ -9537,6 +9543,12 @@ bool ASTContext::areCompatibleRVVTypes(QualType FirstType,
auto IsValidCast = [this](QualType FirstType, QualType SecondType) {
if (const auto *BT = FirstType->getAs<BuiltinType>()) {
if (const auto *VT = SecondType->getAs<VectorType>()) {
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
return FirstType->isRVVVLSBuiltinType() &&
Info.ElementType == BoolTy &&
getTypeSize(SecondType) == getRVVTypeSize(*this, BT);
}
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::Generic)
return FirstType->isRVVVLSBuiltinType() &&
Expand Down
25 changes: 17 additions & 8 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3994,7 +3994,8 @@ void CXXNameMangler::mangleAArch64FixedSveVectorType(
}

void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
assert(T->getVectorKind() == VectorKind::RVVFixedLengthData &&
assert((T->getVectorKind() == VectorKind::RVVFixedLengthData ||
T->getVectorKind() == VectorKind::RVVFixedLengthMask) &&
"expected fixed-length RVV vector!");

QualType EltType = T->getElementType();
Expand All @@ -4009,7 +4010,10 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
TypeNameOS << "int8";
break;
case BuiltinType::UChar:
TypeNameOS << "uint8";
if (T->getVectorKind() == VectorKind::RVVFixedLengthData)
TypeNameOS << "uint8";
else
TypeNameOS << "bool";
break;
case BuiltinType::Short:
TypeNameOS << "int16";
Expand Down Expand Up @@ -4048,12 +4052,16 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
auto VScale = getASTContext().getTargetInfo().getVScaleRange(
getASTContext().getLangOpts());
unsigned VLen = VScale->first * llvm::RISCV::RVVBitsPerBlock;
TypeNameOS << 'm';
if (VecSizeInBits >= VLen)
TypeNameOS << (VecSizeInBits / VLen);
else
TypeNameOS << 'f' << (VLen / VecSizeInBits);

if (T->getVectorKind() == VectorKind::RVVFixedLengthData) {
TypeNameOS << 'm';
if (VecSizeInBits >= VLen)
TypeNameOS << (VecSizeInBits / VLen);
else
TypeNameOS << 'f' << (VLen / VecSizeInBits);
} else {
TypeNameOS << (VLen / VecSizeInBits);
}
TypeNameOS << "_t";

Out << "9__RVV_VLSI" << 'u' << TypeNameStr.size() << TypeNameStr << "Lj"
Expand Down Expand Up @@ -4093,7 +4101,8 @@ void CXXNameMangler::mangleType(const VectorType *T) {
T->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
mangleAArch64FixedSveVectorType(T);
return;
} else if (T->getVectorKind() == VectorKind::RVVFixedLengthData) {
} else if (T->getVectorKind() == VectorKind::RVVFixedLengthData ||
T->getVectorKind() == VectorKind::RVVFixedLengthMask) {
mangleRISCVFixedRVVVectorType(T);
return;
}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/JSONNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ void JSONNodeDumper::VisitVectorType(const VectorType *VT) {
case VectorKind::RVVFixedLengthData:
JOS.attribute("vectorKind", "fixed-length rvv data vector");
break;
case VectorKind::RVVFixedLengthMask:
JOS.attribute("vectorKind", "fixed-length rvv mask vector");
break;
}
}

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,9 @@ void TextNodeDumper::VisitVectorType(const VectorType *T) {
case VectorKind::RVVFixedLengthData:
OS << " fixed-length rvv data vector";
break;
case VectorKind::RVVFixedLengthMask:
OS << " fixed-length rvv mask vector";
break;
}
OS << " " << T->getNumElements();
}
Expand Down
15 changes: 14 additions & 1 deletion clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,9 @@ bool Type::isRVVVLSBuiltinType() const {
IsFP, IsBF) \
case BuiltinType::Id: \
return NF == 1;
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return true;
#include "clang/Basic/RISCVVTypes.def"
default:
return false;
Expand All @@ -2491,7 +2494,17 @@ QualType Type::getRVVEltType(const ASTContext &Ctx) const {
assert(isRVVVLSBuiltinType() && "unsupported type!");

const BuiltinType *BTy = castAs<BuiltinType>();
return Ctx.getBuiltinVectorTypeInfo(BTy).ElementType;

switch (BTy->getKind()) {
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return Ctx.UnsignedCharTy;
default:
return Ctx.getBuiltinVectorTypeInfo(BTy).ElementType;
#include "clang/Basic/RISCVVTypes.def"
}

llvm_unreachable("Unhandled type");
}

bool QualType::isPODType(const ASTContext &Context) const {
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ void TypePrinter::printVectorBefore(const VectorType *T, raw_ostream &OS) {
printBefore(T->getElementType(), OS);
break;
case VectorKind::RVVFixedLengthData:
case VectorKind::RVVFixedLengthMask:
// FIXME: We prefer to print the size directly here, but have no way
// to get the size of the type.
OS << "__attribute__((__riscv_rvv_vector_bits__(";
Expand Down Expand Up @@ -773,6 +774,7 @@ void TypePrinter::printDependentVectorBefore(
printBefore(T->getElementType(), OS);
break;
case VectorKind::RVVFixedLengthData:
case VectorKind::RVVFixedLengthMask:
// FIXME: We prefer to print the size directly here, but have no way
// to get the size of the type.
OS << "__attribute__((__riscv_rvv_vector_bits__(";
Expand Down
21 changes: 15 additions & 6 deletions clang/lib/CodeGen/Targets/RISCV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,28 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty) const {
assert(Ty->isVectorType() && "expected vector type!");

const auto *VT = Ty->castAs<VectorType>();
assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
"Unexpected vector kind");

assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");

auto VScale =
getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts());

unsigned NumElts = VT->getNumElements();
llvm::Type *EltType;
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
NumElts *= 8;
EltType = llvm::Type::getInt1Ty(getVMContext());
} else {
assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
"Unexpected vector kind");
EltType = CGT.ConvertType(VT->getElementType());
}

// The MinNumElts is simplified from equation:
// NumElts / VScale =
// (EltSize * NumElts / (VScale * RVVBitsPerBlock))
// * (RVVBitsPerBlock / EltSize)
llvm::ScalableVectorType *ResType =
llvm::ScalableVectorType::get(CGT.ConvertType(VT->getElementType()),
VT->getNumElements() / VScale->first);
llvm::ScalableVectorType::get(EltType, NumElts / VScale->first);
return ABIArgInfo::getDirect(ResType);
}

Expand Down Expand Up @@ -437,7 +445,8 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
}

if (const VectorType *VT = Ty->getAs<VectorType>())
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData)
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
return coerceVLSVector(Ty);

// Aggregates which are <= 2*XLen will be passed in registers if possible,
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11142,7 +11142,8 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
if (VecType->getVectorKind() == VectorKind::SveFixedLengthData ||
VecType->getVectorKind() == VectorKind::SveFixedLengthPredicate)
return true;
if (VecType->getVectorKind() == VectorKind::RVVFixedLengthData) {
if (VecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
SVEorRVV = 1;
return true;
}
Expand Down Expand Up @@ -11173,7 +11174,8 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
SecondVecType->getVectorKind() ==
VectorKind::SveFixedLengthPredicate)
return true;
if (SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthData) {
if (SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
SVEorRVV = 1;
return true;
}
Expand Down
21 changes: 15 additions & 6 deletions clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8646,21 +8646,30 @@ static void HandleRISCVRVVVectorBitsTypeAttr(QualType &CurType,

ASTContext::BuiltinVectorTypeInfo Info =
S.Context.getBuiltinVectorTypeInfo(CurType->castAs<BuiltinType>());
unsigned EltSize = S.Context.getTypeSize(Info.ElementType);
unsigned MinElts = Info.EC.getKnownMinValue();

VectorKind VecKind = VectorKind::RVVFixedLengthData;
unsigned ExpectedSize = VScale->first * MinElts;
QualType EltType = CurType->getRVVEltType(S.Context);
unsigned EltSize = S.Context.getTypeSize(EltType);
unsigned NumElts;
if (Info.ElementType == S.Context.BoolTy) {
NumElts = VecSize / S.Context.getCharWidth();
VecKind = VectorKind::RVVFixedLengthMask;
} else {
ExpectedSize *= EltSize;
NumElts = VecSize / EltSize;
}

// The attribute vector size must match -mrvv-vector-bits.
unsigned ExpectedSize = VScale->first * MinElts * EltSize;
if (VecSize != ExpectedSize) {
if (ExpectedSize % 8 != 0 || VecSize != ExpectedSize) {
S.Diag(Attr.getLoc(), diag::err_attribute_bad_rvv_vector_size)
<< VecSize << ExpectedSize;
Attr.setInvalid();
return;
}

VectorKind VecKind = VectorKind::RVVFixedLengthData;
VecSize /= EltSize;
CurType = S.Context.getVectorType(Info.ElementType, VecSize, VecKind);
CurType = S.Context.getVectorType(EltType, NumElts, VecKind);
}

/// Handle OpenCL Access Qualifier Attribute.
Expand Down
Loading

0 comments on commit 7227952

Please sign in to comment.