Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Support __riscv_v_fixed_vlen for vbool types. #76551

Merged
merged 3 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,8 @@ RISC-V Support
- Default ABI with F but without D was changed to ilp32f for RV32 and to lp64f
for RV64.

- ``__attribute__((rvv_vector_bits(N))) is now supported for RVV vbool*_t types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line breaks the documentation build (missing ending ``).


CUDA/HIP Language Changes
^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -3492,6 +3492,9 @@ enum class VectorKind {

/// is RISC-V RVV fixed-length data vector
RVVFixedLengthData,

/// is RISC-V RVV fixed-length mask vector
RVVFixedLengthMask,
};

/// Represents a GCC generic vector type. This type is created using
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -2415,7 +2415,10 @@ only be a power of 2 between 64 and 65536.
For types where LMUL!=1, ``__riscv_v_fixed_vlen`` needs to be scaled by the LMUL
of the type before passing to the attribute.

``vbool*_t`` types are not supported at this time.
For ``vbool*_t`` types, ``__riscv_v_fixed_vlen`` needs to be divided by the
number from the type name. For example, ``vbool8_t`` needs to use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__riscv_v_fixed_vlen needs to be divided by the number from the type name.

Can this be done by compiler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but the intention was that the value passed to the attribute should be the size of the type. That's the way LMUL is handled. LMUL 2 needs to pass 2*__riscv_v_fixed_vlen.

``__riscv_v_fixed_vlen`` / 8. If the resulting value is not a multiple of 8,
the type is not supported for that value of ``__riscv_v_fixed_vlen``.
}];
}

Expand Down
20 changes: 16 additions & 4 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1938,7 +1938,8 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
else if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate)
// Adjust the alignment for fixed-length SVE predicates.
Align = 16;
else if (VT->getVectorKind() == VectorKind::RVVFixedLengthData)
else if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
// Adjust the alignment for fixed-length RVV vectors.
Align = std::min<unsigned>(64, Width);
break;
Expand Down Expand Up @@ -9404,7 +9405,9 @@ bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
Second->getVectorKind() != VectorKind::SveFixedLengthData &&
Second->getVectorKind() != VectorKind::SveFixedLengthPredicate &&
First->getVectorKind() != VectorKind::RVVFixedLengthData &&
Second->getVectorKind() != VectorKind::RVVFixedLengthData)
Second->getVectorKind() != VectorKind::RVVFixedLengthData &&
First->getVectorKind() != VectorKind::RVVFixedLengthMask &&
Second->getVectorKind() != VectorKind::RVVFixedLengthMask)
return true;

return false;
Expand Down Expand Up @@ -9510,8 +9513,11 @@ static uint64_t getRVVTypeSize(ASTContext &Context, const BuiltinType *Ty) {

ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(Ty);

uint64_t EltSize = Context.getTypeSize(Info.ElementType);
uint64_t MinElts = Info.EC.getKnownMinValue();
unsigned EltSize = Context.getTypeSize(Info.ElementType);
if (Info.ElementType == Context.BoolTy)
EltSize = 1;

unsigned MinElts = Info.EC.getKnownMinValue();
return VScale->first * MinElts * EltSize;
}

Expand All @@ -9525,6 +9531,12 @@ bool ASTContext::areCompatibleRVVTypes(QualType FirstType,
auto IsValidCast = [this](QualType FirstType, QualType SecondType) {
if (const auto *BT = FirstType->getAs<BuiltinType>()) {
if (const auto *VT = SecondType->getAs<VectorType>()) {
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
return FirstType->isRVVVLSBuiltinType() &&
Info.ElementType == BoolTy &&
getTypeSize(SecondType) == getRVVTypeSize(*this, BT);
}
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::Generic)
return FirstType->isRVVVLSBuiltinType() &&
Expand Down
25 changes: 17 additions & 8 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3994,7 +3994,8 @@ void CXXNameMangler::mangleAArch64FixedSveVectorType(
}

void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
assert(T->getVectorKind() == VectorKind::RVVFixedLengthData &&
assert((T->getVectorKind() == VectorKind::RVVFixedLengthData ||
T->getVectorKind() == VectorKind::RVVFixedLengthMask) &&
"expected fixed-length RVV vector!");

QualType EltType = T->getElementType();
Expand All @@ -4009,7 +4010,10 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
TypeNameOS << "int8";
break;
case BuiltinType::UChar:
TypeNameOS << "uint8";
if (T->getVectorKind() == VectorKind::RVVFixedLengthData)
TypeNameOS << "uint8";
else
TypeNameOS << "bool";
break;
case BuiltinType::Short:
TypeNameOS << "int16";
Expand Down Expand Up @@ -4048,12 +4052,16 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
auto VScale = getASTContext().getTargetInfo().getVScaleRange(
getASTContext().getLangOpts());
unsigned VLen = VScale->first * llvm::RISCV::RVVBitsPerBlock;
TypeNameOS << 'm';
if (VecSizeInBits >= VLen)
TypeNameOS << (VecSizeInBits / VLen);
else
TypeNameOS << 'f' << (VLen / VecSizeInBits);

if (T->getVectorKind() == VectorKind::RVVFixedLengthData) {
TypeNameOS << 'm';
if (VecSizeInBits >= VLen)
TypeNameOS << (VecSizeInBits / VLen);
else
TypeNameOS << 'f' << (VLen / VecSizeInBits);
} else {
TypeNameOS << (VLen / VecSizeInBits);
}
TypeNameOS << "_t";

Out << "9__RVV_VLSI" << 'u' << TypeNameStr.size() << TypeNameStr << "Lj"
Expand Down Expand Up @@ -4093,7 +4101,8 @@ void CXXNameMangler::mangleType(const VectorType *T) {
T->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
mangleAArch64FixedSveVectorType(T);
return;
} else if (T->getVectorKind() == VectorKind::RVVFixedLengthData) {
} else if (T->getVectorKind() == VectorKind::RVVFixedLengthData ||
T->getVectorKind() == VectorKind::RVVFixedLengthMask) {
mangleRISCVFixedRVVVectorType(T);
return;
}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/JSONNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ void JSONNodeDumper::VisitVectorType(const VectorType *VT) {
case VectorKind::RVVFixedLengthData:
JOS.attribute("vectorKind", "fixed-length rvv data vector");
break;
case VectorKind::RVVFixedLengthMask:
JOS.attribute("vectorKind", "fixed-length rvv mask vector");
break;
}
}

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,9 @@ void TextNodeDumper::VisitVectorType(const VectorType *T) {
case VectorKind::RVVFixedLengthData:
OS << " fixed-length rvv data vector";
break;
case VectorKind::RVVFixedLengthMask:
OS << " fixed-length rvv mask vector";
break;
}
OS << " " << T->getNumElements();
}
Expand Down
15 changes: 14 additions & 1 deletion clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,9 @@ bool Type::isRVVVLSBuiltinType() const {
IsFP, IsBF) \
case BuiltinType::Id: \
return NF == 1;
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return true;
#include "clang/Basic/RISCVVTypes.def"
default:
return false;
Expand All @@ -2491,7 +2494,17 @@ QualType Type::getRVVEltType(const ASTContext &Ctx) const {
assert(isRVVVLSBuiltinType() && "unsupported type!");

const BuiltinType *BTy = castAs<BuiltinType>();
return Ctx.getBuiltinVectorTypeInfo(BTy).ElementType;

switch (BTy->getKind()) {
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return Ctx.UnsignedCharTy;
default:
return Ctx.getBuiltinVectorTypeInfo(BTy).ElementType;
#include "clang/Basic/RISCVVTypes.def"
}

llvm_unreachable("Unhandled type");
}

bool QualType::isPODType(const ASTContext &Context) const {
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ void TypePrinter::printVectorBefore(const VectorType *T, raw_ostream &OS) {
printBefore(T->getElementType(), OS);
break;
case VectorKind::RVVFixedLengthData:
case VectorKind::RVVFixedLengthMask:
// FIXME: We prefer to print the size directly here, but have no way
// to get the size of the type.
OS << "__attribute__((__riscv_rvv_vector_bits__(";
Expand Down Expand Up @@ -773,6 +774,7 @@ void TypePrinter::printDependentVectorBefore(
printBefore(T->getElementType(), OS);
break;
case VectorKind::RVVFixedLengthData:
case VectorKind::RVVFixedLengthMask:
// FIXME: We prefer to print the size directly here, but have no way
// to get the size of the type.
OS << "__attribute__((__riscv_rvv_vector_bits__(";
Expand Down
21 changes: 15 additions & 6 deletions clang/lib/CodeGen/Targets/RISCV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,20 +318,28 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty) const {
assert(Ty->isVectorType() && "expected vector type!");

const auto *VT = Ty->castAs<VectorType>();
assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
"Unexpected vector kind");

assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");

auto VScale =
getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts());

unsigned NumElts = VT->getNumElements();
llvm::Type *EltType;
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
NumElts *= 8;
EltType = llvm::Type::getInt1Ty(getVMContext());
} else {
assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
"Unexpected vector kind");
EltType = CGT.ConvertType(VT->getElementType());
}

// The MinNumElts is simplified from equation:
// NumElts / VScale =
// (EltSize * NumElts / (VScale * RVVBitsPerBlock))
// * (RVVBitsPerBlock / EltSize)
llvm::ScalableVectorType *ResType =
llvm::ScalableVectorType::get(CGT.ConvertType(VT->getElementType()),
VT->getNumElements() / VScale->first);
llvm::ScalableVectorType::get(EltType, NumElts / VScale->first);
return ABIArgInfo::getDirect(ResType);
}

Expand Down Expand Up @@ -431,7 +439,8 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
}

if (const VectorType *VT = Ty->getAs<VectorType>())
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData)
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
return coerceVLSVector(Ty);

// Aggregates which are <= 2*XLen will be passed in registers if possible,
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11167,7 +11167,8 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
if (VecType->getVectorKind() == VectorKind::SveFixedLengthData ||
VecType->getVectorKind() == VectorKind::SveFixedLengthPredicate)
return true;
if (VecType->getVectorKind() == VectorKind::RVVFixedLengthData) {
if (VecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
SVEorRVV = 1;
return true;
}
Expand Down Expand Up @@ -11198,7 +11199,8 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
SecondVecType->getVectorKind() ==
VectorKind::SveFixedLengthPredicate)
return true;
if (SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthData) {
if (SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
SVEorRVV = 1;
return true;
}
Expand Down
21 changes: 15 additions & 6 deletions clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8542,21 +8542,30 @@ static void HandleRISCVRVVVectorBitsTypeAttr(QualType &CurType,

ASTContext::BuiltinVectorTypeInfo Info =
S.Context.getBuiltinVectorTypeInfo(CurType->castAs<BuiltinType>());
unsigned EltSize = S.Context.getTypeSize(Info.ElementType);
unsigned MinElts = Info.EC.getKnownMinValue();

VectorKind VecKind = VectorKind::RVVFixedLengthData;
unsigned ExpectedSize = VScale->first * MinElts;
QualType EltType = CurType->getRVVEltType(S.Context);
unsigned EltSize = S.Context.getTypeSize(EltType);
unsigned NumElts;
if (Info.ElementType == S.Context.BoolTy) {
NumElts = VecSize / S.Context.getCharWidth();
VecKind = VectorKind::RVVFixedLengthMask;
} else {
ExpectedSize *= EltSize;
NumElts = VecSize / EltSize;
}

// The attribute vector size must match -mrvv-vector-bits.
unsigned ExpectedSize = VScale->first * MinElts * EltSize;
if (VecSize != ExpectedSize) {
if (ExpectedSize % 8 != 0 || VecSize != ExpectedSize) {
S.Diag(Attr.getLoc(), diag::err_attribute_bad_rvv_vector_size)
<< VecSize << ExpectedSize;
Attr.setInvalid();
return;
}

VectorKind VecKind = VectorKind::RVVFixedLengthData;
VecSize /= EltSize;
CurType = S.Context.getVectorType(Info.ElementType, VecSize, VecKind);
CurType = S.Context.getVectorType(EltType, NumElts, VecKind);
}

/// Handle OpenCL Access Qualifier Attribute.
Expand Down
Loading
Loading