Skip to content

Commit

Permalink
Support struct-type to apply waterfall.end intrinsic
Browse files Browse the repository at this point in the history
The waterfall intrinsics can only accept simple types. The backend
support has been merged. This is a front-end update. We extend
`CreateMapToInt32` to support structs and rename it to
`CreateMapToSimpleType`, `MapToInt32Func` to `MapToSimpleTypeFunc`.

The implementation of `createWaterfallLoop` then boils down to a single
call of the new `CreateMapToSimpleType` method.
  • Loading branch information
xuechen417 committed Oct 7, 2023
1 parent 6680b1e commit 79318ed
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ CheckOptions:
- { key: readability-identifier-naming.MemberPrefix, value: m_ }
- { key: readability-identifier-naming.MemberRemovePrefixes, value: 'p,b,pfn,m_p,m_b' }
- { key: readability-identifier-naming.MethodCase, value: camelBack }
- { key: readability-identifier-naming.MethodIgnoredRegexp, value: '^Create$|^CreateACos$|^CreateACosh$|^CreateASin$|^CreateASinh$|^CreateATan$|^CreateATan2$|^CreateATanh$|^CreateBarrier$|^CreateBinaryIntrinsic$|^CreateCosh$|^CreateCrossProduct$|^CreateCubeFace.*$|^CreateDemoteToHelperInvocation$|^CreateDerivative$|^CreateDeterminant$|^CreateDotProduct$|^CreateEmitVertex$|^CreateEndPrimitive$|^CreateExp$|^CreateExtract.*$|^CreateFaceForward$|^CreateFClamp$|^CreateFindSMsb$|^CreateFma$|^CreateFMax$|^CreateFMax3$|^CreateFMid3$|^CreateFMin$|^CreateFMin3$|^CreateFMod$|^CreateFpTruncWithRounding$|^CreateFract$|^CreateFSign$|^CreateGet.*$|^CreateImage.*$|^CreateIndexDescPtr$|^CreateInsertBitField$|^CreateIntrinsic$|^CreateInverseSqrt$|^CreateIs.*$|^CreateKill$|^CreateLdexp$|^CreateLoad.*$|^CreateLog$|^CreateMapToInt32$|^CreateMatrix.*$|^CreateNormalizeVector$|^CreateOuterProduct$|^CreatePower$|^CreateQuantizeToFp16$|^CreateRead.*$|^CreateReflect$|^CreateRefract$|^CreateSAbs$|^CreateSinh$|^CreateSMod$|^CreateSmoothStep$|^CreateSSign$|^CreateSubgroup.*$|^CreateTan$|^CreateTanh$|^CreateTransposeMatrix$|^CreateUnaryIntrinsic$|^CreateVectorTimesMatrix$|^CreateWrite.*Output$|^Serialize$|^Merge$|^Destroy$|^ConvertColorBufferFormatToExportFormat$|^BuildShaderModule$|^BuildGraphicsPipeline$|^BuildComputePipeline$|^IsVertexFormatSupported$|^DumpSpirvBinary$|^BeginPipelineDump$|^EndPipelineDump$|^DumpPipelineBinary$|^DumpPipelineExtraInfo$|^GetShaderHash$|^GetPipelineHash$|^GetPipelineName$|^CreateShaderCache$|^ReadFromBuffer$|^GetSectionIndex$|^GetSymbolsBySectionIndex$|^GetSectionData$' }
- { key: readability-identifier-naming.MethodIgnoredRegexp, value: '^Create$|^CreateACos$|^CreateACosh$|^CreateASin$|^CreateASinh$|^CreateATan$|^CreateATan2$|^CreateATanh$|^CreateBarrier$|^CreateBinaryIntrinsic$|^CreateCosh$|^CreateCrossProduct$|^CreateCubeFace.*$|^CreateDemoteToHelperInvocation$|^CreateDerivative$|^CreateDeterminant$|^CreateDotProduct$|^CreateEmitVertex$|^CreateEndPrimitive$|^CreateExp$|^CreateExtract.*$|^CreateFaceForward$|^CreateFClamp$|^CreateFindSMsb$|^CreateFma$|^CreateFMax$|^CreateFMax3$|^CreateFMid3$|^CreateFMin$|^CreateFMin3$|^CreateFMod$|^CreateFpTruncWithRounding$|^CreateFract$|^CreateFSign$|^CreateGet.*$|^CreateImage.*$|^CreateIndexDescPtr$|^CreateInsertBitField$|^CreateIntrinsic$|^CreateInverseSqrt$|^CreateIs.*$|^CreateKill$|^CreateLdexp$|^CreateLoad.*$|^CreateLog$|^CreateMapToSimpleType$|^CreateMatrix.*$|^CreateNormalizeVector$|^CreateOuterProduct$|^CreatePower$|^CreateQuantizeToFp16$|^CreateRead.*$|^CreateReflect$|^CreateRefract$|^CreateSAbs$|^CreateSinh$|^CreateSMod$|^CreateSmoothStep$|^CreateSSign$|^CreateSubgroup.*$|^CreateTan$|^CreateTanh$|^CreateTransposeMatrix$|^CreateUnaryIntrinsic$|^CreateVectorTimesMatrix$|^CreateWrite.*Output$|^Serialize$|^Merge$|^Destroy$|^ConvertColorBufferFormatToExportFormat$|^BuildShaderModule$|^BuildGraphicsPipeline$|^BuildComputePipeline$|^IsVertexFormatSupported$|^DumpSpirvBinary$|^BeginPipelineDump$|^EndPipelineDump$|^DumpPipelineBinary$|^DumpPipelineExtraInfo$|^GetShaderHash$|^GetPipelineHash$|^GetPipelineName$|^CreateShaderCache$|^ReadFromBuffer$|^GetSectionIndex$|^GetSymbolsBySectionIndex$|^GetSectionData$' }
- { key: readability-identifier-naming.FunctionIgnoredRegexp, value: 'EnableOuts|EnableErrs' }
- { key: readability-identifier-naming.FunctionCase, value: camelBack }
- { key: readability-identifier-naming.TypeCase, value: CamelCase }
Expand Down
118 changes: 90 additions & 28 deletions lgc/builder/BuilderBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,17 @@ Value *BuilderBase::CreateAddByteOffset(Value *pointer, Value *byteOffset, const
}

// =====================================================================================================================
// Create a map to i32 function. Many AMDGCN intrinsics only take i32's, so we need to massage input data into an i32
// to allow us to call these intrinsics. This helper takes a function pointer, massage arguments, and passthrough
// arguments and massages the mappedArgs into i32's before calling the function pointer. Note that all massage
// arguments must have the same type.
// Create a map to simple type function. Many AMDGCN intrinsics only take MapToSimpleTypeMode, so we need to massage
// input data into a simple type mode to allow us to call these intrinsics. This helper takes a function pointer,
// massage arguments, and passthrough arguments and massages the mappedArgs into simple type mode before calling the
// function pointer. Note that all massage arguments must have the same type.
//
// @param mapFunc : The function to call on each provided i32.
// @param mappedArgs : The arguments to be massaged into i32's and passed to function.
// @param mapFunc : The function to call on each provided simple type mode.
// @param mappedArgs : The arguments to be massaged into simple type mode and passed to function.
// @param passthroughArgs : The arguments to be passed through as is (no massaging).
Value *BuilderBase::CreateMapToInt32(MapToInt32Func mapFunc, ArrayRef<Value *> mappedArgs,
ArrayRef<Value *> passthroughArgs) {
// @param simpleMode : The arguments to specify the simple type mode
Value *BuilderBase::CreateMapToSimpleType(MapToSimpleTypeFunc mapFunc, ArrayRef<Value *> mappedArgs,
ArrayRef<Value *> passthroughArgs, MapToSimpleMode simpleMode) {
// We must have at least one argument to massage.
assert(mappedArgs.size() > 0);

Expand All @@ -172,35 +173,60 @@ Value *BuilderBase::CreateMapToInt32(MapToInt32Func mapFunc, ArrayRef<Value *> m
for (unsigned i = 1; i < mappedArgs.size(); i++)
assert(mappedArgs[i]->getType() == type);

if (mappedArgs[0]->getType()->isVectorTy()) {
// For vectors we extract each vector component and map them individually.
const unsigned compCount = cast<FixedVectorType>(type)->getNumElements();

SmallVector<Value *, 4> results;

for (unsigned i = 0; i < compCount; i++) {
if (type->isStructTy()) {
assert(simpleMode == MapToSimpleMode::SimpleVector);
// For struct we extract each member and map them individually.
const unsigned memberCount = type->getStructNumElements();
SmallVector<Value *> results;
for (unsigned i = 0; i < memberCount; ++i) {
SmallVector<Value *, 4> newMappedArgs;

for (Value *const mappedArg : mappedArgs)
newMappedArgs.push_back(CreateExtractElement(mappedArg, i));
newMappedArgs.push_back(CreateExtractValue(mappedArg, i));

results.push_back(CreateMapToInt32(mapFunc, newMappedArgs, passthroughArgs));
results.push_back(CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs, MapToSimpleMode::SimpleVector));
}

Value *result = PoisonValue::get(FixedVectorType::get(results[0]->getType(), compCount));

for (unsigned i = 0; i < compCount; i++)
result = CreateInsertElement(result, results[i], i);
Value *result = PoisonValue::get(type);
for (unsigned i = 0; i < memberCount; ++i)
result = CreateInsertValue(result, results[i], i);

return result;
}
if (type->isVectorTy()) {
if (simpleMode == MapToSimpleMode::Int32) {
// For vectors we extract each vector component and map them individually.
const unsigned compCount = cast<FixedVectorType>(type)->getNumElements();

SmallVector<Value *, 4> results;

for (unsigned i = 0; i < compCount; i++) {
SmallVector<Value *, 4> newMappedArgs;

for (Value *const mappedArg : mappedArgs)
newMappedArgs.push_back(CreateExtractElement(mappedArg, i));

results.push_back(CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs));
}

Value *result = PoisonValue::get(FixedVectorType::get(results[0]->getType(), compCount));

for (unsigned i = 0; i < compCount; i++)
result = CreateInsertElement(result, results[i], i);

return result;
} else if (simpleMode == MapToSimpleMode::SimpleVector) {
return mapFunc(*this, mappedArgs, passthroughArgs);
} else {
llvm_unreachable("Unhandled simple mode");
}
}
if (type->isIntegerTy() && type->getIntegerBitWidth() == 1) {
SmallVector<Value *, 4> newMappedArgs;

for (Value *const mappedArg : mappedArgs)
newMappedArgs.push_back(CreateZExt(mappedArg, getInt32Ty()));

Value *const result = CreateMapToInt32(mapFunc, newMappedArgs, passthroughArgs);
Value *const result = CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs);
return CreateTrunc(result, getInt1Ty());
}
if (type->isIntegerTy() && type->getIntegerBitWidth() < 32) {
Expand All @@ -214,7 +240,7 @@ Value *BuilderBase::CreateMapToInt32(MapToInt32Func mapFunc, ArrayRef<Value *> m
newMappedArgs.push_back(CreateBitCast(newMappedArg, getInt32Ty()));
}

Value *const result = CreateMapToInt32(mapFunc, newMappedArgs, passthroughArgs);
Value *const result = CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs);
return CreateExtractElement(CreateBitCast(result, vectorType), static_cast<uint64_t>(0));
}
if (type->getPrimitiveSizeInBits() == 64) {
Expand All @@ -231,7 +257,7 @@ Value *BuilderBase::CreateMapToInt32(MapToInt32Func mapFunc, ArrayRef<Value *> m
for (Value *const castMappedArg : castMappedArgs)
newMappedArgs.push_back(CreateExtractElement(castMappedArg, i));

Value *const resultComp = CreateMapToInt32(mapFunc, newMappedArgs, passthroughArgs);
Value *const resultComp = CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs);

result = CreateInsertElement(result, resultComp, i);
}
Expand All @@ -244,7 +270,7 @@ Value *BuilderBase::CreateMapToInt32(MapToInt32Func mapFunc, ArrayRef<Value *> m
for (Value *const mappedArg : mappedArgs)
newMappedArgs.push_back(CreateBitCast(mappedArg, getIntNTy(mappedArg->getType()->getPrimitiveSizeInBits())));

Value *const result = CreateMapToInt32(mapFunc, newMappedArgs, passthroughArgs);
Value *const result = CreateMapToSimpleType(mapFunc, newMappedArgs, passthroughArgs);
return CreateBitCast(result, type);
}
if (type->isIntegerTy(32))
Expand All @@ -266,7 +292,7 @@ Value *BuilderBase::CreateInlineAsmSideEffect(Value *const value) {
return builder.CreateCall(inlineAsm, value);
};

return CreateMapToInt32(mapFunc, value, {});
return CreateMapToSimpleType(mapFunc, value, {});
}

// =====================================================================================================================
Expand All @@ -281,5 +307,41 @@ Value *BuilderBase::CreateSetInactive(Value *active, Value *inactive) {
return builder.CreateIntrinsic(Intrinsic::amdgcn_set_inactive, active->getType(), {active, inactive});
};

return CreateMapToInt32(mapFunc, {active, inactive}, {});
return CreateMapToSimpleType(mapFunc, {active, inactive}, {});
}

// =====================================================================================================================
// Create a waterfall end intrinsic.
//
// @param nonUniform: The instruction to put in a end waterfall loop.
// @param waterfallBegin: The waterfall begin intrinsic.
Instruction *BuilderBase::CreateWaterfallEnd(Value *nonUniform, Value *waterfallBegin) {

auto nonUniformInst = cast<Instruction>(nonUniform);
Instruction *resultValue = nonUniformInst;

// End the waterfall loop (as long as nonUniformInst is not a store with no result).
if (!nonUniformInst->getType()->isVoidTy()) {
SetInsertPoint(nonUniformInst->getNextNode());
SetCurrentDebugLocation(nonUniformInst->getDebugLoc());

Type *waterfallEndTy = resultValue->getType();
if (auto vecTy = dyn_cast<FixedVectorType>(waterfallEndTy)) {
if (vecTy->getElementType()->isIntegerTy(8)) {
// ISel does not like waterfall.end with vector of i8 type, so cast if necessary.
assert((vecTy->getNumElements() % 4) == 0);
waterfallEndTy = getInt32Ty();
if (vecTy->getNumElements() != 4)
waterfallEndTy = FixedVectorType::get(getInt32Ty(), vecTy->getNumElements() / 4);
resultValue = cast<Instruction>(CreateBitCast(resultValue, waterfallEndTy));
}
}
resultValue =
CreateIntrinsic(Intrinsic::amdgcn_waterfall_end, waterfallEndTy, {waterfallBegin, resultValue}, nullptr);

if (waterfallEndTy != nonUniformInst->getType())
resultValue = cast<Instruction>(CreateBitCast(resultValue, nonUniformInst->getType()));
}

return resultValue;
}
42 changes: 10 additions & 32 deletions lgc/builder/BuilderImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,38 +679,16 @@ Instruction *BuilderImpl::createWaterfallLoop(Instruction *nonUniformInst, Array
}
}

Instruction *resultValue = nonUniformInst;

// End the waterfall loop (as long as nonUniformInst is not a store with no result).
if (!nonUniformInst->getType()->isVoidTy()) {
SetInsertPoint(nonUniformInst->getNextNode());
SetCurrentDebugLocation(nonUniformInst->getDebugLoc());

Use *useOfNonUniformInst = nullptr;
Type *waterfallEndTy = resultValue->getType();
if (auto vecTy = dyn_cast<FixedVectorType>(waterfallEndTy)) {
if (vecTy->getElementType()->isIntegerTy(8)) {
// ISel does not like waterfall.end with vector of i8 type, so cast if necessary.
assert((vecTy->getNumElements() % 4) == 0);
waterfallEndTy = getInt32Ty();
if (vecTy->getNumElements() != 4)
waterfallEndTy = FixedVectorType::get(getInt32Ty(), vecTy->getNumElements() / 4);
resultValue = cast<Instruction>(CreateBitCast(resultValue, waterfallEndTy, instName));
useOfNonUniformInst = &resultValue->getOperandUse(0);
}
}
resultValue = CreateIntrinsic(Intrinsic::amdgcn_waterfall_end, waterfallEndTy, {waterfallBegin, resultValue},
nullptr, instName);
if (!useOfNonUniformInst)
useOfNonUniformInst = &resultValue->getOperandUse(1);
if (waterfallEndTy != nonUniformInst->getType())
resultValue = cast<Instruction>(CreateBitCast(resultValue, nonUniformInst->getType(), instName));

// Replace all uses of nonUniformInst with the result of this code.
*useOfNonUniformInst = PoisonValue::get(nonUniformInst->getType());
nonUniformInst->replaceAllUsesWith(resultValue);
*useOfNonUniformInst = nonUniformInst;
}
if (nonUniformInst->getType()->isVoidTy())
return nonUniformInst;

auto mapFunc = [](BuilderBase &builder, ArrayRef<Value *> mappedArgs, ArrayRef<Value *> passthroughArgs) -> Value * {
return builder.CreateWaterfallEnd(mappedArgs[0], passthroughArgs[0]);
};

SetInsertPoint(nonUniformInst->getNextNode());
auto resultValue =
cast<Instruction>(CreateMapToSimpleType(mapFunc, nonUniformInst, waterfallBegin, MapToSimpleMode::SimpleVector));

return resultValue;
#endif
Expand Down
Loading

0 comments on commit 79318ed

Please sign in to comment.