Skip to content

Commit

Permalink
cleanup Binding (#107)
Browse files Browse the repository at this point in the history
* cleanup Binding

* fix Binding constructors with ComPtr

---------

Co-authored-by: Simon Kallweit <[email protected]>
  • Loading branch information
westlicht and skallweitNV authored Nov 15, 2024
1 parent aa89fde commit 1b6b405
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 49 deletions.
41 changes: 28 additions & 13 deletions include/slang-rhi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1088,31 +1088,46 @@ enum class BindingType
TextureView,
Sampler,
CombinedTextureSampler,
CombinedTextureViewSampler,
AccelerationStructure,
};

struct Binding
{
BindingType type;
ComPtr<IResource> resource;
ComPtr<IResource> resource2;
BindingType type = BindingType::Unknown;
IResource* resource = nullptr;
IResource* resource2 = nullptr;
union
{
BufferRange bufferRange;
};

// clang-format off
Binding() : type(BindingType::Unknown) {}
Binding(ComPtr<IBuffer> buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer), bufferRange(range) {}
Binding(IBuffer* buffer, const BufferRange& range = kEntireBuffer) : Binding(ComPtr<IBuffer>(buffer), range) {}
Binding(ComPtr<IBuffer> buffer, ComPtr<IBuffer> counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer), resource2(counter), bufferRange(range) {}
Binding(ComPtr<ITexture> texture) : type(BindingType::Texture), resource(texture) {}
Binding(ComPtr<ITextureView> textureView) : type(BindingType::TextureView), resource(textureView) {}
Binding(ComPtr<ISampler> sampler) : type(BindingType::Sampler) , resource(sampler) {}
Binding(ComPtr<ITextureView> textureView, ComPtr<ISampler> sampler) : type(BindingType::CombinedTextureSampler), resource(textureView), resource2(sampler) {}
Binding(ComPtr<ITexture> texture, ComPtr<ISampler> sampler) : type(BindingType::CombinedTextureSampler) , resource(texture), resource2(sampler) {}
Binding(ComPtr<IAccelerationStructure> as) : type(BindingType::AccelerationStructure) , resource(as) {}
Binding(IAccelerationStructure* as) : Binding(ComPtr<IAccelerationStructure>(as)) {}

Binding(IBuffer* buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer), bufferRange(range) {}
Binding(const ComPtr<IBuffer>& buffer, const BufferRange& range = kEntireBuffer) : type(BindingType::Buffer), resource(buffer.get()), bufferRange(range) {}

Binding(IBuffer* buffer, IBuffer* counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer), resource2(counter), bufferRange(range) {}
Binding(const ComPtr<IBuffer>& buffer, const ComPtr<IBuffer>& counter, const BufferRange& range = kEntireBuffer) : type(BindingType::BufferWithCounter), resource(buffer.get()), resource2(counter.get()), bufferRange(range) {}

Binding(ITexture* texture) : type(BindingType::Texture), resource(texture) {}
Binding(const ComPtr<ITexture>& texture) : type(BindingType::Texture), resource(texture.get()) {}

Binding(ITextureView* textureView) : type(BindingType::TextureView), resource(textureView) {}
Binding(const ComPtr<ITextureView>& textureView) : type(BindingType::TextureView), resource(textureView.get()) {}

Binding(ISampler* sampler) : type(BindingType::Sampler) , resource(sampler) {}
Binding(const ComPtr<ISampler>& sampler) : type(BindingType::Sampler) , resource(sampler.get()) {}

Binding(ITexture* texture, ISampler* sampler) : type(BindingType::CombinedTextureSampler), resource(texture), resource2(sampler) {}
Binding(const ComPtr<ITexture>& texture, const ComPtr<ISampler>& sampler) : type(BindingType::CombinedTextureSampler), resource(texture.get()), resource2(sampler.get()) {}

Binding(ITextureView* textureView, ISampler* sampler) : type(BindingType::CombinedTextureViewSampler) , resource(textureView), resource2(sampler) {}
Binding(const ComPtr<ITextureView>& textureView, const ComPtr<ISampler>& sampler) : type(BindingType::CombinedTextureViewSampler) , resource(textureView.get()), resource2(sampler.get()) {}

Binding(IAccelerationStructure* as) : type(BindingType::AccelerationStructure), resource(as) {}
Binding(const ComPtr<IAccelerationStructure>& as) : type(BindingType::AccelerationStructure), resource(as.get()) {}
// clang-format on
};

Expand Down
10 changes: 7 additions & 3 deletions src/cpu/cpu-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
case BindingType::Buffer:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
const BufferDesc& desc = buffer->m_desc;
BufferRange range = buffer->resolveBufferRange(binding.bufferRange);
m_resources[viewIndex] = buffer;
Expand All @@ -182,12 +182,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
{
auto textureView = checked_cast<TextureViewImpl*>(binding.resource.get());
auto textureView = checked_cast<TextureViewImpl*>(binding.resource);
m_resources[viewIndex] = textureView;
slang_prelude::IRWTexture* textureObj = textureView;
SLANG_RETURN_ON_FAIL(setData(offset, &textureObj, sizeof(textureObj)));
Expand All @@ -201,6 +201,10 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
break;
}
case BindingType::CombinedTextureViewSampler:
{
break;
}
case BindingType::AccelerationStructure:
{
break;
Expand Down
6 changes: 3 additions & 3 deletions src/cuda/cuda-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
case BindingType::Buffer:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
const BufferDesc& desc = buffer->m_desc;
BufferRange range = buffer->resolveBufferRange(binding.bufferRange);
m_resources[viewIndex] = buffer;
Expand All @@ -222,7 +222,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
m_resources[viewIndex] = texture;
switch (bindingRange.bindingType)
{
Expand All @@ -237,7 +237,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::TextureView:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource.get());
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource);
m_resources[viewIndex] = textureView;
TextureImpl* texture = textureView->m_texture;
switch (bindingRange.bindingType)
Expand Down
10 changes: 6 additions & 4 deletions src/d3d11/d3d11-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
case BindingType::Buffer:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
BufferRange bufferRange = buffer->resolveBufferRange(binding.bufferRange);
m_resources.emplace(buffer);
if (D3DUtil::isUAVBinding(bindingRange.bindingType))
Expand All @@ -73,12 +73,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource.get());
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource);
m_resources.emplace(textureView);
if (D3DUtil::isUAVBinding(bindingRange.bindingType))
{
Expand All @@ -91,10 +91,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
break;
}
case BindingType::Sampler:
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource.get());
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource);
break;
case BindingType::CombinedTextureSampler:
break;
case BindingType::CombinedTextureViewSampler:
break;
case BindingType::AccelerationStructure:
break;
}
Expand Down
22 changes: 14 additions & 8 deletions src/d3d12/d3d12-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,8 +765,8 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
case BindingType::Buffer:
case BindingType::BufferWithCounter:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* counterBuffer = checked_cast<BufferImpl*>(binding.resource2.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
BufferImpl* counterBuffer = checked_cast<BufferImpl*>(binding.resource2);
BufferRange bufferRange = buffer->resolveBufferRange(binding.bufferRange);
boundResource.type = BoundResourceType::Buffer;
boundResource.resource = buffer;
Expand Down Expand Up @@ -811,12 +811,12 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource.get());
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource);
boundResource.type = BoundResourceType::TextureView;
boundResource.resource = textureView;
D3D12Descriptor descriptor;
Expand All @@ -843,7 +843,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Sampler:
{
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource.get());
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource);
d3dDevice->CopyDescriptorsSimple(
1,
m_descriptorSet.samplerTable.getCpuHandle(bindingIndex),
Expand All @@ -854,8 +854,14 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::CombinedTextureSampler:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource.get());
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource2.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource2);
return setBinding(offset, Binding(m_device->createTextureView(texture, {}), sampler));
}
case BindingType::CombinedTextureViewSampler:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource);
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource2);
boundResource.type = BoundResourceType::TextureView;
boundResource.resource = textureView;
boundResource.requiredState = ResourceState::ShaderResource;
Expand All @@ -875,7 +881,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::AccelerationStructure:
{
AccelerationStructureImpl* as = checked_cast<AccelerationStructureImpl*>(binding.resource.get());
AccelerationStructureImpl* as = checked_cast<AccelerationStructureImpl*>(binding.resource);
boundResource.type = BoundResourceType::AccelerationStructure;
boundResource.resource = as;
if (bindingRange.isRootParameter)
Expand Down
8 changes: 4 additions & 4 deletions src/metal/metal-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
switch (binding.type)
{
case BindingType::Buffer:
m_buffers[bindingIndex] = checked_cast<BufferImpl*>(binding.resource.get());
m_buffers[bindingIndex] = checked_cast<BufferImpl*>(binding.resource);
m_bufferOffsets[bindingIndex] = binding.bufferRange.offset;
break;
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
m_textureViews[bindingIndex] = checked_cast<TextureViewImpl*>(binding.resource.get());
m_textureViews[bindingIndex] = checked_cast<TextureViewImpl*>(binding.resource);
break;
case BindingType::Sampler:
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource.get());
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource);
break;
}

Expand Down
24 changes: 14 additions & 10 deletions src/vulkan/vk-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
case BindingType::Buffer:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
slot.type = BindingType::Buffer;
slot.resource = buffer;
slot.format = slot.format != Format::Unknown ? slot.format : buffer->m_desc.format;
Expand All @@ -112,13 +112,13 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
{
slot.type = BindingType::TextureView;
slot.resource = checked_cast<TextureViewImpl*>(binding.resource.get());
slot.resource = checked_cast<TextureViewImpl*>(binding.resource);
switch (bindingRange.bindingType)
{
case slang::BindingType::Texture:
Expand All @@ -131,20 +131,24 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
break;
}
case BindingType::Sampler:
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource.get());
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource);
break;
case BindingType::CombinedTextureSampler:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
m_combinedTextureSamplers[bindingIndex] = CombinedTextureSamplerSlot{
checked_cast<TextureViewImpl*>(m_device->createTextureView(texture, {}).get()),
checked_cast<SamplerImpl*>(binding.resource2.get())
};
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource2);
return setBinding(offset, Binding(m_device->createTextureView(texture, {}), sampler));
}
case BindingType::CombinedTextureViewSampler:
{
TextureViewImpl* textureView = checked_cast<TextureViewImpl*>(binding.resource);
SamplerImpl* sampler = checked_cast<SamplerImpl*>(binding.resource2);
m_combinedTextureSamplers[bindingIndex] = CombinedTextureSamplerSlot{textureView, sampler};
break;
}
case BindingType::AccelerationStructure:
slot.type = BindingType::AccelerationStructure;
slot.resource = checked_cast<AccelerationStructureImpl*>(binding.resource.get());
slot.resource = checked_cast<AccelerationStructureImpl*>(binding.resource);
break;
}

Expand Down
8 changes: 4 additions & 4 deletions src/wgpu/wgpu-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
{
case BindingType::Buffer:
{
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource.get());
BufferImpl* buffer = checked_cast<BufferImpl*>(binding.resource);
ResourceSlot slot;
slot.type = BindingType::Buffer;
slot.resource = buffer;
Expand All @@ -99,19 +99,19 @@ Result ShaderObjectImpl::setBinding(ShaderOffset const& offset, Binding binding)
}
case BindingType::Texture:
{
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource.get());
TextureImpl* texture = checked_cast<TextureImpl*>(binding.resource);
return setBinding(offset, m_device->createTextureView(texture, {}));
}
case BindingType::TextureView:
{
ResourceSlot slot;
slot.type = BindingType::TextureView;
slot.resource = checked_cast<TextureViewImpl*>(binding.resource.get());
slot.resource = checked_cast<TextureViewImpl*>(binding.resource);
m_resources[bindingIndex] = slot;
break;
}
case BindingType::Sampler:
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource.get());
m_samplers[bindingIndex] = checked_cast<SamplerImpl*>(binding.resource);
break;
}

Expand Down

0 comments on commit 1b6b405

Please sign in to comment.