Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
a-johnston committed Nov 21, 2024
1 parent e827b2d commit 6c3194b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 49 deletions.
16 changes: 16 additions & 0 deletions drivers/metal/metal_objects.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#ifndef METAL_OBJECTS_H
#define METAL_OBJECTS_H

#include "core/typedefs.h"
#import "metal_device_properties.h"
#import "metal_utils.h"
#import "pixel_formats.h"
Expand Down Expand Up @@ -120,12 +121,14 @@ enum class MDCommandBufferStateType {
Render,
Compute,
Blit,
Raytrace,
};

enum class MDPipelineType {
None,
Render,
Compute,
Raytrace,
};

class MDRenderPass;
Expand Down Expand Up @@ -317,6 +320,7 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) MDCommandBuffer {

void _end_compute_dispatch();
void _end_blit();
void _end_raytrace();

#pragma mark - Render

Expand Down Expand Up @@ -486,6 +490,14 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) MDCommandBuffer {
}
} blit;

// State specific to a raytracing pass
struct {
id<MTLAccelerationStructureCommandEncoder> encoder = nil;
_FORCE_INLINE_ void reset() {
encoder = nil;
}
} raytrace;

_FORCE_INLINE_ id<MTLCommandBuffer> get_command_buffer() const {
return commandBuffer;
}
Expand Down Expand Up @@ -538,6 +550,10 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) MDCommandBuffer {
void compute_dispatch(uint32_t p_x_groups, uint32_t p_y_groups, uint32_t p_z_groups);
void compute_dispatch_indirect(RDD::BufferID p_indirect_buffer, uint64_t p_offset);

#pragma mark - Raytracing Commands

void raytrace_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index);

MDCommandBuffer(id<MTLCommandQueue> p_queue, RenderingDeviceDriverMetal *p_device_driver) :
device_driver(p_device_driver), queue(p_queue) {
type = MDCommandBufferStateType::None;
Expand Down
55 changes: 41 additions & 14 deletions drivers/metal/metal_objects.mm
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
return _end_compute_dispatch();
case MDCommandBufferStateType::Blit:
return _end_blit();
case MDCommandBufferStateType::Raytrace:
return _end_raytrace();
}
}

Expand All @@ -89,6 +91,8 @@
_end_compute_dispatch();
} else if (type == MDCommandBufferStateType::Blit) {
_end_blit();
} else if (type == MDCommandBufferStateType::Raytrace) {
_end_raytrace();
}

if (p->type == MDPipelineType::Render) {
Expand Down Expand Up @@ -158,6 +162,10 @@
compute.pipeline = (MDComputePipeline *)p;
compute.encoder = commandBuffer.computeCommandEncoder;
[compute.encoder setComputePipelineState:compute.pipeline->state];
} else if (p->type == MDPipelineType::Raytrace) {
DEV_ASSERT(type == MDCommandBufferStateType::None);
type = MDCommandBufferStateType::Raytrace;
raytrace.encoder = commandBuffer.accelerationStructureCommandEncoder;
}
}

Expand All @@ -171,6 +179,9 @@
case MDCommandBufferStateType::Compute:
_end_compute_dispatch();
break;
case MDCommandBufferStateType::Raytrace:
_end_raytrace();
break;
case MDCommandBufferStateType::Blit:
return blit.encoder;
}
Expand All @@ -181,20 +192,7 @@
}

void MDCommandBuffer::encodeRenderCommandEncoderWithDescriptor(MTLRenderPassDescriptor *p_desc, NSString *p_label) {
switch (type) {
case MDCommandBufferStateType::None:
break;
case MDCommandBufferStateType::Render:
render_end_pass();
break;
case MDCommandBufferStateType::Compute:
_end_compute_dispatch();
break;
case MDCommandBufferStateType::Blit:
_end_blit();
break;
}

end();
id<MTLRenderCommandEncoder> enc = [commandBuffer renderCommandEncoderWithDescriptor:p_desc];
if (p_label != nil) {
[enc pushDebugGroup:p_label];
Expand Down Expand Up @@ -990,6 +988,35 @@
type = MDCommandBufferStateType::None;
}

# pragma mark - Raytracing

void MDCommandBuffer::raytrace_bind_uniform_set(RDD::UniformSetID p_uniform_set, RDD::ShaderID p_shader, uint32_t p_set_index) {
DEV_ASSERT(type == MDCommandBufferStateType::Compute);

id<MTLComputeCommandEncoder> enc = compute.encoder;
id<MTLDevice> device = enc.device;

MDShader *shader = (MDShader *)(p_shader.id);
UniformSet const &set_info = shader->sets[p_set_index];

MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
BoundUniformSet &bus = set->boundUniformSetForShader(shader, device);
bus.merge_into(compute.resource_usage);

uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_COMPUTE);
if (offset) {
[enc setBuffer:bus.buffer offset:*offset atIndex:p_set_index];
}
}

void MDCommandBuffer::_end_raytrace() {
DEV_ASSERT(type == MDCommandBufferStateType::Raytrace);

[raytrace.encoder endEncoding];
raytrace.reset();
type = MDCommandBufferStateType::None;
}

void MDCommandBuffer::_end_blit() {
DEV_ASSERT(type == MDCommandBufferStateType::Blit);

Expand Down
12 changes: 5 additions & 7 deletions drivers/metal/rendering_device_driver_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,10 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public

#pragma mark - Raytracing

struct BlasStructureInfo {
MTLAccelerationStructureGeometryDescriptor *mtl_structure = nullptr;
};

struct TlasStructureInfo {
MTLAccelerationStructureDescriptor *mtl_structure = nullptr;
struct AccelerationStructureInfo {
AccelerationStructureType type = ACCELERATION_STRUCTURE_TYPE_BLAS;
MTLAccelerationStructureGeometryDescriptor *blas_desc = nil;
MTLAccelerationStructureDescriptor *tlas_desc = nil;
};

virtual bool is_raytracing_supported() override final;
Expand Down Expand Up @@ -463,7 +461,7 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public
size_t get_texel_buffer_alignment_for_format(RDD::DataFormat p_format) const;
size_t get_texel_buffer_alignment_for_format(MTLPixelFormat p_format) const;

using VersatileResource = VersatileResourceTemplate<BlasStructureInfo, TlasStructureInfo>;
using VersatileResource = VersatileResourceTemplate<AccelerationStructureInfo>;
PagedAllocator<VersatileResource, true> resources_allocator;

/******************/
Expand Down
55 changes: 27 additions & 28 deletions drivers/metal/rendering_device_driver_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
/**************************************************************************/

#import "rendering_device_driver_metal.h"
#include "core/error/error_macros.h"
#include "vulkan/vulkan_core.h"
#include "core/templates/local_vector.h"
#include <Foundation/Foundation.h>
#include <cstdint>
Expand Down Expand Up @@ -1816,7 +1818,7 @@ void deserialize(BufReader &p_reader) {
}

ERR_FAIL_COND_V_MSG(!resources.atomic_counters.empty(), FAILED, "Atomic counters not supported");
ERR_FAIL_COND_V_MSG(!resources.acceleration_structures.empty(), FAILED, "Acceleration structures not supported");
// ERR_FAIL_COND_V_MSG(!resources.acceleration_structures.empty(), FAILED, "Acceleration structures not supported");
ERR_FAIL_COND_V_MSG(!resources.shader_record_buffers.empty(), FAILED, "Shader record buffers not supported");

if (stage == SHADER_STAGE_VERTEX && !resources.stage_inputs.empty()) {
Expand Down Expand Up @@ -3658,7 +3660,7 @@ bool isArrayTexture(MTLTextureType p_type) {
descriptor.transformationMatrixBuffer = rid::get(p_transform_buffer);
descriptor.transformationMatrixBufferOffset = p_transform_offset;
} else {
// TODO
// TODO: Transform vertices and add a new buffer? Require vertices to be pre-transformed?
}

descriptor.vertexBuffer = rid::get(p_vertex_buffer);
Expand All @@ -3675,37 +3677,42 @@ bool isArrayTexture(MTLTextureType p_type) {
descriptor.indexBufferOffset = p_index_offset_bytes / sizeof(uint32_t);
}

BlasStructureInfo *info = VersatileResource::allocate<BlasStructureInfo>(resources_allocator);
info->mtl_structure = descriptor;
AccelerationStructureInfo *info = VersatileResource::allocate<AccelerationStructureInfo>(resources_allocator);
info->blas_desc = descriptor;

// TODO: Add bounding box support if support is added for descriptors with specialized intersection functions

return RDD::AccelerationStructureID(info);
}

RDD::AccelerationStructureID RenderingDeviceDriverMetal::tlas_create(const LocalVector<RDD::AccelerationStructureID> &p_blases) {
// TODO: This is extremely naive use of the metal api

MTLPrimitiveAccelerationStructureDescriptor *descriptor = [MTLPrimitiveAccelerationStructureDescriptor descriptor];
NSMutableArray *mtl_geometry = [NSMutableArray arrayWithCapacity: p_blases.size()];

for (int i = 0; i < p_blases.size(); i++) {
mtl_geometry[i] = ((BlasStructureInfo *) p_blases[i].id)->mtl_structure;
mtl_geometry[i] = ((AccelerationStructureInfo *) p_blases[i].id)->blas_desc;
}

descriptor.geometryDescriptors = mtl_geometry;

TlasStructureInfo *info = VersatileResource::allocate<TlasStructureInfo>(resources_allocator);
info->mtl_structure = descriptor;
AccelerationStructureInfo *info = VersatileResource::allocate<AccelerationStructureInfo>(resources_allocator);
info->tlas_desc = descriptor;

return RDD::AccelerationStructureID(info);
}

void RenderingDeviceDriverMetal::acceleration_structure_free(RDD::AccelerationStructureID p_acceleration_structure) {
// TODO
AccelerationStructureInfo *info = (AccelerationStructureInfo *) p_acceleration_structure.id;
info->blas_desc = nullptr;
info->tlas_desc = nullptr;
}

// ----- PIPELINE -----

RDD::RaytracingPipelineID RenderingDeviceDriverMetal::raytracing_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) {
// TODO
return RaytracingPipelineID();
return RDD::RaytracingPipelineID(compute_pipeline_create(p_shader, p_specialization_constants).id);
}

void RenderingDeviceDriverMetal::raytracing_pipeline_free(RDD::RaytracingPipelineID p_pipeline) {
Expand All @@ -3715,35 +3722,27 @@ bool isArrayTexture(MTLTextureType p_type) {
// ----- COMMANDS -----

void RenderingDeviceDriverMetal::command_build_acceleration_structure(CommandBufferID p_cmd_buffer, AccelerationStructureID p_acceleration_structure) {
TlasStructureInfo *info = (TlasStructureInfo *) p_acceleration_structure.id;
AccelerationStructureInfo *info = (AccelerationStructureInfo *) p_acceleration_structure.id;

MTLAccelerationStructureSizes sizes = [device accelerationStructureSizesWithDescriptor:info->mtl_structure];
MTLAccelerationStructureSizes sizes = [device accelerationStructureSizesWithDescriptor:info->tlas_desc];
id<MTLAccelerationStructure> structure = [device newAccelerationStructureWithSize:sizes.accelerationStructureSize];
id<MTLBuffer> scratch = [device newBufferWithLength:sizes.buildScratchBufferSize options:MTLResourceStorageModePrivate];

id<MTLCommandBuffer> command_buffer = rid::get(p_cmd_buffer);
id<MTLAccelerationStructureCommandEncoder> command_encoder = [command_buffer accelerationStructureCommandEncoder];
MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id);

// TODO: Optional structure compaction pass
[command_encoder buildAccelerationStructure:structure descriptor:info->mtl_structure scratchBuffer:scratch scratchBufferOffset:0];
[command_encoder endEncoding];
[command_buffer commit];
// TODO: Optional structure compaction pass, especially for structures which are reused

[cb->raytrace.encoder buildAccelerationStructure:structure descriptor:info->tlas_desc scratchBuffer:scratch scratchBufferOffset:0];
}

void RenderingDeviceDriverMetal::command_bind_raytracing_pipeline(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline) {
id<MTLCommandBuffer> command_buffer = rid::get(p_cmd_buffer);
id<MTLAccelerationStructureCommandEncoder> command_encoder = [command_buffer accelerationStructureCommandEncoder];
// TODO
[command_encoder endEncoding];
[command_buffer commit];
MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id);
cb->bind_pipeline((RDD::PipelineID) &p_pipeline);
}

void RenderingDeviceDriverMetal::command_bind_raytracing_uniform_set(CommandBufferID p_cmd_buffer, UniformSetID p_uniform_set, ShaderID p_shader, uint32_t p_set_index) {
id<MTLCommandBuffer> command_buffer = rid::get(p_cmd_buffer);
id<MTLAccelerationStructureCommandEncoder> command_encoder = [command_buffer accelerationStructureCommandEncoder];
// TODO
[command_encoder endEncoding];
[command_buffer commit];
MDCommandBuffer *cb = (MDCommandBuffer *)(p_cmd_buffer.id);
cb->raytrace_bind_uniform_set(p_uniform_set, p_shader, p_set_index);
}

void RenderingDeviceDriverMetal::command_raytracing_trace_rays(CommandBufferID p_cmd_buffer, RaytracingPipelineID p_pipeline, ShaderID p_shader, uint32_t p_width, uint32_t p_height) {
Expand Down

0 comments on commit 6c3194b

Please sign in to comment.