Skip to content

Commit

Permalink
[Backport to 16] fpbuiltin-max-error support
Browse files Browse the repository at this point in the history
Changes were cherry-picked from the following commit: c6fe12b
Also cherry-picked fixes from:

This changes add SPIR-V translator support for the SPIR-V extension documented here: KhronosGroup/SPIRV-Registry#193. This extension adds one decoration to represent maximum error for FP operations and adds the related Capability. SPIRV Headers support for representing this in SPIR-V: KhronosGroup/SPIRV-Headers#363

intel/llvm#8134 added a new call-site attribute associated with FP builtin intrinsics. This attribute is named 'fpbuiltin-max-error'. Following example shows how this extension is supported in the translator. The input LLVM IR uses new LLVM builtin calls to represent FP operations. An attribute named 'fpbuiltin-max-error' is used to represent the max-error allowed in the FP operation. Example
Input LLVM:
%t6 = call float @llvm.fpbuiltin.sin.f32(float %f1) KhronosGroup#2 attributes KhronosGroup#2 = { "fpbuiltin-max-error"="2.5" }

This is translated into a SPIR-V instruction (for add/sub/mul/div/rem) and OpenCl extended instruction for other instructions. A decoration to represent the max-error is attached to the SPIR-V instruction.

SPIR-V code:
4 Decorate 97 FPMaxErrorDecorationINTEL 1075838976 6 ExtInst 2 97 1 sin 88

No new support is added to support translating this SPIR_V back to LLVM. Existing support is used. The decoration is translated back into named metadata associated with the LLVM instruction. This can be readily consumed by backends.

Based on input from @andykaylor, we emit attributes when the FP operation is translated back to a call to a builtin function and emit metadata otherwise.

Translated LLVM code for basic math functions (add/sub/mul/div/rem): %t6 = fmul float %f1, %f2, !fpbuiltin-max-error !7 !7 = !{!"2.500000"}

Translated LLVM code for other math functions:
%t6 = call spir_func float @_Z3sinf(float %f1) KhronosGroup#3
attributes KhronosGroup#3 = { "fpbuiltin-max-error"="4.000000" }
  • Loading branch information
MiloszSkobejko committed Mar 20, 2024
1 parent 101de01 commit 09f4941
Show file tree
Hide file tree
Showing 11 changed files with 500 additions and 22 deletions.
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
EXT(SPV_INTEL_tensor_float32_rounding)
EXT(SPV_EXT_relaxed_printf_string_address_space)
EXT(SPV_INTEL_fpga_argument_interfaces)
EXT(SPV_INTEL_fp_max_error)
EXT(SPV_INTEL_cache_controls)
45 changes: 24 additions & 21 deletions lib/SPIRV/SPIRVBuiltinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ BuiltinCallMutator::BuiltinCallMutator(
CallInst *CI, std::string FuncName, ManglingRules Rules,
std::function<std::string(StringRef)> NameMapFn)
: CI(CI), FuncName(FuncName),
Attrs(CI->getCalledFunction()->getAttributes()), ReturnTy(CI->getType()),
Args(CI->args()), Rules(Rules), Builder(CI) {
Attrs(CI->getCalledFunction()->getAttributes()),
CallAttrs(CI->getAttributes()), ReturnTy(CI->getType()), Args(CI->args()),
Rules(Rules), Builder(CI) {
bool DidDemangle = getParameterTypes(CI->getCalledFunction(), PointerTypes,
std::move(NameMapFn));
if (!DidDemangle) {
Expand All @@ -78,8 +79,8 @@ BuiltinCallMutator::BuiltinCallMutator(
BuiltinCallMutator::BuiltinCallMutator(BuiltinCallMutator &&Other)
: CI(Other.CI), FuncName(std::move(Other.FuncName)),
MutateRet(std::move(Other.MutateRet)), Attrs(Other.Attrs),
ReturnTy(Other.ReturnTy), Args(std::move(Other.Args)),
PointerTypes(std::move(Other.PointerTypes)),
CallAttrs(Other.CallAttrs), ReturnTy(Other.ReturnTy),
Args(std::move(Other.Args)), PointerTypes(std::move(Other.PointerTypes)),
Rules(std::move(Other.Rules)), Builder(CI) {
// Clear the other's CI instance so that it knows not to construct the actual
// call.
Expand All @@ -102,6 +103,13 @@ Value *BuiltinCallMutator::doConversion() {
CallInst *NewCall =
Builder.Insert(addCallInst(CI->getModule(), FuncName, ReturnTy, Args,
&Attrs, nullptr, Mangler.get()));
NewCall->copyMetadata(*CI);
NewCall->setAttributes(CallAttrs);
NewCall->setTailCall(CI->isTailCall());
if (CI->hasFnAttr("fpbuiltin-max-error")) {
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
NewCall->addFnAttr(Attr);
}
Value *Result = MutateRet ? MutateRet(Builder, NewCall) : NewCall;
Result->takeName(CI);
if (!CI->getType()->isVoidTy())
Expand All @@ -116,6 +124,8 @@ BuiltinCallMutator &BuiltinCallMutator::setArgs(ArrayRef<Value *> NewArgs) {
// Retain only the function attributes, not any parameter attributes.
Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), {});
CallAttrs = AttributeList::get(CI->getContext(), CallAttrs.getFnAttrs(),
CallAttrs.getRetAttrs(), {});
Args.clear();
PointerTypes.clear();
for (Value *Arg : NewArgs) {
Expand Down Expand Up @@ -169,6 +179,8 @@ BuiltinCallMutator &BuiltinCallMutator::insertArg(unsigned Index,
PointerTypes.insert(PointerTypes.begin() + Index, Arg.second);
moveAttributes(CI->getContext(), Attrs, Index, Args.size() - Index,
Index + 1);
moveAttributes(CI->getContext(), CallAttrs, Index, Args.size() - Index,
Index + 1);
return *this;
}

Expand All @@ -177,30 +189,21 @@ BuiltinCallMutator &BuiltinCallMutator::replaceArg(unsigned Index,
Args[Index] = Arg.first;
PointerTypes[Index] = Arg.second;
Attrs = Attrs.removeParamAttributes(CI->getContext(), Index);
CallAttrs = CallAttrs.removeParamAttributes(CI->getContext(), Index);
return *this;
}

BuiltinCallMutator &BuiltinCallMutator::removeArg(unsigned Index) {
// If the argument being dropped is the last one, there is nothing to move, so
// just remove the attributes.
auto &Ctx = CI->getContext();
if (Index == Args.size() - 1) {
// TODO: Remove this workaround when LLVM fixes
// https://github.com/llvm/llvm-project/issues/59746 on
// AttributeList::removeParamAttributes function.
// AttributeList::removeParamAttributes function sets attribute at
// specified index empty so that return value of
// AttributeList::getNumAttrSet() keeps unchanged after that call. When call
// BuiltinCallMutator::removeArg function, there is assert failure on
// BuiltinCallMutator::doConversion() since new CallInst removed arg but
// still holds attribute of that removed arg.
SmallVector<AttributeSet, 4> ArgAttrs;
for (unsigned I = 0; I < Index; ++I)
ArgAttrs.push_back(Attrs.getParamAttrs(I));
Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(),
Attrs.getRetAttrs(), ArgAttrs);
} else
moveAttributes(CI->getContext(), Attrs, Index + 1, Args.size() - Index - 1,
Index);
Attrs = Attrs.removeParamAttributes(Ctx, Index);
CallAttrs = CallAttrs.removeParamAttributes(Ctx, Index);
} else {
moveAttributes(Ctx, Attrs, Index + 1, Args.size() - Index - 1, Index);
moveAttributes(Ctx, CallAttrs, Index + 1, Args.size() - Index - 1, Index);
}
Args.erase(Args.begin() + Index);
PointerTypes.erase(PointerTypes.begin() + Index);
return *this;
Expand Down
4 changes: 3 additions & 1 deletion lib/SPIRV/SPIRVBuiltinHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ class BuiltinCallMutator {
// the new instruction is created.
std::function<llvm::Value *(llvm::IRBuilder<> &, llvm::CallInst *)> MutateRet;
typedef decltype(MutateRet) MutateRetFuncTy;
// The attribute list for the new call instruction.
// The attribute list for the new called function.
llvm::AttributeList Attrs;
// The attribute list for the new call instruction.
llvm::AttributeList CallAttrs;
// The return type for the new call instruction.
llvm::Type *ReturnTy;
// The arguments for the new call instruction.
Expand Down
41 changes: 41 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,7 +3790,48 @@ void SPIRVToLLVM::transDecorationsToMetadata(SPIRVValue *BV, Value *V) {
SetDecorationsMetadata(I);
}

namespace {

static float convertSPIRVWordToFloat(SPIRVWord Spir) {
union {
float F;
SPIRVWord Spir;
} FPMaxError;
FPMaxError.Spir = Spir;
return FPMaxError.F;
}

static bool transFPMaxErrorDecoration(SPIRVValue *BV, Value *V,
LLVMContext *Context) {
SPIRVWord ID;
if (Instruction *I = dyn_cast<Instruction>(V))
if (BV->hasDecorate(DecorationFPMaxErrorDecorationINTEL, 0, &ID)) {
auto Literals =
BV->getDecorationLiterals(DecorationFPMaxErrorDecorationINTEL);
assert(Literals.size() == 1 &&
"FP Max Error decoration shall have 1 operand");
auto F = convertSPIRVWordToFloat(Literals[0]);
if (CallInst *CI = dyn_cast<CallInst>(I)) {
// Add attribute
auto A = llvm::Attribute::get(*Context, "fpbuiltin-max-error",
std::to_string(F));
CI->addFnAttr(A);
} else {
// Add metadata
MDNode *N =
MDNode::get(*Context, MDString::get(*Context, std::to_string(F)));
I->setMetadata("fpbuiltin-max-error", N);
}
return true;
}
return false;
}
} // namespace

bool SPIRVToLLVM::transDecoration(SPIRVValue *BV, Value *V) {
if (transFPMaxErrorDecoration(BV, V, Context))
return true;

if (!transAlign(BV, V))
return false;

Expand Down
153 changes: 153 additions & 0 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ using namespace llvm;
using namespace SPIRV;
using namespace OCLUtil;

namespace {

static SPIRVWord convertFloatToSPIRVWord(float F) {
union {
float F;
SPIRVWord Spir;
} FPMaxError;
FPMaxError.F = F;
return FPMaxError.Spir;
}

} // namespace

namespace SPIRV {

static void foreachKernelArgMD(
Expand Down Expand Up @@ -3499,6 +3512,26 @@ bool LLVMToSPIRVBase::isKnownIntrinsic(Intrinsic::ID Id) {
}
}

// Add decoration if needed
SPIRVInstruction *addFPBuiltinDecoration(SPIRVModule *BM, IntrinsicInst *II,
SPIRVInstruction *I) {
const bool AllowFPMaxError =
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_max_error);
assert(II->getCalledFunction()->getName().startswith("llvm.fpbuiltin"));
// Add a new decoration for llvm.builtin intrinsics, if needed
if (AllowFPMaxError)
if (II->getAttributes().hasFnAttr("fpbuiltin-max-error")) {
double F = 0.0;
II->getAttributes()
.getFnAttr("fpbuiltin-max-error")
.getValueAsString()
.getAsDouble(F);
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
convertFloatToSPIRVWord(F));
}
return I;
}

// Performs mapping of LLVM IR rounding mode to SPIR-V rounding mode
// Value *V is metadata <rounding mode> argument of
// llvm.experimental.constrained.* intrinsics
Expand Down Expand Up @@ -4215,6 +4248,8 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
}

default:
if (auto *BVar = transFPBuiltinIntrinsicInst(II, BB))
return BVar;
if (BM->isUnknownIntrinsicAllowed(II))
return BM->addCallInst(
transFunctionDecl(II->getCalledFunction()),
Expand All @@ -4230,6 +4265,124 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
return nullptr;
}

LLVMToSPIRVBase::FPBuiltinType
LLVMToSPIRVBase::getFPBuiltinType(IntrinsicInst *II, StringRef &OpName) {
StringRef Name = II->getCalledFunction()->getName();
if (!Name.startswith("llvm.fpbuiltin"))
return FPBuiltinType::UNKNOWN;
Name.consume_front("llvm.fpbuiltin.");
OpName = Name.split('.').first;
FPBuiltinType Type =
StringSwitch<FPBuiltinType>(OpName)
.Cases("fadd", "fsub", "fmul", "fdiv", "frem",
FPBuiltinType::REGULAR_MATH)
.Cases("sin", "cos", "tan", FPBuiltinType::EXT_1OPS)
.Cases("sinh", "cosh", "tanh", FPBuiltinType::EXT_1OPS)
.Cases("asin", "acos", "atan", FPBuiltinType::EXT_1OPS)
.Cases("asinh", "acosh", "atanh", FPBuiltinType::EXT_1OPS)
.Cases("exp", "exp2", "exp10", "expm1", FPBuiltinType::EXT_1OPS)
.Cases("log", "log2", "log10", "log1p", FPBuiltinType::EXT_1OPS)
.Cases("sqrt", "rsqrt", "erf", "erfc", FPBuiltinType::EXT_1OPS)
.Cases("atan2", "pow", "hypot", "ldexp", FPBuiltinType::EXT_2OPS)
.Case("sincos", FPBuiltinType::EXT_3OPS)
.Default(FPBuiltinType::UNKNOWN);
return Type;
}

SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
SPIRVBasicBlock *BB) {
StringRef OpName;
auto FPBuiltinTypeVal = getFPBuiltinType(II, OpName);
if (FPBuiltinTypeVal == FPBuiltinType::UNKNOWN)
return nullptr;
switch (FPBuiltinTypeVal) {
case FPBuiltinType::REGULAR_MATH: {
auto BinOp = StringSwitch<Op>(OpName)
.Case("fadd", OpFAdd)
.Case("fsub", OpFSub)
.Case("fmul", OpFMul)
.Case("fdiv", OpFDiv)
.Case("frem", OpFRem)
.Default(OpUndef);
auto *BI = BM->addBinaryInst(BinOp, transType(II->getType()),
transValue(II->getArgOperand(0), BB),
transValue(II->getArgOperand(1), BB), BB);
return addFPBuiltinDecoration(BM, II, BI);
}
case FPBuiltinType::EXT_1OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
break;
SPIRVType *STy = transType(II->getType());
std::vector<SPIRVValue *> Ops(1, transValue(II->getArgOperand(0), BB));
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
.Case("sin", OpenCLLIB::Sin)
.Case("cos", OpenCLLIB::Cos)
.Case("tan", OpenCLLIB::Tan)
.Case("sinh", OpenCLLIB::Sinh)
.Case("cosh", OpenCLLIB::Cosh)
.Case("tanh", OpenCLLIB::Tanh)
.Case("asin", OpenCLLIB::Asin)
.Case("acos", OpenCLLIB::Acos)
.Case("atan", OpenCLLIB::Atan)
.Case("asinh", OpenCLLIB::Asinh)
.Case("acosh", OpenCLLIB::Acosh)
.Case("atanh", OpenCLLIB::Atanh)
.Case("exp", OpenCLLIB::Exp)
.Case("exp2", OpenCLLIB::Exp2)
.Case("exp10", OpenCLLIB::Exp10)
.Case("expm1", OpenCLLIB::Expm1)
.Case("log", OpenCLLIB::Log)
.Case("log2", OpenCLLIB::Log2)
.Case("log10", OpenCLLIB::Log10)
.Case("log1p", OpenCLLIB::Log1p)
.Case("sqrt", OpenCLLIB::Sqrt)
.Case("rsqrt", OpenCLLIB::Rsqrt)
.Case("erf", OpenCLLIB::Erf)
.Case("erfc", OpenCLLIB::Erfc)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
}
case FPBuiltinType::EXT_2OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
break;
SPIRVType *STy = transType(II->getType());
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
transValue(II->getArgOperand(1), BB)};
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
.Case("atan2", OpenCLLIB::Atan2)
.Case("hypot", OpenCLLIB::Hypot)
.Case("pow", OpenCLLIB::Pow)
.Case("ldexp", OpenCLLIB::Ldexp)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
}
case FPBuiltinType::EXT_3OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
break;
SPIRVType *STy = transType(II->getType());
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
transValue(II->getArgOperand(1), BB),
transValue(II->getArgOperand(2), BB)};
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
.Case("sincos", OpenCLLIB::Sincos)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
}
default:
return nullptr;
}
return nullptr;
}

SPIRVValue *LLVMToSPIRVBase::transFenceInst(FenceInst *FI,
SPIRVBasicBlock *BB) {
SPIRVWord MemorySemantics;
Expand Down
10 changes: 10 additions & 0 deletions lib/SPIRV/SPIRVWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ class LLVMToSPIRVBase : protected BuiltinCallHelper {
bool transWorkItemBuiltinCallsToVariables();
bool isKnownIntrinsic(Intrinsic::ID Id);
SPIRVValue *transIntrinsicInst(IntrinsicInst *Intrinsic, SPIRVBasicBlock *BB);
enum class FPBuiltinType {
REGULAR_MATH,
EXT_1OPS,
EXT_2OPS,
EXT_3OPS,
UNKNOWN
};
FPBuiltinType getFPBuiltinType(IntrinsicInst *II, StringRef &);
SPIRVValue *transFPBuiltinIntrinsicInst(IntrinsicInst *II,
SPIRVBasicBlock *BB);
SPIRVValue *transFenceInst(FenceInst *FI, SPIRVBasicBlock *BB);
SPIRVValue *transCallInst(CallInst *Call, SPIRVBasicBlock *BB);
SPIRVValue *transDirectCallInst(CallInst *Call, SPIRVBasicBlock *BB);
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVDecorate.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class SPIRVDecorate : public SPIRVDecorateGeneric {
case DecorationMMHostInterfaceMaxBurstINTEL:
case DecorationMMHostInterfaceWaitRequestINTEL:
return ExtensionID::SPV_INTEL_fpga_argument_interfaces;
case DecorationFPMaxErrorDecorationINTEL:
return ExtensionID::SPV_INTEL_fp_max_error;
case internal::DecorationCacheControlLoadINTEL:
case internal::DecorationCacheControlStoreINTEL:
return ExtensionID::SPV_INTEL_cache_controls;
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ template <> inline void SPIRVMap<Decoration, SPIRVCapVec>::init() {
{CapabilityFPGAArgumentInterfacesINTEL});
ADD_VEC_INIT(DecorationStableKernelArgumentINTEL,
{CapabilityFPGAArgumentInterfacesINTEL});
ADD_VEC_INIT(DecorationFPMaxErrorDecorationINTEL,
{CapabilityFPMaxErrorINTEL});
}

template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ template <> inline void SPIRVMap<Decoration, std::string>::init() {
add(DecorationMMHostInterfaceWaitRequestINTEL,
"MMHostInterfaceWaitRequestINTEL");
add(DecorationStableKernelArgumentINTEL, "StableKernelArgumentINTEL");
add(DecorationFPMaxErrorDecorationINTEL, "FPMaxErrorDecorationINTEL");

// From spirv_internal.hpp
add(internal::DecorationCallableFunctionINTEL, "CallableFunctionINTEL");
Expand Down Expand Up @@ -618,6 +619,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityRuntimeAlignedAttributeINTEL, "RuntimeAlignedAttributeINTEL");
add(CapabilityMax, "Max");
add(CapabilityFPGAArgumentInterfacesINTEL, "FPGAArgumentInterfacesINTEL");
add(CapabilityFPMaxErrorINTEL, "FPMaxErrorINTEL");

// From spirv_internal.hpp
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");
Expand Down
Loading

0 comments on commit 09f4941

Please sign in to comment.