Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXAMPLE 4: Compute Shaders #4

Open
wants to merge 1 commit into
base: 3-hello-uniform-buffers
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 107 additions & 29 deletions samples/sample/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ namespace {
if (buttons[2] || (buttons[0] && buttons[1])) {
camera->pan(-dX * 0.002f, dY * -0.002f);
memcpy(mappedCameraView, &camera->view(), sizeof(glm::mat4));
} else if (buttons[0]) {
}
else if (buttons[0]) {
camera->rotate(dX * -0.01f, dY * -0.01f);
memcpy(mappedCameraView, &camera->view(), sizeof(glm::mat4));
} else if (buttons[1]) {
}
else if (buttons[1]) {
camera->zoom(dY * -0.005f);
memcpy(mappedCameraView, &camera->view(), sizeof(glm::mat4));
}
Expand All @@ -48,8 +50,8 @@ namespace {
}

struct Vertex {
float position[3];
float color[3];
glm::vec4 position;
glm::vec4 color;
};

struct CameraUBO {
Expand Down Expand Up @@ -125,19 +127,22 @@ VkDescriptorSetLayout CreateDescriptorSetLayout(std::vector<VkDescriptorSetLayou

VkDescriptorPool CreateDescriptorPool() {
// Info for the types of descriptors that can be allocated from this pool
VkDescriptorPoolSize poolSizes[2];
VkDescriptorPoolSize poolSizes[3];
poolSizes[0].type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
poolSizes[0].descriptorCount = 1;

poolSizes[1].type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
poolSizes[1].descriptorCount = 1;

poolSizes[2].type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
poolSizes[2].descriptorCount = 1;

VkDescriptorPoolCreateInfo descriptorPoolInfo = {};
descriptorPoolInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descriptorPoolInfo.pNext = nullptr;
descriptorPoolInfo.poolSizeCount = 2;
descriptorPoolInfo.poolSizeCount = 3;
descriptorPoolInfo.pPoolSizes = poolSizes;
descriptorPoolInfo.maxSets = 2;
descriptorPoolInfo.maxSets = 3;

VkDescriptorPool descriptorPool;
vkCreateDescriptorPool(device->GetVulkanDevice(), &descriptorPoolInfo, nullptr, &descriptorPool);
Expand Down Expand Up @@ -172,7 +177,7 @@ VkPipelineLayout CreatePipelineLayout(std::vector<VkDescriptorSetLayout> descrip
return pipelineLayout;
}

VkPipeline CreatePipeline(VkPipelineLayout pipelineLayout, VkRenderPass renderPass, unsigned int subpass) {
VkPipeline CreateGraphicsPipeline(VkPipelineLayout pipelineLayout, VkRenderPass renderPass, unsigned int subpass) {
VkShaderModule vertShaderModule = createShaderModule("sample/shaders/shader.vert.spv", device->GetVulkanDevice());
VkShaderModule fragShaderModule = createShaderModule("sample/shaders/shader.frag.spv", device->GetVulkanDevice());

Expand Down Expand Up @@ -321,8 +326,32 @@ VkPipeline CreatePipeline(VkPipelineLayout pipelineLayout, VkRenderPass renderPa
}

// No need for the shader modules anymore
vkDestroyShaderModule(device->GetVulkanDevice(), vertShaderModule, nullptr);
vkDestroyShaderModule(device->GetVulkanDevice(), fragShaderModule, nullptr);
vkDestroyShaderModule(device->GetVulkanDevice(), vertShaderModule, nullptr);
vkDestroyShaderModule(device->GetVulkanDevice(), fragShaderModule, nullptr);

return pipeline;
}

VkPipeline CreateComputePipeline(VkPipelineLayout pipelineLayout) {
VkShaderModule compShaderModule = createShaderModule("sample/shaders/shader.comp.spv", device->GetVulkanDevice());

VkPipelineShaderStageCreateInfo compShaderStageInfo = {};
compShaderStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
compShaderStageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
compShaderStageInfo.module = compShaderModule;
compShaderStageInfo.pName = "main";

VkComputePipelineCreateInfo pipelineInfo = {};
pipelineInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipelineInfo.stage = compShaderStageInfo;
pipelineInfo.layout = pipelineLayout;

VkPipeline pipeline;
if (vkCreateComputePipelines(device->GetVulkanDevice(), VK_NULL_HANDLE, 1, &pipelineInfo, nullptr, &pipeline) != VK_SUCCESS) {
throw std::runtime_error("Failed to create pipeline");
}

vkDestroyShaderModule(device->GetVulkanDevice(), compShaderModule, nullptr);

return pipeline;
}
Expand All @@ -349,7 +378,7 @@ std::vector<VkFramebuffer> CreateFrameBuffers(VkRenderPass renderPass) {
}

int main(int argc, char** argv) {
static constexpr char* applicationName = "Hello Uniform Buffers";
static constexpr char* applicationName = "Hello Compute";
InitializeWindow(640, 480, applicationName);

unsigned int glfwExtensionCount = 0;
Expand All @@ -362,9 +391,9 @@ int main(int argc, char** argv) {
throw std::runtime_error("Failed to create window surface");
}

instance->PickPhysicalDevice({ VK_KHR_SWAPCHAIN_EXTENSION_NAME }, QueueFlagBit::GraphicsBit | QueueFlagBit::TransferBit | QueueFlagBit::PresentBit, surface);
instance->PickPhysicalDevice({ VK_KHR_SWAPCHAIN_EXTENSION_NAME }, QueueFlagBit::GraphicsBit | QueueFlagBit::TransferBit | QueueFlagBit::ComputeBit | QueueFlagBit::PresentBit, surface);

device = instance->CreateDevice(QueueFlagBit::GraphicsBit | QueueFlagBit::TransferBit | QueueFlagBit::PresentBit);
device = instance->CreateDevice(QueueFlagBit::GraphicsBit | QueueFlagBit::TransferBit | QueueFlagBit::ComputeBit | QueueFlagBit::PresentBit);
swapchain = device->CreateSwapChain(surface);

VkCommandPoolCreateInfo poolInfo = {};
Expand All @@ -385,9 +414,9 @@ int main(int argc, char** argv) {
modelTransforms.modelMatrix = glm::rotate(glm::mat4(1.f), static_cast<float>(15 * M_PI / 180), glm::vec3(0.f, 0.f, 1.f));

std::vector<Vertex> vertices = {
{ { 0.5f, 0.5f, 0.0f }, { 0.0f, 1.0f, 0.0f } },
{ { -0.5f, 0.5f, 0.0f }, { 0.0f, 0.0f, 1.0f } },
{ { 0.0f, -0.5f, 0.0f }, { 1.0f, 0.0f, 0.0f } }
{ { 0.5f, 0.5f, 0.0f, 1.f },{ 0.0f, 1.0f, 0.0f, 1.f } },
{ { -0.5f, 0.5f, 0.0f, 1.f },{ 0.0f, 0.0f, 1.0f, 1.f } },
{ { 0.0f, -0.5f, 0.0f, 1.f },{ 1.0f, 0.0f, 0.0f, 1.f } }
};

std::vector<unsigned int> indices = { 0, 1, 2 };
Expand All @@ -396,7 +425,7 @@ int main(int argc, char** argv) {
unsigned int indexBufferSize = static_cast<uint32_t>(indices.size() * sizeof(indices[0]));

// Create vertex and index buffers
VkBuffer vertexBuffer = CreateBuffer(device, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT, vertexBufferSize);
VkBuffer vertexBuffer = CreateBuffer(device, VK_BUFFER_USAGE_VERTEX_BUFFER_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, vertexBufferSize);
VkBuffer indexBuffer = CreateBuffer(device, VK_BUFFER_USAGE_INDEX_BUFFER_BIT, indexBufferSize);
unsigned int vertexBufferOffsets[2];
VkDeviceMemory vertexBufferMemory = AllocateMemoryForBuffers(device, { vertexBuffer, indexBuffer }, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, vertexBufferOffsets);
Expand Down Expand Up @@ -433,6 +462,10 @@ int main(int argc, char** argv) {

VkDescriptorPool descriptorPool = CreateDescriptorPool();

VkDescriptorSetLayout computeSetLayout = CreateDescriptorSetLayout({
{ 0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr },
});

VkDescriptorSetLayout cameraSetLayout = CreateDescriptorSetLayout({
{ 0, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1, VK_SHADER_STAGE_VERTEX_BIT, nullptr },
});
Expand All @@ -441,11 +474,25 @@ int main(int argc, char** argv) {
{ 0, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1, VK_SHADER_STAGE_VERTEX_BIT, nullptr },
});

VkDescriptorSet computeSet = CreateDescriptorSet(descriptorPool, computeSetLayout);
VkDescriptorSet cameraSet = CreateDescriptorSet(descriptorPool, cameraSetLayout);
VkDescriptorSet modelSet = CreateDescriptorSet(descriptorPool, modelSetLayout);

// Initialize descriptor sets
{
VkDescriptorBufferInfo computeBufferInfo = {};
computeBufferInfo.buffer = vertexBuffer;
computeBufferInfo.offset = 0;
computeBufferInfo.range = vertexBufferSize;

VkWriteDescriptorSet writeComputeInfo = {};
writeComputeInfo.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
writeComputeInfo.dstSet = computeSet;
writeComputeInfo.dstBinding = 0;
writeComputeInfo.descriptorCount = 1;
writeComputeInfo.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
writeComputeInfo.pBufferInfo = &computeBufferInfo;

VkDescriptorBufferInfo cameraBufferInfo = {};
cameraBufferInfo.buffer = cameraBuffer;
cameraBufferInfo.offset = 0;
Expand All @@ -472,14 +519,16 @@ int main(int argc, char** argv) {
writeModelInfo.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
writeModelInfo.pBufferInfo = &modelBufferInfo;

VkWriteDescriptorSet writeDescriptorSets[] = { writeCameraInfo, writeModelInfo };
VkWriteDescriptorSet writeDescriptorSets[] = { writeComputeInfo, writeCameraInfo, writeModelInfo };

vkUpdateDescriptorSets(device->GetVulkanDevice(), 2, writeDescriptorSets, 0, nullptr);
vkUpdateDescriptorSets(device->GetVulkanDevice(), 3, writeDescriptorSets, 0, nullptr);
}

VkRenderPass renderPass = CreateRenderPass();
VkPipelineLayout pipelineLayout = CreatePipelineLayout({ cameraSetLayout, modelSetLayout });
VkPipeline pipeline = CreatePipeline(pipelineLayout, renderPass, 0);
VkPipelineLayout computePipelineLayout = CreatePipelineLayout({ computeSetLayout });
VkPipelineLayout graphicsPipelineLayout = CreatePipelineLayout({ cameraSetLayout, modelSetLayout });
VkPipeline computePipeline = CreateComputePipeline(computePipelineLayout);
VkPipeline graphicsPipeline = CreateGraphicsPipeline(graphicsPipelineLayout, renderPass, 0);

// Create one framebuffer for each frame of the swap chain
std::vector<VkFramebuffer> frameBuffers = CreateFrameBuffers(renderPass);
Expand Down Expand Up @@ -518,16 +567,44 @@ int main(int argc, char** argv) {
renderPassInfo.clearValueCount = 1;
renderPassInfo.pClearValues = &clearColor;

// Bind the compute pipeline
vkCmdBindPipeline(commandBuffers[i], VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline);

// Bind descriptor sets for compute
vkCmdBindDescriptorSets(commandBuffers[i], VK_PIPELINE_BIND_POINT_COMPUTE, computePipelineLayout, 0, 1, &computeSet, 0, nullptr);

// Dispatch the compute kernel, with one thread for each vertex
vkCmdDispatch(commandBuffers[i], vertices.size(), 1, 1);

// Define a memory barrier to transition the vertex buffer from a compute storage object to a vertex input
VkBufferMemoryBarrier computeToVertexBarrier = {};
computeToVertexBarrier.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
computeToVertexBarrier.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
computeToVertexBarrier.dstAccessMask = VK_ACCESS_VERTEX_ATTRIBUTE_READ_BIT;
computeToVertexBarrier.srcQueueFamilyIndex = device->GetQueueIndex(QueueFlags::Compute);
computeToVertexBarrier.dstQueueFamilyIndex = device->GetQueueIndex(QueueFlags::Graphics);
computeToVertexBarrier.buffer = vertexBuffer;
computeToVertexBarrier.offset = 0;
computeToVertexBarrier.size = vertexBufferSize;

vkCmdPipelineBarrier(commandBuffers[i],
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_VERTEX_INPUT_BIT,
0,
0, nullptr,
1, &computeToVertexBarrier,
0, nullptr);

vkCmdBeginRenderPass(commandBuffers[i], &renderPassInfo, VK_SUBPASS_CONTENTS_INLINE);

// Bind camera descriptor set
vkCmdBindDescriptorSets(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout, 0, 1, &cameraSet, 0, nullptr);
vkCmdBindDescriptorSets(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, graphicsPipelineLayout, 0, 1, &cameraSet, 0, nullptr);

// Bind the graphics pipeline
vkCmdBindPipeline(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, pipeline);
vkCmdBindPipeline(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, graphicsPipeline);

// Bind model descriptor set
vkCmdBindDescriptorSets(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, pipelineLayout, 1, 1, &modelSet, 0, nullptr);
vkCmdBindDescriptorSets(commandBuffers[i], VK_PIPELINE_BIND_POINT_GRAPHICS, graphicsPipelineLayout, 1, 1, &modelSet, 0, nullptr);

VkDeviceSize offsets[1] = { 0 };
vkCmdBindVertexBuffers(commandBuffers[i], 0, 1, &vertexBuffer, offsets);
Expand All @@ -538,9 +615,6 @@ int main(int argc, char** argv) {
// Draw indexed triangle
vkCmdDrawIndexed(commandBuffers[i], 3, 1, 0, 0, 1);

// Draw
vkCmdDraw(commandBuffers[i], 3, 1, 0, 0);

vkCmdEndRenderPass(commandBuffers[i]);

if (vkEndCommandBuffer(commandBuffers[i]) != VK_SUCCESS) {
Expand Down Expand Up @@ -598,12 +672,16 @@ int main(int argc, char** argv) {
vkDestroyBuffer(device->GetVulkanDevice(), modelBuffer, nullptr);
vkFreeMemory(device->GetVulkanDevice(), uniformBufferMemory, nullptr);

vkDestroyDescriptorSetLayout(device->GetVulkanDevice(), computeSetLayout, nullptr);
vkDestroyDescriptorSetLayout(device->GetVulkanDevice(), cameraSetLayout, nullptr);
vkDestroyDescriptorSetLayout(device->GetVulkanDevice(), modelSetLayout, nullptr);
vkDestroyDescriptorPool(device->GetVulkanDevice(), descriptorPool, nullptr);

vkDestroyPipeline(device->GetVulkanDevice(), pipeline, nullptr);
vkDestroyPipelineLayout(device->GetVulkanDevice(), pipelineLayout, nullptr);
vkDestroyPipeline(device->GetVulkanDevice(), computePipeline, nullptr);
vkDestroyPipelineLayout(device->GetVulkanDevice(), computePipelineLayout, nullptr);

vkDestroyPipeline(device->GetVulkanDevice(), graphicsPipeline, nullptr);
vkDestroyPipelineLayout(device->GetVulkanDevice(), graphicsPipelineLayout, nullptr);
vkDestroyRenderPass(device->GetVulkanDevice(), renderPass, nullptr);
vkFreeCommandBuffers(device->GetVulkanDevice(), commandPool, static_cast<uint32_t>(commandBuffers.size()), commandBuffers.data());
for (size_t i = 0; i < frameBuffers.size(); i++) {
Expand Down
22 changes: 22 additions & 0 deletions samples/sample/shaders/shader.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#version 450
#extension GL_ARB_separate_shader_objects : enable

struct Vertex {
vec4 position;
vec4 color;
};

layout(set = 0, binding = 0) buffer Vertices {
Vertex vertices[];
};

void main() {
uint index = gl_GlobalInvocationID.x;
const float a = 0.05 * 3.14159 / 180.0;
vertices[index].position *= mat4(
cos(a), -sin(a), 0.0, 0.0,
sin(a), cos(a), 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0
);
}
4 changes: 4 additions & 0 deletions samples/sample/vulkan_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ VkQueue VulkanDevice::GetQueue(QueueFlags flag) {
return queues[flag];
}

unsigned int VulkanDevice::GetQueueIndex(QueueFlags flag) {
return GetInstance()->GetQueueFamilyIndices()[flag];
}

VulkanSwapChain* VulkanDevice::CreateSwapChain(VkSurfaceKHR surface) {
return new VulkanSwapChain(this, surface);
}
Expand Down
2 changes: 1 addition & 1 deletion samples/sample/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class VulkanDevice {
VulkanInstance* GetInstance();
VkDevice GetVulkanDevice();
VkQueue GetQueue(QueueFlags flag);
unsigned int GetQueueIndex(QueueFlags flag);
~VulkanDevice();

private:
Expand All @@ -24,5 +25,4 @@ class VulkanDevice {
VulkanInstance* instance;
VkDevice vkDevice;
Queues queues;
QueueFamilyIndices queueIndices;
};