Skip to content

Commit

Permalink
[slang-rhi] refactor command encoding (#5487)
Browse files Browse the repository at this point in the history
* update render test to use new slang-rhi API

* update slang-rhi

---------

Co-authored-by: Yong He <[email protected]>
  • Loading branch information
skallweitNV and csyonghe authored Nov 11, 2024
1 parent 5ca37c3 commit c0d0611
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 119 deletions.
2 changes: 1 addition & 1 deletion external/slang-rhi
Submodule slang-rhi updated 246 files
186 changes: 68 additions & 118 deletions tools/render-test/render-test-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ struct ShaderOutputPlan
List<Item> items;
};

enum class PipelineType
{
Graphics,
Compute,
RayTracing,
};

class RenderTestApp
{
public:
Expand All @@ -95,12 +88,9 @@ class RenderTestApp
IDevice* device,
const Options& options,
const ShaderCompilerUtil::Input& input);
void runCompute(IComputePassEncoder* encoder);
void renderFrame(IRenderPassEncoder* encoder);
void renderFrameMesh(IRenderPassEncoder* encoder);
void finalize();

Result applyBinding(PipelineType pipelineType, IPassEncoder* encoder);
Result applyBinding(IShaderObject* rootObject);
void setProjectionMatrix(IShaderObject* rootObject);
Result writeBindingOutput(const String& fileName);

Expand All @@ -123,7 +113,6 @@ class RenderTestApp

IDevice* m_device;
ComPtr<ICommandQueue> m_queue;
ComPtr<ITransientResourceHeap> m_transientHeap;
ComPtr<IInputLayout> m_inputLayout;
ComPtr<IBuffer> m_vertexBuffer;
ComPtr<IShaderProgram> m_shaderProgram;
Expand Down Expand Up @@ -394,13 +383,14 @@ struct AssignValsFromLayoutContext
}

ComPtr<IShaderObject> shaderObject;
device->createShaderObject2(
device->createShaderObject(
slangSession,
slangType,
ShaderObjectContainerType::None,
shaderObject.writeRef());

SLANG_RETURN_ON_FAIL(assign(ShaderCursor(shaderObject), srcVal->contentVal));
shaderObject->finalize();
dstCursor.setObject(shaderObject);
return SLANG_OK;
}
Expand Down Expand Up @@ -507,48 +497,21 @@ SlangResult _assignVarsFromLayout(
return context.assign(rootCursor, layout.rootVal);
}

Result RenderTestApp::applyBinding(PipelineType pipelineType, IPassEncoder* encoder)
Result RenderTestApp::applyBinding(IShaderObject* rootObject)
{
auto slangReflection = (slang::ProgramLayout*)spGetReflection(
m_compilationOutput.output.getRequestForReflection());
ComPtr<slang::ISession> slangSession;
m_compilationOutput.output.m_requestForKernels->getSession(slangSession.writeRef());

switch (pipelineType)
{
case PipelineType::Compute:
{
IComputePassEncoder* computeEncoder = static_cast<IComputePassEncoder*>(encoder);
auto rootObject = computeEncoder->bindPipeline(m_pipeline);
SLANG_RETURN_ON_FAIL(_assignVarsFromLayout(
m_device,
slangSession,
rootObject,
m_compilationOutput.layout,
m_outputPlan,
slangReflection,
m_topLevelAccelerationStructure));
}
break;
case PipelineType::Graphics:
{
IRenderPassEncoder* renderEncoder = static_cast<IRenderPassEncoder*>(encoder);
auto rootObject = renderEncoder->bindPipeline(m_pipeline);
SLANG_RETURN_ON_FAIL(_assignVarsFromLayout(
m_device,
slangSession,
rootObject,
m_compilationOutput.layout,
m_outputPlan,
slangReflection,
m_topLevelAccelerationStructure));
setProjectionMatrix(rootObject);
}
break;
default:
throw "unknown pipeline type";
}
return SLANG_OK;
return _assignVarsFromLayout(
m_device,
slangSession,
rootObject,
m_compilationOutput.layout,
m_outputPlan,
slangReflection,
m_topLevelAccelerationStructure);
}

SlangResult RenderTestApp::initialize(
Expand Down Expand Up @@ -688,11 +651,6 @@ Result RenderTestApp::_initializeShaders(

void RenderTestApp::_initializeRenderPass()
{
ITransientResourceHeap::Desc transientHeapDesc = {};
transientHeapDesc.constantBufferSize = 4096 * 1024;
m_transientHeap = m_device->createTransientResourceHeap(transientHeapDesc);
SLANG_ASSERT(m_transientHeap);

m_queue = m_device->getQueue(QueueType::Graphics);
SLANG_ASSERT(m_queue);

Expand Down Expand Up @@ -783,21 +741,18 @@ void RenderTestApp::_initializeAccelerationStructure()

compactedSizeQuery->reset();

auto commandBuffer = m_transientHeap->createCommandBuffer();
auto passEncoder = commandBuffer->beginRayTracingPass();
auto encoder = m_queue->createCommandEncoder();
AccelerationStructureQueryDesc compactedSizeQueryDesc = {};
compactedSizeQueryDesc.queryPool = compactedSizeQuery;
compactedSizeQueryDesc.queryType = QueryType::AccelerationStructureCompactedSize;
passEncoder->buildAccelerationStructure(
encoder->buildAccelerationStructure(
buildDesc,
draftAS,
nullptr,
scratchBuffer,
1,
&compactedSizeQueryDesc);
passEncoder->end();
commandBuffer->close();
m_queue->submit(commandBuffer);
m_queue->submit(encoder->finish());
m_queue->waitOnHost();

uint64_t compactedSize = 0;
Expand All @@ -808,15 +763,12 @@ void RenderTestApp::_initializeAccelerationStructure()
finalDesc,
m_bottomLevelAccelerationStructure.writeRef());

commandBuffer = m_transientHeap->createCommandBuffer();
passEncoder = commandBuffer->beginRayTracingPass();
passEncoder->copyAccelerationStructure(
encoder = m_queue->createCommandEncoder();
encoder->copyAccelerationStructure(
m_bottomLevelAccelerationStructure,
draftAS,
AccelerationStructureCopyMode::Compact);
passEncoder->end();
commandBuffer->close();
m_queue->submit(commandBuffer);
m_queue->submit(encoder->finish());
m_queue->waitOnHost();
}

Expand Down Expand Up @@ -881,18 +833,15 @@ void RenderTestApp::_initializeAccelerationStructure()
createDesc,
m_topLevelAccelerationStructure.writeRef());

auto commandBuffer = m_transientHeap->createCommandBuffer();
auto passEncoder = commandBuffer->beginRayTracingPass();
passEncoder->buildAccelerationStructure(
auto encoder = m_queue->createCommandEncoder();
encoder->buildAccelerationStructure(
buildDesc,
m_topLevelAccelerationStructure,
nullptr,
scratchBuffer,
0,
nullptr);
passEncoder->end();
commandBuffer->close();
m_queue->submit(commandBuffer);
m_queue->submit(encoder->finish());
m_queue->waitOnHost();
}
}
Expand All @@ -906,36 +855,6 @@ void RenderTestApp::setProjectionMatrix(IShaderObject* rootObject)
.setData(info.identityProjectionMatrix, sizeof(float) * 16);
}

void RenderTestApp::renderFrameMesh(IRenderPassEncoder* encoder)
{
auto pipelineType = PipelineType::Graphics;
applyBinding(pipelineType, encoder);
encoder->drawMeshTasks(
m_options.computeDispatchSize[0],
m_options.computeDispatchSize[1],
m_options.computeDispatchSize[2]);
}

void RenderTestApp::renderFrame(IRenderPassEncoder* encoder)
{
auto pipelineType = PipelineType::Graphics;
applyBinding(pipelineType, encoder);

encoder->setVertexBuffer(0, m_vertexBuffer);

encoder->draw(3);
}

void RenderTestApp::runCompute(IComputePassEncoder* encoder)
{
auto pipelineType = PipelineType::Compute;
applyBinding(pipelineType, encoder);
encoder->dispatchCompute(
m_options.computeDispatchSize[0],
m_options.computeDispatchSize[1],
m_options.computeDispatchSize[2]);
}

void RenderTestApp::finalize()
{
m_compilationOutput.output.reset();
Expand Down Expand Up @@ -1002,15 +921,31 @@ Result RenderTestApp::writeScreen(const String& filename)

Result RenderTestApp::update()
{
auto commandBuffer = m_transientHeap->createCommandBuffer();
auto encoder = m_queue->createCommandEncoder();
if (m_options.shaderType == Options::ShaderProgramType::Compute)
{
auto passEncoder = commandBuffer->beginComputePass();
runCompute(passEncoder);
passEncoder->end();
auto rootObject = m_device->createRootShaderObject(m_pipeline);
applyBinding(rootObject);
rootObject->finalize();

encoder->beginComputePass();
ComputeState state;
state.pipeline = static_cast<IComputePipeline*>(m_pipeline.get());
state.rootObject = rootObject;
encoder->setComputeState(state);
encoder->dispatchCompute(
m_options.computeDispatchSize[0],
m_options.computeDispatchSize[1],
m_options.computeDispatchSize[2]);
encoder->endComputePass();
}
else
{
auto rootObject = m_device->createRootShaderObject(m_pipeline);
applyBinding(rootObject);
setProjectionMatrix(rootObject);
rootObject->finalize();

RenderPassColorAttachment colorAttachment = {};
colorAttachment.view = m_colorBufferView;
colorAttachment.loadOp = LoadOp::Clear;
Expand All @@ -1024,23 +959,38 @@ Result RenderTestApp::update()
renderPass.colorAttachmentCount = 1;
renderPass.depthStencilAttachment = &depthStencilAttachment;

auto passEncoder = commandBuffer->beginRenderPass(renderPass);
rhi::Viewport viewport = {};
viewport.maxZ = 1.0f;
viewport.extentX = (float)gWindowWidth;
viewport.extentY = (float)gWindowHeight;
passEncoder->setViewportAndScissor(viewport);
encoder->beginRenderPass(renderPass);

RenderState state;
state.pipeline = static_cast<IRenderPipeline*>(m_pipeline.get());
state.rootObject = rootObject;
state.viewports[0] = Viewport((float)gWindowWidth, (float)gWindowHeight);
state.viewportCount = 1;
state.scissorRects[0] = ScissorRect(gWindowWidth, gWindowHeight);
state.scissorRectCount = 1;

if (m_options.shaderType == Options::ShaderProgramType::GraphicsMeshCompute ||
m_options.shaderType == Options::ShaderProgramType::GraphicsTaskMeshCompute)
renderFrameMesh(passEncoder);
{
encoder->setRenderState(state);
encoder->drawMeshTasks(
m_options.computeDispatchSize[0],
m_options.computeDispatchSize[1],
m_options.computeDispatchSize[2]);
}
else
renderFrame(passEncoder);
passEncoder->end();
{
state.vertexBuffers[0] = m_vertexBuffer;
state.vertexBufferCount = 1;
encoder->setRenderState(state);
DrawArguments args;
args.vertexCount = 3;
encoder->draw(args);
}
encoder->endRenderPass();
}
commandBuffer->close();

m_startTicks = Process::getClockTick();
m_queue->submit(commandBuffer);
m_queue->submit(encoder->finish());
m_queue->waitOnHost();

// If we are in a mode where output is requested, we need to snapshot the back buffer here
Expand Down

0 comments on commit c0d0611

Please sign in to comment.