Skip to content

Commit

Permalink
metal wip
Browse files Browse the repository at this point in the history
  • Loading branch information
a-johnston committed Nov 18, 2024
1 parent 736681f commit d980d88
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 9 deletions.
18 changes: 18 additions & 0 deletions drivers/metal/rendering_device_driver_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,24 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public

virtual PipelineID compute_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) override final;

#pragma mark - Raytracing

virtual RDD::AccelerationStructureID blas_create(BufferID p_vertex_buffer, uint64_t p_vertex_offset, VertexFormatID p_vertex_format, uint32_t p_vertex_count, BufferID p_index_buffer, IndexBufferFormat p_index_format, uint64_t p_index_offset_bytes, uint32_t p_index_count, BufferID p_transform_buffer, uint64_t p_transform_offset) override final;
virtual RDD::AccelerationStructureID tlas_create(const LocalVector<AccelerationStructureID> &p_blases) override final;
virtual void acceleration_structure_free(AccelerationStructureID p_acceleration_structure) override final;

// ----- PIPELINE -----

virtual RaytracingPipelineID raytracing_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) override final;
virtual void raytracing_pipeline_free(RDD::RaytracingPipelineID p_pipeline) override final;

// ----- COMMANDS -----

virtual void command_build_acceleration_structure(CommandBufferID p_cmd_buffer, AccelerationStructureID p_acceleration_structure) override final;
virtual void command_bind_raytracing_pipeline(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline) override final;
virtual void command_bind_raytracing_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) override final;
virtual void command_raytracing_trace_rays(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline, ShaderID p_shader, uint32_t p_width, uint32_t p_height) override final;

#pragma mark - Queries

// ----- TIMESTAMP -----
Expand Down
67 changes: 58 additions & 9 deletions drivers/metal/rendering_device_driver_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ struct API_AVAILABLE(macos(11.0), ios(14.0)) ShaderBinaryData {
uint32_t vertex_input_mask = UINT32_MAX;
uint32_t fragment_output_mask = UINT32_MAX;
uint32_t spirv_specialization_constants_ids_mask = UINT32_MAX;
uint32_t is_compute = UINT32_MAX;
uint32_t pipeline_type = (uint32_t) RenderingDevice::PipelineType::RASTERIZATION;
uint32_t needs_view_mask_buffer = UINT32_MAX;
ComputeSize compute_local_size;
PushConstantData push_constant;
Expand All @@ -1538,7 +1538,7 @@ size_t serialize_size() const {
size += sizeof(uint32_t); // vertex_input_mask
size += sizeof(uint32_t); // fragment_output_mask
size += sizeof(uint32_t); // spirv_specialization_constants_ids_mask
size += sizeof(uint32_t); // is_compute
size += sizeof(uint32_t); // pipeline_type
size += sizeof(uint32_t); // needs_view_mask_buffer
size += compute_local_size.serialize_size(); // compute_local_size
size += push_constant.serialize_size(); // push_constant
Expand All @@ -1563,7 +1563,7 @@ void serialize(BufWriter &p_writer) const {
p_writer.write(vertex_input_mask);
p_writer.write(fragment_output_mask);
p_writer.write(spirv_specialization_constants_ids_mask);
p_writer.write(is_compute);
p_writer.write(pipeline_type);
p_writer.write(needs_view_mask_buffer);
p_writer.write(compute_local_size);
p_writer.write(push_constant);
Expand All @@ -1578,7 +1578,7 @@ void deserialize(BufReader &p_reader) {
p_reader.read(vertex_input_mask);
p_reader.read(fragment_output_mask);
p_reader.read(spirv_specialization_constants_ids_mask);
p_reader.read(is_compute);
p_reader.read(pipeline_type);
p_reader.read(needs_view_mask_buffer);
p_reader.read(compute_local_size);
p_reader.read(push_constant);
Expand Down Expand Up @@ -1617,10 +1617,14 @@ void deserialize(BufReader &p_reader) {
ShaderStage stage_flag = (ShaderStage)(1 << p_spirv[i].shader_stage);

if (p_spirv[i].shader_stage == SHADER_STAGE_COMPUTE) {
r_reflection.is_compute = true;
r_reflection.pipeline_type = PipelineType::COMPUTE;
ERR_FAIL_COND_V_MSG(p_spirv.size() != 1, FAILED,
"Compute shaders can only receive one stage, dedicated to compute.");
}
if (p_spirv[i].shader_stage == SHADER_STAGE_RAYGEN || p_spirv[i].shader_stage == SHADER_STAGE_MISS || p_spirv[i].shader_stage == SHADER_STAGE_CLOSEST_HIT) {
r_reflection.pipeline_type = PipelineType::RAYTRACING;
}

ERR_FAIL_COND_V_MSG(r_reflection.stages.has_flag(stage_flag), FAILED,
"Stage " + String(SHADER_STAGE_NAMES[p_spirv[i].shader_stage]) + " submitted more than once.");

Expand All @@ -1629,7 +1633,7 @@ void deserialize(BufReader &p_reader) {

Compiler compiler(std::move(pir));

if (r_reflection.is_compute) {
if (r_reflection.pipeline_type == PipelineType::COMPUTE) {
r_reflection.compute_local_size[0] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 0);
r_reflection.compute_local_size[1] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 1);
r_reflection.compute_local_size[2] = compiler.get_execution_mode_argument(spv::ExecutionModeLocalSize, 2);
Expand Down Expand Up @@ -1926,7 +1930,7 @@ void deserialize(BufReader &p_reader) {
.y = spirv_data.compute_local_size[1],
.z = spirv_data.compute_local_size[2],
};
bin_data.is_compute = spirv_data.is_compute;
bin_data.pipeline_type = (uint32_t) spirv_data.pipeline_type;
bin_data.push_constant.size = spirv_data.push_constant_size;
bin_data.push_constant.stages = (ShaderStageUsage)(uint8_t)spirv_data.push_constant_stages;
bin_data.needs_view_mask_buffer = shader_meta.has_multiview ? 1 : 0;
Expand Down Expand Up @@ -2478,7 +2482,7 @@ void deserialize(BufReader &p_reader) {
}

MDShader *shader = nullptr;
if (binary_data.is_compute) {
if (binary_data.pipeline_type == (uint32_t) PipelineType::COMPUTE) {
MDComputeShader *cs = new MDComputeShader(binary_data.shader_name, uniform_sets, libraries[ShaderStage::SHADER_STAGE_COMPUTE]);

uint32_t *binding = binary_data.push_constant.msl_binding.getptr(SHADER_STAGE_COMPUTE);
Expand Down Expand Up @@ -2520,7 +2524,7 @@ void deserialize(BufReader &p_reader) {

r_shader_desc.vertex_input_mask = binary_data.vertex_input_mask;
r_shader_desc.fragment_output_mask = binary_data.fragment_output_mask;
r_shader_desc.is_compute = binary_data.is_compute;
r_shader_desc.pipeline_type = (RenderingDevice::PipelineType) binary_data.pipeline_type;
r_shader_desc.compute_local_size[0] = binary_data.compute_local_size.x;
r_shader_desc.compute_local_size[1] = binary_data.compute_local_size.y;
r_shader_desc.compute_local_size[2] = binary_data.compute_local_size.z;
Expand Down Expand Up @@ -3623,6 +3627,51 @@ bool isArrayTexture(MTLTextureType p_type) {
return PipelineID(pipeline);
}

#pragma mark - Raytracing

RDD::AccelerationStructureID RenderingDeviceDriverMetal::blas_create(BufferID p_vertex_buffer, uint64_t p_vertex_offset, VertexFormatID p_vertex_format, uint32_t p_vertex_count, BufferID p_index_buffer, IndexBufferFormat p_index_format, uint64_t p_index_offset_bytes, uint32_t p_index_count, RDD::BufferID p_transform_buffer, uint64_t p_transform_offset) {
// TODO
return RDD::AccelerationStructureID();
}

RDD::AccelerationStructureID RenderingDeviceDriverMetal::tlas_create(const LocalVector<RDD::AccelerationStructureID> &p_blases) {
// TODO
return RDD::AccelerationStructureID();
}

void RenderingDeviceDriverMetal::acceleration_structure_free(RDD::AccelerationStructureID p_acceleration_structure) {
// TODO
}

// ----- PIPELINE -----

RDD::RaytracingPipelineID RenderingDeviceDriverMetal::raytracing_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) {
// TODO
return RaytracingPipelineID();
}

void RenderingDeviceDriverMetal::raytracing_pipeline_free(RDD::RaytracingPipelineID p_pipeline) {
// TODO
}

// ----- COMMANDS -----

void RenderingDeviceDriverMetal::command_build_acceleration_structure(CommandBufferID p_cmd_buffer, AccelerationStructureID p_acceleration_structure) {
// TODO
}

void RenderingDeviceDriverMetal::command_bind_raytracing_pipeline(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline) {
// TODO
}

void RenderingDeviceDriverMetal::command_bind_raytracing_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) {
// TODO
}

void RenderingDeviceDriverMetal::command_raytracing_trace_rays(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline, ShaderID p_shader, uint32_t p_width, uint32_t p_height) {
// TODO
}

#pragma mark - Queries

// ----- TIMESTAMP -----
Expand Down

0 comments on commit d980d88

Please sign in to comment.