Skip to content

Commit

Permalink
Set condition to use parentId/RayStaticId
Browse files Browse the repository at this point in the history
These two function parameters are used for developer GPURT logging.  In most
cases, these two variables are not used and would generate two more
vgprs for the indirect/continuation shaders. So guard them with
enableRayTracingCounter

So do a little refactoring for the  RayStaticIdOp processing
  • Loading branch information
jiaolu committed Oct 9, 2023
1 parent 79318ed commit 5f60cae
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 65 deletions.
6 changes: 0 additions & 6 deletions lgc/interface/lgc/GpurtDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,6 @@ def GpurtSetParentIdOp : GpurtOp<"set.parent.id", [Memory<[(write InaccessibleMe
let summary = "Store TraceRay rayId";
}

def GpurtSetRayStaticIdOp : GpurtOp<"set.ray.static.id", [Memory<[(write InaccessibleMem)]>, WillReturn]> {
let arguments = (ins I32:$id);
let results = (outs);
let summary = "set a unique static ID for a ray";
}

def GpurtGetRayStaticIdOp : GpurtOp<"get.ray.static.id", [Memory<[(read InaccessibleMem)]>, WillReturn]> {
let arguments = (ins);
let results = (outs I32:$result);
Expand Down
42 changes: 1 addition & 41 deletions llpc/lower/LowerGpuRt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static const char *LdsStack = "LdsStack";

namespace Llpc {
// =====================================================================================================================
LowerGpuRt::LowerGpuRt() : m_stack(nullptr), m_stackTy(nullptr), m_lowerStack(false), m_rayStaticId(nullptr) {
LowerGpuRt::LowerGpuRt() : m_stack(nullptr), m_stackTy(nullptr), m_lowerStack(false) {
}
// =====================================================================================================================
// Executes this SPIR-V lowering pass on the specified LLVM module.
Expand All @@ -62,7 +62,6 @@ PreservedAnalyses LowerGpuRt::run(Module &module, ModuleAnalysisManager &analysi
m_lowerStack = (m_entryPoint->getName().startswith("_ahit") || m_entryPoint->getName().startswith("_sect")) &&
(gfxip.major < 11);
createGlobalStack();
createRayStaticIdValue();

static auto visitor = llvm_dialects::VisitorBuilder<LowerGpuRt>()
.setStrategy(llvm_dialects::VisitorStrategy::ByFunctionDeclaration)
Expand All @@ -77,8 +76,6 @@ PreservedAnalyses LowerGpuRt::run(Module &module, ModuleAnalysisManager &analysi
.add(&LowerGpuRt::visitGetStaticFlags)
.add(&LowerGpuRt::visitGetTriangleCompressionMode)
.add(&LowerGpuRt::visitGetFlattenedGroupThreadId)
.add(&LowerGpuRt::visitSetRayStaticId)
.add(&LowerGpuRt::visitGetRayStaticId)
.build();

visitor.visit(*this, *m_module);
Expand Down Expand Up @@ -142,13 +139,6 @@ void LowerGpuRt::createGlobalStack() {
m_stack = ldsStack;
}

// =====================================================================================================================
// Create ray static ID value
void LowerGpuRt::createRayStaticIdValue() {
m_builder->SetInsertPointPastAllocas(m_entryPoint);
m_rayStaticId = m_builder->CreateAlloca(m_builder->getInt32Ty());
}

// =====================================================================================================================
// Visit "GpurtGetStackSizeOp" instruction
//
Expand Down Expand Up @@ -351,34 +341,4 @@ void LowerGpuRt::visitGetFlattenedGroupThreadId(GpurtGetFlattenedGroupThreadIdOp
m_funcsToLower.insert(inst.getCalledFunction());
}

// =====================================================================================================================
// Visit "GpurtSetRayStaticIdOp" instruction
//
// @param inst : The dialect instruction to process
void LowerGpuRt::visitSetRayStaticId(GpurtSetRayStaticIdOp &inst) {
m_builder->SetInsertPoint(&inst);

assert(m_rayStaticId);
auto rayStaticId = inst.getId();
auto storeInst = m_builder->CreateStore(rayStaticId, m_rayStaticId);

inst.replaceAllUsesWith(storeInst);
m_callsToLower.push_back(&inst);
m_funcsToLower.insert(inst.getCalledFunction());
}

// =====================================================================================================================
// Visit "GpurtGetRayStaticIdOp" instruction
//
// @param inst : The dialect instruction to process
void LowerGpuRt::visitGetRayStaticId(GpurtGetRayStaticIdOp &inst) {
m_builder->SetInsertPoint(&inst);

assert(m_rayStaticId);
auto rayStaticId = m_builder->CreateLoad(m_builder->getInt32Ty(), m_rayStaticId);
inst.replaceAllUsesWith(rayStaticId);
m_callsToLower.push_back(&inst);
m_funcsToLower.insert(inst.getCalledFunction());
}

} // namespace Llpc
5 changes: 0 additions & 5 deletions llpc/lower/LowerGpuRt.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class GpurtGetBoxSortHeuristicModeOp;
class GpurtGetStaticFlagsOp;
class GpurtGetTriangleCompressionModeOp;
class GpurtGetFlattenedGroupThreadIdOp;
class GpurtSetRayStaticIdOp;
class GpurtGetRayStaticIdOp;
} // namespace lgc

namespace llvm {
Expand Down Expand Up @@ -78,13 +76,10 @@ class LowerGpuRt : public SpirvLower, public llvm::PassInfoMixin<LowerGpuRt> {
void visitGetStaticFlags(lgc::GpurtGetStaticFlagsOp &inst);
void visitGetTriangleCompressionMode(lgc::GpurtGetTriangleCompressionModeOp &inst);
void visitGetFlattenedGroupThreadId(lgc::GpurtGetFlattenedGroupThreadIdOp &inst);
void visitSetRayStaticId(lgc::GpurtSetRayStaticIdOp &inst);
void visitGetRayStaticId(lgc::GpurtGetRayStaticIdOp &inst);
llvm::Value *m_stack; // Stack array to hold stack value
llvm::Type *m_stackTy; // Stack type
bool m_lowerStack; // If it is lowerStack
llvm::SmallVector<llvm::Instruction *> m_callsToLower; // Call instruction to lower
llvm::SmallSet<llvm::Function *, 4> m_funcsToLower; // Functions to lower
llvm::Value *m_rayStaticId; // Ray static ID value
};
} // namespace Llpc
4 changes: 2 additions & 2 deletions llpc/lower/llpcSpirvLowerRayQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,15 +1262,15 @@ void SpirvLowerRayQuery::initGlobalVariable() {
// =====================================================================================================================
// Generate a static ID for current Trace Ray call
//
void SpirvLowerRayQuery::generateTraceRayStaticId() {
unsigned SpirvLowerRayQuery::generateTraceRayStaticId() {
Util::MetroHash64 hasher;
hasher.Update(m_nextTraceRayId++);
hasher.Update(m_module->getName());

MetroHash::Hash hash = {};
hasher.Finalize(hash.bytes);

m_builder->create<lgc::GpurtSetRayStaticIdOp>(m_builder->getInt32(MetroHash::compact32(&hash)));
return MetroHash::compact32(&hash);
}

// =====================================================================================================================
Expand Down
2 changes: 1 addition & 1 deletion llpc/lower/llpcSpirvLowerRayQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class SpirvLowerRayQuery : public SpirvLower, public llvm::PassInfoMixin<SpirvLo
void createGlobalLdsUsage();
void createGlobalRayQueryObj();
void initGlobalVariable();
void generateTraceRayStaticId();
unsigned generateTraceRayStaticId();
llvm::Value *createTransformMatrix(unsigned builtInId, llvm::Value *accelStruct, llvm::Value *instanceId,
llvm::Instruction *insertPos);
void eraseFunctionBlocks(llvm::Function *func);
Expand Down
46 changes: 36 additions & 10 deletions llpc/lower/llpcSpirvLowerRayTracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ static unsigned TraceParamsTySize[] = {
8, // 15, hit attribute
1, // 16, parentId
9, // 17, HitTriangleVertexPositions
1, // 18, Payload
1, // 18, Payload,
1, // 19, RayStaticId
};

// =====================================================================================================================
Expand All @@ -115,12 +116,13 @@ void SpirvLowerRayTracing::processTraceRayCall(BaseTraceRayOp *inst) {

auto rayTracingContext = static_cast<RayTracingContext *>(m_context->getPipelineContext());

implCallArgs.push_back(m_traceParams[TraceParam::RayStaticId]);
implCallArgs.push_back(m_traceParams[TraceParam::ParentRayId]);
implCallArgs.push_back(m_dispatchRaysInfoDesc);
m_builder->SetInsertPoint(inst);

// Generate a unique static ID for each trace ray call
generateTraceRayStaticId();
m_builder->CreateStore(m_builder->getInt32(generateTraceRayStaticId()), m_traceParams[TraceParam::RayStaticId]);
auto newCall = m_builder->CreateNamedCall(mangledName, inst->getFunctionType()->getReturnType(), implCallArgs,
{Attribute::NoUnwind, Attribute::AlwaysInline});

Expand Down Expand Up @@ -160,13 +162,17 @@ void SpirvLowerRayTracing::processTraceRayCall(BaseTraceRayOp *inst) {
args.push_back(func->getArg(i));

Value *parentRayId = func->arg_end() - 2;
Value *rayStaticId = func->arg_end() - 3;

// RayGen shaders are non-recursive, initialize parent ray ID to -1 here.
if (m_shaderStage == ShaderStageRayTracingRayGen)
m_builder->CreateStore(m_builder->getInt32(InvalidValue), parentRayId);

Value *currentParentRayId = m_builder->CreateLoad(m_builder->getInt32Ty(), parentRayId);
args.push_back(currentParentRayId);
args.push_back(m_builder->create<lgc::GpurtGetRayStaticIdOp>());
if (m_context->getPipelineContext()->getRayTracingState()->enableRayTracingCounters) {
args.push_back(currentParentRayId);
args.push_back(m_builder->CreateLoad(m_builder->getInt32Ty(), rayStaticId));
}

CallInst *result = nullptr;
auto funcTy = getTraceRayFuncTy();
Expand Down Expand Up @@ -574,6 +580,7 @@ PreservedAnalyses SpirvLowerRayTracing::run(Module &module, ModuleAnalysisManage
.add(&SpirvLowerRayTracing::visitSetHitTriangleNodePointer)
.add(&SpirvLowerRayTracing::visitGetParentId)
.add(&SpirvLowerRayTracing::visitSetParentId)
.add(&SpirvLowerRayTracing::visitGetRayStaticId)
.add(&SpirvLowerRayTracing::visitDispatchRayIndex)
.build();

Expand Down Expand Up @@ -1600,10 +1607,12 @@ CallInst *SpirvLowerRayTracing::createTraceRay() {
m_builder->CreateStore(arg, traceRaysArgs[TraceRayLibFuncParam::TMax]);

// Parent ray ID and static ID for logging feature
arg = ++argIt;
m_builder->CreateStore(arg, m_traceParams[TraceParam::ParentRayId]);
arg = ++argIt;
m_builder->create<lgc::GpurtSetRayStaticIdOp>(arg);
if (m_context->getPipelineContext()->getRayTracingState()->enableRayTracingCounters) {
arg = ++argIt;
m_builder->CreateStore(arg, m_traceParams[TraceParam::ParentRayId]);
arg = ++argIt;
m_builder->CreateStore(arg, m_traceParams[TraceParam::RayStaticId]);
}

// Call TraceRay function from traceRays module
auto call = m_builder->CreateCall(traceRayFunc, traceRaysArgs);
Expand Down Expand Up @@ -1663,6 +1672,7 @@ void SpirvLowerRayTracing::initTraceParamsTy(unsigned attributeSize) {
m_builder->getInt32Ty(), // 16, parentId
StructType::get(*m_context, {floatx3Ty, floatx3Ty, floatx3Ty}), // 17, HitTriangleVertexPositions
payloadType, // 18, Payload
m_builder->getInt32Ty(), // 19, rayStaticId
};
TraceParamsTySize[TraceParam::HitAttributes] = attributeSize;
TraceParamsTySize[TraceParam::Payload] = payloadType->getArrayNumElements();
Expand Down Expand Up @@ -1892,8 +1902,10 @@ FunctionType *SpirvLowerRayTracing::getTraceRayFuncTy() {
};

// Add parent ray ID and static ID for logging feature.
argsTys.push_back(m_builder->getInt32Ty());
argsTys.push_back(m_builder->getInt32Ty());
if (m_context->getPipelineContext()->getRayTracingState()->enableRayTracingCounters) {
argsTys.push_back(m_builder->getInt32Ty()); // Parent Id
argsTys.push_back(m_builder->getInt32Ty()); // Ray Static Id
}

auto funcTy = FunctionType::get(retTy, argsTys, false);
return funcTy;
Expand Down Expand Up @@ -2544,6 +2556,20 @@ void SpirvLowerRayTracing::visitSetHitTriangleNodePointer(lgc::GpurtSetHitTriang
m_funcsToLower.insert(inst.getCalledFunction());
}

// =====================================================================================================================
// Visits "lgc.gpurt.get.ray.static.id" instructions
//
// @param inst : The instruction
void SpirvLowerRayTracing::visitGetRayStaticId(lgc::GpurtGetRayStaticIdOp &inst) {
m_builder->SetInsertPoint(&inst);

auto rayStaticId = m_builder->CreateLoad(m_builder->getInt32Ty(), m_traceParams[TraceParam::RayStaticId]);
inst.replaceAllUsesWith(rayStaticId);

m_callsToLower.push_back(&inst);
m_funcsToLower.insert(inst.getCalledFunction());
}

// =====================================================================================================================
// Visits "lgc.gpurt.get.parent.id" instructions
//
Expand Down
3 changes: 3 additions & 0 deletions llpc/lower/llpcSpirvLowerRayTracing.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class GpurtSetTriangleIntersectionAttributesOp;
class GpurtSetHitTriangleNodePointerOp;
class GpurtGetParentIdOp;
class GpurtSetParentIdOp;
class GpurtGetRayStaticIdOp;
} // namespace lgc

namespace Llpc {
Expand All @@ -100,6 +101,7 @@ enum : unsigned {
ParentRayId, // Ray ID of the parent TraceRay call
HitTriangleVertexPositions, // Hit triangle vertex positions
Payload, // Payload
RayStaticId, // Ray static ID
Count // Count of the trace attributes
};
}
Expand Down Expand Up @@ -242,6 +244,7 @@ class SpirvLowerRayTracing : public SpirvLowerRayQuery {
void visitSetHitTriangleNodePointer(lgc::GpurtSetHitTriangleNodePointerOp &inst);
void visitGetParentId(lgc::GpurtGetParentIdOp &inst);
void visitSetParentId(lgc::GpurtSetParentIdOp &inst);
void visitGetRayStaticId(lgc::GpurtGetRayStaticIdOp &inst);
void visitDispatchRayIndex(lgc::rt::DispatchRaysIndexOp &inst);
void visitDispatchRaysDimensionsOp(lgc::rt::DispatchRaysDimensionsOp &inst);
void visitWorldRayOriginOp(lgc::rt::WorldRayOriginOp &inst);
Expand Down

0 comments on commit 5f60cae

Please sign in to comment.