diff --git a/lib/SPIRV/SPIRVBuiltinHelper.cpp b/lib/SPIRV/SPIRVBuiltinHelper.cpp index a648cbac6d..a184f409e1 100644 --- a/lib/SPIRV/SPIRVBuiltinHelper.cpp +++ b/lib/SPIRV/SPIRVBuiltinHelper.cpp @@ -62,8 +62,9 @@ BuiltinCallMutator::BuiltinCallMutator( CallInst *CI, std::string FuncName, ManglingRules Rules, std::function NameMapFn) : CI(CI), FuncName(FuncName), - Attrs(CI->getCalledFunction()->getAttributes()), ReturnTy(CI->getType()), - Args(CI->args()), Rules(Rules), Builder(CI) { + Attrs(CI->getCalledFunction()->getAttributes()), + CallAttrs(CI->getAttributes()), ReturnTy(CI->getType()), Args(CI->args()), + Rules(Rules), Builder(CI) { bool DidDemangle = getParameterTypes(CI->getCalledFunction(), PointerTypes, std::move(NameMapFn)); if (!DidDemangle) { @@ -78,8 +79,8 @@ BuiltinCallMutator::BuiltinCallMutator( BuiltinCallMutator::BuiltinCallMutator(BuiltinCallMutator &&Other) : CI(Other.CI), FuncName(std::move(Other.FuncName)), MutateRet(std::move(Other.MutateRet)), Attrs(Other.Attrs), - ReturnTy(Other.ReturnTy), Args(std::move(Other.Args)), - PointerTypes(std::move(Other.PointerTypes)), + CallAttrs(Other.CallAttrs), ReturnTy(Other.ReturnTy), + Args(std::move(Other.Args)), PointerTypes(std::move(Other.PointerTypes)), Rules(std::move(Other.Rules)), Builder(CI) { // Clear the other's CI instance so that it knows not to construct the actual // call. @@ -103,7 +104,8 @@ Value *BuiltinCallMutator::doConversion() { Builder.Insert(addCallInst(CI->getModule(), FuncName, ReturnTy, Args, &Attrs, nullptr, Mangler.get())); NewCall->copyMetadata(*CI); - NewCall->setAttributes(CI->getAttributes()); + NewCall->setAttributes(CallAttrs); + NewCall->setTailCall(CI->isTailCall()); Value *Result = MutateRet ? MutateRet(Builder, NewCall) : NewCall; Result->takeName(CI); if (!CI->getType()->isVoidTy()) @@ -118,6 +120,8 @@ BuiltinCallMutator &BuiltinCallMutator::setArgs(ArrayRef NewArgs) { // Retain only the function attributes, not any parameter attributes. Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), {}); + CallAttrs = AttributeList::get(CI->getContext(), CallAttrs.getFnAttrs(), + CallAttrs.getRetAttrs(), {}); Args.clear(); PointerTypes.clear(); for (Value *Arg : NewArgs) { @@ -171,6 +175,8 @@ BuiltinCallMutator &BuiltinCallMutator::insertArg(unsigned Index, PointerTypes.insert(PointerTypes.begin() + Index, Arg.second); moveAttributes(CI->getContext(), Attrs, Index, Args.size() - Index, Index + 1); + moveAttributes(CI->getContext(), CallAttrs, Index, Args.size() - Index, + Index + 1); return *this; } @@ -179,30 +185,21 @@ BuiltinCallMutator &BuiltinCallMutator::replaceArg(unsigned Index, Args[Index] = Arg.first; PointerTypes[Index] = Arg.second; Attrs = Attrs.removeParamAttributes(CI->getContext(), Index); + CallAttrs = CallAttrs.removeParamAttributes(CI->getContext(), Index); return *this; } BuiltinCallMutator &BuiltinCallMutator::removeArg(unsigned Index) { // If the argument being dropped is the last one, there is nothing to move, so // just remove the attributes. + auto &Ctx = CI->getContext(); if (Index == Args.size() - 1) { - // TODO: Remove this workaround when LLVM fixes - // https://github.com/llvm/llvm-project/issues/59746 on - // AttributeList::removeParamAttributes function. - // AttributeList::removeParamAttributes function sets attribute at - // specified index empty so that return value of - // AttributeList::getNumAttrSet() keeps unchanged after that call. When call - // BuiltinCallMutator::removeArg function, there is assert failure on - // BuiltinCallMutator::doConversion() since new CallInst removed arg but - // still holds attribute of that removed arg. - SmallVector ArgAttrs; - for (unsigned I = 0; I < Index; ++I) - ArgAttrs.push_back(Attrs.getParamAttrs(I)); - Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(), - Attrs.getRetAttrs(), ArgAttrs); - } else - moveAttributes(CI->getContext(), Attrs, Index + 1, Args.size() - Index - 1, - Index); + Attrs = Attrs.removeParamAttributes(Ctx, Index); + CallAttrs = CallAttrs.removeParamAttributes(Ctx, Index); + } else { + moveAttributes(Ctx, Attrs, Index + 1, Args.size() - Index - 1, Index); + moveAttributes(Ctx, CallAttrs, Index + 1, Args.size() - Index - 1, Index); + } Args.erase(Args.begin() + Index); PointerTypes.erase(PointerTypes.begin() + Index); return *this; diff --git a/lib/SPIRV/SPIRVBuiltinHelper.h b/lib/SPIRV/SPIRVBuiltinHelper.h index da5ae0283a..e4b955e82e 100644 --- a/lib/SPIRV/SPIRVBuiltinHelper.h +++ b/lib/SPIRV/SPIRVBuiltinHelper.h @@ -74,8 +74,10 @@ class BuiltinCallMutator { // the new instruction is created. std::function &, llvm::CallInst *)> MutateRet; typedef decltype(MutateRet) MutateRetFuncTy; - // The attribute list for the new call instruction. + // The attribute list for the new called function. llvm::AttributeList Attrs; + // The attribute list for the new call instruction. + llvm::AttributeList CallAttrs; // The return type for the new call instruction. llvm::Type *ReturnTy; // The arguments for the new call instruction.