Skip to content

Commit

Permalink
raytracing: Initial Vulkan support
Browse files Browse the repository at this point in the history
- Vulkan implementations in `RenderingDeviceDriverVulkan`
- Raytracing instruction list in `RenderingDeviceGraph`
- Functions to create acceleration structures and raytracing pipelines
  in `RenderingDevice`
- Raygen, Miss, and ClosestHit shader stages support
  • Loading branch information
Fahien committed Nov 16, 2024
1 parent ec6a1c0 commit 2e1d952
Show file tree
Hide file tree
Showing 18 changed files with 1,929 additions and 35 deletions.
549 changes: 547 additions & 2 deletions drivers/vulkan/rendering_device_driver_vulkan.cpp

Large diffs are not rendered by default.

83 changes: 82 additions & 1 deletion drivers/vulkan/rendering_device_driver_vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {
bool storage_input_output_16 = false;
};

struct RaytracingCapabilities {
bool buffer_device_address_support = false;
bool acceleration_structure_support = false;
bool raytracing_pipeline_support = false;
uint32_t shader_group_handle_size = 0;
uint32_t shader_group_handle_alignment = 0;
uint32_t shader_group_handle_size_aligned = 0;
uint32_t shader_group_base_alignment = 0;
bool validation = false;
};

struct DeviceFunctions {
PFN_vkCreateSwapchainKHR CreateSwapchainKHR = nullptr;
PFN_vkDestroySwapchainKHR DestroySwapchainKHR = nullptr;
Expand All @@ -116,6 +127,10 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {

// Debug device fault.
PFN_vkGetDeviceFaultInfoEXT GetDeviceFaultInfoEXT = nullptr;

// Raytracing extensions.
PFN_vkCreateAccelerationStructureKHR CreateAccelerationStructureKHR = nullptr;
PFN_vkCreateRayTracingPipelinesKHR CreateRaytracingPipelinesKHR = nullptr;
};
// Debug marker extensions.
VkDebugReportObjectTypeEXT _convert_to_debug_report_objectType(VkObjectType p_object_type);
Expand All @@ -138,6 +153,7 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {
VRSCapabilities vrs_capabilities;
ShaderCapabilities shader_capabilities;
StorageBufferCapabilities storage_buffer_capabilities;
RaytracingCapabilities raytracing_capabilities;
bool pipeline_cache_control_support = false;
bool device_fault_support = false;
#if defined(VK_TRACK_DEVICE_MEMORY)
Expand Down Expand Up @@ -198,6 +214,10 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {
VkBufferView vk_view = VK_NULL_HANDLE; // For texel buffers.
};

private:
VkDeviceAddress _buffer_get_device_address(BufferID p_buffer);

public:
virtual BufferID buffer_create(uint64_t p_size, BitField<BufferUsageBits> p_usage, MemoryAllocationType p_allocation_type) override final;
virtual bool buffer_set_texel_format(BufferID p_buffer, DataFormat p_format) override final;
virtual void buffer_free(BufferID p_buffer) override final;
Expand Down Expand Up @@ -424,7 +444,7 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {
uint64_t vertex_input_mask = 0;
uint32_t fragment_output_mask = 0;
uint32_t specialization_constants_count = 0;
uint32_t is_compute = 0;
PipelineType pipeline_type = PipelineType::RASTERIZATION;
uint32_t compute_local_size[3] = {};
uint32_t set_count = 0;
uint32_t push_constant_size = 0;
Expand All @@ -434,11 +454,28 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {
};
};

struct RaytracingShaderRegions {
VkStridedDeviceAddressRegionKHR raygen;
uint32_t raygen_count = 0;
VkStridedDeviceAddressRegionKHR miss;
uint32_t miss_count = 0;
VkStridedDeviceAddressRegionKHR closest_hit;
uint32_t closest_hit_count = 0;
VkStridedDeviceAddressRegionKHR call;
uint32_t group_count = 0;

// Size of one shader group handle
LocalVector<uint8_t> handles_data;
};

struct ShaderInfo {
VkShaderStageFlags vk_push_constant_stages = 0;
TightLocalVector<VkPipelineShaderStageCreateInfo> vk_stages_create_info;
TightLocalVector<VkRayTracingShaderGroupCreateInfoKHR> vk_groups_create_info;
TightLocalVector<VkDescriptorSetLayout> vk_descriptor_set_layouts;
VkPipelineLayout vk_pipeline_layout = VK_NULL_HANDLE;
RaytracingShaderRegions regions;
BufferID sbt_buffer;
};

public:
Expand Down Expand Up @@ -626,6 +663,50 @@ class RenderingDeviceDriverVulkan : public RenderingDeviceDriver {

virtual PipelineID compute_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) override final;

/********************/
/**** RAYTRACING ****/
/********************/
struct AccelerationStructureInfo {
VkAccelerationStructureKHR vk_acceleration_structure = VK_NULL_HANDLE;
// Buffer used for the structure
RDD::BufferID buffer;
// Buffer used for building the structure
RDD::BufferID scratch_buffer;
// Buffer used for instances in a TLAS
RDD::BufferID instances_buffer;

// Required for building
VkAccelerationStructureGeometryKHR geometry;
LocalVector<VkAccelerationStructureInstanceKHR> instances;
VkAccelerationStructureBuildGeometryInfoKHR build_info;
VkAccelerationStructureBuildRangeInfoKHR range_info;
};

virtual AccelerationStructureID blas_create(BufferID p_vertex_buffer, uint64_t p_vertex_offset, VertexFormatID p_vertex_format, uint32_t p_vertex_count, BufferID p_index_buffer, IndexBufferFormat p_index_format, uint64_t p_index_offset, uint32_t p_index_count, BufferID p_transform_buffer, uint64_t p_transform_offset) override final;
virtual AccelerationStructureID tlas_create(const LocalVector<AccelerationStructureID> &p_blases) override final;
virtual void acceleration_structure_free(AccelerationStructureID p_acceleration_structure) override final;

private:
void _acceleration_structure_create(VkAccelerationStructureTypeKHR p_type, VkAccelerationStructureBuildSizesInfoKHR p_size_info, AccelerationStructureInfo *r_accel_info);

public:
// ----- PIPELINE -----

struct RaytracingPipelineInfo {
VkPipeline vk_pipeline;
};

virtual RaytracingPipelineID raytracing_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) override final;
virtual void raytracing_pipeline_free(RaytracingPipelineID p_pipeline) override final;

// ----- COMMANDS -----

virtual void command_build_acceleration_structure(CommandBufferID p_cmd_buffer, AccelerationStructureID p_acceleration_structure) override final;
virtual void command_bind_raytracing_pipeline(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline) override final;
virtual void command_bind_raytracing_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) override final;
virtual void command_raytracing_trace_rays(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline, ShaderID p_shader, uint32_t p_width, uint32_t p_height) override final;

public:
/*****************/
/**** QUERIES ****/
/*****************/
Expand Down
5 changes: 4 additions & 1 deletion editor/plugins/shader_file_editor_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ ShaderFileEditor::ShaderFileEditor() {
"Fragment",
"TessControl",
"TessEval",
"Compute"
"Compute",
"Raygen",
"Miss",
"ClosestHit",
};

stage_hb = memnew(HBoxContainer);
Expand Down
33 changes: 33 additions & 0 deletions gles3_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class GLES3HeaderStruct:
def __init__(self):
self.vertex_lines = []
self.fragment_lines = []
self.raygen_lines = []
self.miss_lines = []
self.closest_hit_lines = []
self.uniforms = []
self.fbos = []
self.texunits = []
Expand All @@ -25,6 +28,9 @@ def __init__(self):
self.line_offset = 0
self.vertex_offset = 0
self.fragment_offset = 0
self.raygen_offset = 0
self.miss_offset = 0
self.closest_hit_offset = 0
self.variant_defines = []
self.variant_names = []
self.specialization_names = []
Expand Down Expand Up @@ -88,6 +94,27 @@ def include_file_in_gles3_header(filename: str, header_data: GLES3HeaderStruct,
header_data.fragment_offset = header_data.line_offset
continue

if line.find("#[raygen]") != -1:
header_data.reading = "raygen"
line = fs.readline()
header_data.line_offset += 1
header_data.raygen_offset = header_data.line_offset
continue

if line.find("#[miss]") != -1:
header_data.reading = "miss"
line = fs.readline()
header_data.line_offset += 1
header_data.miss_offset = header_data.line_offset
continue

if line.find("#[closest_hit]") != -1:
header_data.reading = "closest_hit"
line = fs.readline()
header_data.line_offset += 1
header_data.closest_hit_offset = header_data.line_offset
continue

while line.find("#include ") != -1:
includeline = line.replace("#include ", "").strip()[1:-1]

Expand Down Expand Up @@ -182,6 +209,12 @@ def include_file_in_gles3_header(filename: str, header_data: GLES3HeaderStruct,
header_data.vertex_lines += [line]
if header_data.reading == "fragment":
header_data.fragment_lines += [line]
if header_data.reading == "raygen":
header_data.raygen_lines += [line]
if header_data.reading == "miss":
header_data.miss_lines += [line]
if header_data.reading == "closest_hit":
header_data.closest_hit_lines += [line]

line = fs.readline()
header_data.line_offset += 1
Expand Down
57 changes: 56 additions & 1 deletion glsl_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,25 @@ def __init__(self):
self.vertex_lines = []
self.fragment_lines = []
self.compute_lines = []
self.raygen_lines = []
self.miss_lines = []
self.closest_hit_lines = []

self.vertex_included_files = []
self.fragment_included_files = []
self.compute_included_files = []
self.raygen_included_files = []
self.miss_included_files = []
self.closest_hit_included_files = []

self.reading = ""
self.line_offset = 0
self.vertex_offset = 0
self.fragment_offset = 0
self.compute_offset = 0
self.raygen_offset = 0
self.miss_offset = 0
self.closest_hit_offset = 0


def include_file_in_rd_header(filename: str, header_data: RDHeaderStruct, depth: int) -> RDHeaderStruct:
Expand Down Expand Up @@ -53,6 +62,27 @@ def include_file_in_rd_header(filename: str, header_data: RDHeaderStruct, depth:
header_data.compute_offset = header_data.line_offset
continue

if line.find("#[raygen]") != -1:
header_data.reading = "raygen"
line = fs.readline()
header_data.line_offset += 1
header_data.raygen_offset = header_data.line_offset
continue

if line.find("#[miss]") != -1:
header_data.reading = "miss"
line = fs.readline()
header_data.line_offset += 1
header_data.miss_offset = header_data.line_offset
continue

if line.find("#[closest_hit]") != -1:
header_data.reading = "closest_hit"
line = fs.readline()
header_data.line_offset += 1
header_data.closest_hit_offset = header_data.line_offset
continue

while line.find("#include ") != -1:
includeline = line.replace("#include ", "").strip()[1:-1]

Expand All @@ -74,6 +104,18 @@ def include_file_in_rd_header(filename: str, header_data: RDHeaderStruct, depth:
header_data.compute_included_files += [included_file]
if include_file_in_rd_header(included_file, header_data, depth + 1) is None:
print_error(f'In file "{filename}": #include "{includeline}" could not be found!"')
elif included_file not in header_data.raygen_included_files and header_data.reading == "raygen":
header_data.raygen_included_files += [included_file]
if include_file_in_rd_header(included_file, header_data, depth + 1) is None:
print_error(f'In file "{filename}": #include "{includeline}" could not be found!"')
elif included_file not in header_data.miss_included_files and header_data.reading == "miss":
header_data.miss_included_files += [included_file]
if include_file_in_rd_header(included_file, header_data, depth + 1) is None:
print_error(f'In file "{filename}": #include "{includeline}" could not be found!"')
elif included_file not in header_data.closest_hit_included_files and header_data.reading == "closest_hit":
header_data.closest_hit_included_files += [included_file]
if include_file_in_rd_header(included_file, header_data, depth + 1) is None:
print_error(f'In file "{filename}": #include "{includeline}" could not be found!"')

line = fs.readline()

Expand All @@ -85,6 +127,12 @@ def include_file_in_rd_header(filename: str, header_data: RDHeaderStruct, depth:
header_data.fragment_lines += [line]
if header_data.reading == "compute":
header_data.compute_lines += [line]
if header_data.reading == "raygen":
header_data.raygen_lines += [line]
if header_data.reading == "miss":
header_data.miss_lines += [line]
if header_data.reading == "closest_hit":
header_data.closest_hit_lines += [line]

line = fs.readline()
header_data.line_offset += 1
Expand All @@ -109,7 +157,14 @@ def build_rd_header(
out_file_ifdef = out_file_base.replace(".", "_").upper()
out_file_class = out_file_base.replace(".glsl.gen.h", "").title().replace("_", "").replace(".", "") + "ShaderRD"

if header_data.compute_lines:
if header_data.raygen_lines:
body_parts = [
"static const char _raygen_code[] = {\n%s\n\t\t};" % to_raw_cstring(header_data.raygen_lines),
"static const char _miss_code[] = {\n%s\n\t\t};" % to_raw_cstring(header_data.miss_lines),
"static const char _closest_hit_code[] = {\n%s\n\t\t};" % to_raw_cstring(header_data.closest_hit_lines),
f'setup_raytracing(_raygen_code, _miss_code, _closest_hit_code, "{out_file_class}");',
]
elif header_data.compute_lines:
body_parts = [
"static const char _compute_code[] = {\n%s\n\t\t};" % to_raw_cstring(header_data.compute_lines),
f'setup(nullptr, nullptr, _compute_code, "{out_file_class}");',
Expand Down
5 changes: 4 additions & 1 deletion modules/glslang/register_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ static Vector<uint8_t> _compile_shader_glsl(RenderingDevice::ShaderStage p_stage
EShLangFragment,
EShLangTessControl,
EShLangTessEvaluation,
EShLangCompute
EShLangCompute,
EShLangRayGen,
EShLangMiss,
EShLangClosestHit,
};

int ClientInputSemanticsVersion = 100; // maps to, say, #define VULKAN 100
Expand Down
Loading

0 comments on commit 2e1d952

Please sign in to comment.