Skip to content

Commit

Permalink
vulkan: optimize depth downsampling
Browse files Browse the repository at this point in the history
900us to 370us on intel iGPU with 32-thread wide waves.

See the comments for a more complete understanding of the details.
  • Loading branch information
Ryp committed Aug 9, 2023
1 parent aaf6da5 commit 7fff800
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 25 deletions.
71 changes: 67 additions & 4 deletions src/renderer/shader/tiled_lighting/tile_depth_downsample.comp.hlsl
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
#include "lib/base.hlsl"

#include "lib/morton.hlsl"
#include "tile_depth_downsample.share.hlsl"

// Input
#define INTERLOCKED 0
#define MORTON 1

// NOTE:
// https://gitlab.freedesktop.org/mesa/mesa/-/issues/9039
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPipelineShaderStageCreateFlagBits.html
// See: VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT

VK_PUSH_CONSTANT_HELPER(TileDepthConstants) consts;

VK_BINDING(0, 0) SamplerState Sampler;
Expand All @@ -12,9 +20,18 @@ VK_BINDING(0, 3) RWTexture2D<float2> TileDepthMax;

static const uint ThreadCount = TileDepthThreadCountX * TileDepthThreadCountY;

groupshared float2 lds_depth_min_max[ThreadCount];
#if INTERLOCKED
groupshared uint lds_depth_min;
groupshared uint lds_depth_max;
#else
groupshared float2 lds_depth_min_max[ThreadCount / MinWaveLaneCount];
#endif

#if MORTON
[numthreads(ThreadCount, 1, 1)]
#else
[numthreads(TileDepthThreadCountX, TileDepthThreadCountY, 1)]
#endif
void main(uint3 gtid : SV_GroupThreadID,
uint3 gid : SV_GroupID,
uint3 dtid : SV_DispatchThreadID,
Expand All @@ -23,17 +40,57 @@ void main(uint3 gtid : SV_GroupThreadID,
float depth_min_cs = 1.0;
float depth_max_cs = 0.0;

#if INTERLOCKED
if (gi == 0)
{
lds_depth_min = asuint(depth_min_cs);
lds_depth_max = asuint(depth_max_cs);
}

GroupMemoryBarrierWithGroupSync();
#endif

#if MORTON
const uint2 local_position_ts = gid.xy * uint2(TileDepthThreadCountX, TileDepthThreadCountY) * 2;
const uint2 position_ts = local_position_ts + decode_morton_2d(gi) * 2 + 1;
#else
const uint2 position_ts = dtid.xy * 2 + 1;
#endif
const float2 position_uv = (float2)position_ts * consts.extent_ts_inv;

// Use a single thread to process 4 texels first
const float4 quad_depth = SceneDepth.GatherRed(Sampler, position_uv);

depth_min_cs = min(min(quad_depth.x, quad_depth.y), min(quad_depth.z, quad_depth.w));
depth_max_cs = max(max(quad_depth.x, quad_depth.y), max(quad_depth.z, quad_depth.w));

lds_depth_min_max[gi] = float2(depth_min_cs, depth_max_cs);
// Reduce as much as we can with wave intrinsics
depth_min_cs = WaveActiveMin(depth_min_cs);
depth_max_cs = WaveActiveMax(depth_max_cs);

// Write wave-wide min/max results to LDS
#if INTERLOCKED
if (WaveIsFirstLane())
{
InterlockedMin(lds_depth_min, asuint(depth_min_cs));
InterlockedMax(lds_depth_max, asuint(depth_max_cs));
}

GroupMemoryBarrierWithGroupSync();
#else
uint wave_lane_count = WaveGetLaneCount();

if (WaveIsFirstLane())
{
// NOTE: Only valid with VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT
uint wave_index = gi / wave_lane_count;
lds_depth_min_max[wave_index] = float2(depth_min_cs, depth_max_cs);
}

// Reduce using LDS
uint active_threads = ThreadCount / wave_lane_count;

for (uint threads = ThreadCount / 2; threads > 0; threads /= 2)
for (uint threads = active_threads / 2; threads > 0; threads /= 2)
{
GroupMemoryBarrierWithGroupSync();

Expand All @@ -47,10 +104,16 @@ void main(uint3 gtid : SV_GroupThreadID,
lds_depth_min_max[gi] = float2(depth_min_cs, depth_max_cs);
}
}
#endif

if (gi == 0)
{
#if INTERLOCKED
TileDepthMin[gid.xy] = asfloat(lds_depth_min);
TileDepthMax[gid.xy] = asfloat(lds_depth_max);
#else
TileDepthMin[gid.xy] = depth_min_cs;
TileDepthMax[gid.xy] = depth_max_cs;
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#include "shared_types.hlsl"

// NOTE: Should always be a power of two
static const hlsl_uint MinWaveLaneCount = 8;

static const hlsl_uint TileDepthThreadCountX = 8;
static const hlsl_uint TileDepthThreadCountY = 8;

Expand Down
4 changes: 3 additions & 1 deletion src/renderer/vulkan/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,10 @@ bool vulkan_check_physical_device(IWindow* window,
Assert(device_features2.features.shaderClipDistance == VK_TRUE); // This is just checked, not enabled
Assert(device_features2.features.fillModeNonSolid == VK_TRUE);
Assert(device_features2.features.geometryShader == VK_TRUE);
Assert(device_vulkan12_features.shaderSampledImageArrayNonUniformIndexing == VK_TRUE);
Assert(device_vulkan13_features.synchronization2 == VK_TRUE);
Assert(device_vulkan13_features.dynamicRendering == VK_TRUE);
Assert(device_vulkan12_features.shaderSampledImageArrayNonUniformIndexing == VK_TRUE);
Assert(device_vulkan13_features.computeFullSubgroups == VK_TRUE);
Assert(primitive_restart_feature.primitiveTopologyListRestart == VK_TRUE);
Assert(index_uint8_feature.indexTypeUint8 == VK_TRUE);

Expand Down Expand Up @@ -752,6 +753,7 @@ void vulkan_create_logical_device(ReaperRoot& root,
device_vulkan13_features.pNext = &primitive_restart_feature;
device_vulkan13_features.synchronization2 = VK_TRUE;
device_vulkan13_features.dynamicRendering = VK_TRUE;
device_vulkan13_features.computeFullSubgroups = VK_TRUE;

VkPhysicalDeviceVulkan12Features device_features_1_2 = {};
device_features_1_2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
Expand Down
41 changes: 25 additions & 16 deletions src/renderer/vulkan/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,18 @@ const char* default_entry_point()
return "main";
}

VkPipeline create_compute_pipeline(VkDevice device,
VkPipelineLayout pipeline_layout,
VkShaderModule compute_shader,
VkSpecializationInfo* specialization_info)
VkPipeline create_compute_pipeline(VkDevice device, VkPipelineLayout pipeline_layout,
const VkPipelineShaderStageCreateInfo& shader_stage_create_info)
{
VkPipelineShaderStageCreateInfo shaderStage =
default_pipeline_shader_stage_create_info(VK_SHADER_STAGE_COMPUTE_BIT, compute_shader, specialization_info);

VkComputePipelineCreateInfo pipelineCreateInfo = {VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
nullptr,
0,
shaderStage,
pipeline_layout,
VK_NULL_HANDLE, // do not care about pipeline derivatives
0};
VkComputePipelineCreateInfo pipelineCreateInfo = {
.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
.pNext = nullptr,
.flags = 0,
.stage = shader_stage_create_info,
.layout = pipeline_layout,
.basePipelineHandle = VK_NULL_HANDLE, // do not care about pipeline derivatives
.basePipelineIndex = 0,
};

VkPipeline pipeline = VK_NULL_HANDLE;
VkPipelineCache cache = VK_NULL_HANDLE;
Expand All @@ -121,14 +118,26 @@ VkPipeline create_compute_pipeline(VkDevice device,
return pipeline;
}

VkPipeline create_compute_pipeline(VkDevice device,
VkPipelineLayout pipeline_layout,
VkShaderModule compute_shader,
VkSpecializationInfo* specialization_info)
{
VkPipelineShaderStageCreateInfo shader_stage =
default_pipeline_shader_stage_create_info(VK_SHADER_STAGE_COMPUTE_BIT, compute_shader, specialization_info);

return create_compute_pipeline(device, pipeline_layout, shader_stage);
}

VkPipelineShaderStageCreateInfo
default_pipeline_shader_stage_create_info(VkShaderStageFlagBits stage_bit, VkShaderModule shader_module,
const VkSpecializationInfo* specialization_info)
const VkSpecializationInfo* specialization_info,
VkPipelineShaderStageCreateFlags flags)
{
return VkPipelineShaderStageCreateInfo{
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
.pNext = nullptr,
.flags = VK_FLAGS_NONE,
.flags = flags,
.stage = stage_bit,
.module = shader_module,
.pName = default_entry_point(),
Expand Down
7 changes: 6 additions & 1 deletion src/renderer/vulkan/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ VkPipelineLayout create_pipeline_layout(
VkDevice device, nonstd::span<const VkDescriptorSetLayout> descriptor_set_layouts,
nonstd::span<const VkPushConstantRange> push_constant_ranges = nonstd::span<const VkPushConstantRange>());

VkPipeline create_compute_pipeline(VkDevice device, VkPipelineLayout pipeline_layout,
const VkPipelineShaderStageCreateInfo& shader_stage_create_info);

// FIXME Deprecated
VkPipeline create_compute_pipeline(VkDevice device, VkPipelineLayout pipeline_layout, VkShaderModule compute_shader,
VkSpecializationInfo* specialization_info = nullptr);

VkPipelineShaderStageCreateInfo
default_pipeline_shader_stage_create_info(VkShaderStageFlagBits stage_bit, VkShaderModule shader_module,
const VkSpecializationInfo* specialization_info = nullptr);
const VkSpecializationInfo* specialization_info = nullptr,
VkPipelineShaderStageCreateFlags flags = VK_FLAGS_NONE);

VkPipelineColorBlendAttachmentState default_pipeline_color_blend_attachment_state();
VkPipelineRenderingCreateInfo default_pipeline_rendering_create_info();
Expand Down
11 changes: 8 additions & 3 deletions src/renderer/vulkan/renderpass/TiledRasterPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ TiledRasterResources create_tiled_raster_pass_resources(ReaperRoot& root, Vulkan

const VkPushConstantRange pushConstantRange = {VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(TileDepthConstants)};

VkPipelineLayout pipeline_layout = create_pipeline_layout(
const VkPipelineLayout pipeline_layout = create_pipeline_layout(
backend.device, nonstd::span(&descriptor_set_layout, 1), nonstd::span(&pushConstantRange, 1));

VkPipeline pipeline =
create_compute_pipeline(backend.device, pipeline_layout, shader_modules.tile_depth_downsample_cs);
const VkPipelineShaderStageCreateInfo shader_stage = default_pipeline_shader_stage_create_info(
VK_SHADER_STAGE_COMPUTE_BIT, shader_modules.tile_depth_downsample_cs, nullptr,
VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT);

VkPipeline pipeline = create_compute_pipeline(backend.device, pipeline_layout, shader_stage);

Assert(backend.physicalDeviceInfo.subgroup_size >= MinWaveLaneCount);

resources.tile_depth_descriptor_set_layout = descriptor_set_layout;
resources.tile_depth_pipeline_layout = pipeline_layout;
Expand Down

0 comments on commit 7fff800

Please sign in to comment.