Skip to content

Commit

Permalink
generate ksl vk shader code
Browse files Browse the repository at this point in the history
  • Loading branch information
fabmax committed Jan 7, 2025
1 parent f199497 commit ab1c577
Show file tree
Hide file tree
Showing 20 changed files with 113 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ open class GlslGenerator(val hints: Hints) : KslGenerator() {
return GlslGeneratorOutput.shaderOutput(
generateVertexSrc(vertexStage, pipeline),
generateFragmentSrc(fragmentStage, pipeline)
)
).also {
if (program.dumpCode) {
it.dump()
}
}
}

override fun generateComputeProgram(program: KslProgram, pipeline: ComputePipeline): GlslGeneratorOutput {
val computeStage = checkNotNull(program.computeStage) {
"KslProgram computeStage is missing"
}
return GlslGeneratorOutput.computeOutput(generateComputeSrc(computeStage, pipeline))
return GlslGeneratorOutput.computeOutput(generateComputeSrc(computeStage, pipeline)).also {
if (program.dumpCode) {
it.dump()
}
}
}

private fun generateVertexSrc(vertexStage: KslVertexStage, pipeline: DrawPipeline): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ abstract class RenderBackendGl(val numSamples: Int, internal val gl: GlApi, inte

override fun generateKslShader(shader: KslShader, pipeline: DrawPipeline): ShaderCodeGl {
val src = GlslGenerator(glslGeneratorHints).generateProgram(shader.program, pipeline)
if (shader.program.dumpCode) {
src.dump()
}
return ShaderCodeGl(src.vertexSrc, src.fragmentSrc)
}

Expand All @@ -125,9 +122,6 @@ abstract class RenderBackendGl(val numSamples: Int, internal val gl: GlApi, inte
logW { "Compute shaders require OpenGL 4.3 or higher" }
}
val src = GlslGenerator(glslGeneratorHints).generateComputeProgram(shader.program, pipeline)
if (shader.program.dumpCode) {
src.dump()
}
return ComputeShaderCodeGl(src.computeSrc)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ class WgslGenerator : KslGenerator() {
return WgslGeneratorOutput.shaderOutput(
generateVertexSrc(vertexStage, pipeline),
generateFragmentSrc(fragmentStage, pipeline)
)
).also {
if (program.dumpCode) {
it.dump()
}
}
}

override fun generateComputeProgram(program: KslProgram, pipeline: ComputePipeline): WgslGeneratorOutput {
Expand All @@ -39,7 +43,11 @@ class WgslGenerator : KslGenerator() {
}

generatorState = GeneratorState(pipeline.bindGroupLayouts, null)
return WgslGeneratorOutput.computeOutput(generateComputeSrc(computeStage, pipeline))
return WgslGeneratorOutput.computeOutput(generateComputeSrc(computeStage, pipeline)).also {
if (program.dumpCode) {
it.dump()
}
}
}

private fun generateVertexSrc(vertexStage: KslVertexStage, pipeline: DrawPipeline): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ class LongHashBuilder {

inline fun LongHash(block: LongHashBuilder.() -> Unit): LongHash = LongHashBuilder().apply(block).build()

// intentionally not a value class to avoid continuous boxing when used as a map key
data class LongHash(val hash: Long)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package de.fabmax.kool.pipeline.backend.vk.util
package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.pipeline.ShaderStage
import de.fabmax.kool.pipeline.TexFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import de.fabmax.kool.util.memStack
import org.lwjgl.vulkan.KHRSwapchain.VK_IMAGE_LAYOUT_PRESENT_SRC_KHR
import org.lwjgl.vulkan.VK10.*

class OnScreenRenderPass(val swapChain: SwapChain) :
VkRenderPass(swapChain.backend, swapChain.extent.width(), swapChain.extent.height(), listOf(swapChain.imageFormat))
class OnScreenRenderPass(val swapchain: Swapchain) :
VkRenderPass(swapchain.backend, swapchain.extent.width(), swapchain.extent.height(), listOf(swapchain.imageFormat))
{

override val vkRenderPass: Long
Expand All @@ -15,7 +15,7 @@ class OnScreenRenderPass(val swapChain: SwapChain) :
memStack {
val attachments = callocVkAttachmentDescriptionN(3) {
this[0]
.format(swapChain.imageFormat)
.format(swapchain.imageFormat)
.samples(physicalDevice.msaaSamples)
.loadOp(VK_ATTACHMENT_LOAD_OP_CLEAR)
.storeOp(VK_ATTACHMENT_STORE_OP_STORE)
Expand All @@ -33,7 +33,7 @@ class OnScreenRenderPass(val swapChain: SwapChain) :
.initialLayout(VK_IMAGE_LAYOUT_UNDEFINED)
.finalLayout(VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL)
this[2]
.format(swapChain.imageFormat)
.format(swapchain.imageFormat)
.samples(VK_SAMPLE_COUNT_1_BIT)
.loadOp(VK_ATTACHMENT_LOAD_OP_DONT_CARE)
.storeOp(VK_ATTACHMENT_STORE_OP_STORE)
Expand Down Expand Up @@ -83,7 +83,7 @@ class OnScreenRenderPass(val swapChain: SwapChain) :
vkRenderPass = checkCreateLongPtr { vkCreateRenderPass(logicalDevice.vkDevice, renderPassInfo, null, it) }
}

swapChain.addDependingResource(this)
swapchain.addDependingResource(this)
logD { "Created render pass" }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class RenderBackendVk(val ctx: Lwjgl3Context) : RenderBackendJvm {
val physicalDevice: PhysicalDevice
val logicalDevice: LogicalDevice
val memManager: MemoryManager
var swapchain: SwapChain
var swapchain: Swapchain
val commandPool: CommandPool
val commandBuffer: VkCommandBuffer

Expand All @@ -54,8 +54,6 @@ class RenderBackendVk(val ctx: Lwjgl3Context) : RenderBackendJvm {
//val transferCommandPool: CommandPool
//val renderLoop: RenderLoop

// private val shaderCodes = mutableMapOf<String, ShaderCodeImplVk>()

// private val vkScene = KoolVkScene()

//private val semaPool: SemaphorePool
Expand All @@ -79,7 +77,7 @@ class RenderBackendVk(val ctx: Lwjgl3Context) : RenderBackendJvm {
deviceName = physicalDevice.deviceName

memManager = MemoryManager(this)
swapchain = SwapChain(this)
swapchain = Swapchain(this)
commandPool = CommandPool(this, logicalDevice.graphicsQueue)

val buffers = commandPool.createCommandBuffers(1)
Expand Down Expand Up @@ -125,27 +123,13 @@ class RenderBackendVk(val ctx: Lwjgl3Context) : RenderBackendJvm {
}

override fun generateKslShader(shader: KslShader, pipeline: DrawPipeline): ShaderCode {
TODO()
// val src = KslGlslGeneratorVk().generateProgram(shader.program, pipeline)
// if (shader.program.dumpCode) {
// src.dump()
// }
// val codeKey = src.vertexSrc + src.fragmentSrc
// return shaderCodes.getOrPut(codeKey) {
// ShaderCodeImplVk.vkCodeFromSource(src.vertexSrc, src.fragmentSrc)
// }
val src = KslGlslGeneratorVk().generateProgram(shader.program, pipeline)
return ShaderCodeVk.drawShaderCode(src.vertexSrc, src.fragmentSrc)
}

override fun generateKslComputeShader(shader: KslComputeShader, pipeline: ComputePipeline): ComputeShaderCode {
TODO()
// val src = KslGlslGeneratorVk().generateComputeProgram(shader.program, pipeline)
// if (shader.program.dumpCode) {
// src.dump()
// }
// val codeKey = src.computeSrc
// return shaderCodes.getOrPut(codeKey) {
// ShaderCodeImplVk.vkComputeCodeFromSource(src.computeSrc)
// }
val src = KslGlslGeneratorVk().generateComputeProgram(shader.program, pipeline)
return ShaderCodeVk.computeShaderCode(src.computeSrc)
}

override fun renderFrame(ctx: KoolContext) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.pipeline.ComputeShaderCode
import de.fabmax.kool.pipeline.ShaderCode
import de.fabmax.kool.pipeline.backend.vk.pipeline.ShaderStage
import de.fabmax.kool.util.LongHash
import de.fabmax.kool.util.logE
import org.lwjgl.vulkan.VK10.*

class ShaderCodeVk(val stages: List<ShaderStage>): ShaderCode, ComputeShaderCode {

override val hash: LongHash = LongHash {
stages.forEach { this += it.hash }
}

companion object {
private val shaderCache = mutableMapOf<ShaderKey, ShaderStage>()

fun drawShaderCode(vertShaderSrc: String, fragShaderSrc: String): ShaderCodeVk {
try {
val vertexStage = shaderCache.getOrPut(ShaderKey(vertShaderSrc, VK_SHADER_STAGE_VERTEX_BIT)) {
ShaderStage.fromSource("vertShader", vertShaderSrc, VK_SHADER_STAGE_VERTEX_BIT)
}
val fragmentStage = shaderCache.getOrPut(ShaderKey(fragShaderSrc, VK_SHADER_STAGE_FRAGMENT_BIT)) {
ShaderStage.fromSource("fragShader", fragShaderSrc, VK_SHADER_STAGE_FRAGMENT_BIT)
}
return ShaderCodeVk(listOf(vertexStage, fragmentStage))
} catch (e: Exception) {
logE { "Compilation failed: $e" }
dumpCode(listOf("Vertex shader" to vertShaderSrc, "Fragment shader" to fragShaderSrc))
throw RuntimeException(e)
}
}

fun computeShaderCode(computeShaderSrc: String): ShaderCodeVk {
try {
val computeStage = shaderCache.getOrPut(ShaderKey(computeShaderSrc, VK_SHADER_STAGE_COMPUTE_BIT)) {
ShaderStage.fromSource("computeShader", computeShaderSrc, VK_SHADER_STAGE_COMPUTE_BIT)
}
return ShaderCodeVk(listOf(computeStage))
} catch (e: Exception) {
logE { "Compilation failed: $e" }
dumpCode(listOf("Compute shader" to computeShaderSrc))
throw RuntimeException(e)
}
}

private fun dumpCode(sources: List<Pair<String, String>>) {
sources.forEach { (name, source) ->
println("$name:\n\n")
source.lines().forEachIndexed { i, l ->
println(String.format("%3d: %s", i+1, l))
}
}
}
}

private data class ShaderKey(val src: String, val stageBit: Int)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package de.fabmax.kool.pipeline.backend.vk.util
package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.util.logE
import org.lwjgl.util.shaderc.Shaderc
Expand All @@ -7,16 +7,19 @@ import java.nio.ByteBuffer
object Shaderc {
private val compiler = Shaderc.shaderc_compiler_initialize()

fun compileVertexShader(src: String, name: String = "shader.vert", entryPoint: String = "main") =
fun compileVertexShader(src: String, name: String, entryPoint: String) =
compileShader(src, Shaderc.shaderc_vertex_shader, name, entryPoint)

fun compileFragmentShader(src: String, name: String = "shader.frag", entryPoint: String = "main") =
fun compileFragmentShader(src: String, name: String, entryPoint: String) =
compileShader(src, Shaderc.shaderc_fragment_shader, name, entryPoint)

fun compileComputeShader(src: String, name: String, entryPoint: String) =
compileShader(src, Shaderc.shaderc_compute_shader, name, entryPoint)

private fun compileShader(src: String, shaderKind: Int, fName: String, entryPoint: String): CompileResult {
val options = Shaderc.shaderc_compile_options_initialize()
Shaderc.shaderc_compile_options_set_optimization_level(options, Shaderc.shaderc_optimization_level_performance)
Shaderc.shaderc_compile_options_set_target_env(options, Shaderc.shaderc_target_env_webgpu, 0)
Shaderc.shaderc_compile_options_set_target_env(options, Shaderc.shaderc_target_env_vulkan, 0)

val result = Shaderc.shaderc_compile_into_spv(compiler, src, shaderKind, fName, entryPoint, options)
val status = Shaderc.shaderc_result_get_compilation_status(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.lwjgl.vulkan.KHRSwapchain.*
import org.lwjgl.vulkan.VK10.*
import org.lwjgl.vulkan.VkExtent2D

class SwapChain(val backend: RenderBackendVk) : VkResource() {
class Swapchain(val backend: RenderBackendVk) : VkResource() {

private val physicalDevice: PhysicalDevice get() = backend.physicalDevice
private val logicalDevice: LogicalDevice get() = backend.logicalDevice
Expand Down Expand Up @@ -72,7 +72,7 @@ class SwapChain(val backend: RenderBackendVk) : VkResource() {
}

imageFormat = surfaceFormat.format()
this@SwapChain.extent.set(extent)
this@Swapchain.extent.set(extent)

val imgs = enumerateLongs { cnt, imgs ->
vkGetSwapchainImagesKHR(logicalDevice.vkDevice, vkSwapchain.handle, cnt, imgs)
Expand All @@ -88,7 +88,7 @@ class SwapChain(val backend: RenderBackendVk) : VkResource() {
}
}

renderPass = OnScreenRenderPass(this@SwapChain)
renderPass = OnScreenRenderPass(this@Swapchain)

val (cImage, cImageView) = createColorResources()
colorImage = cImage.also { addDependingResource(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.math.getNumMipLevels
import de.fabmax.kool.pipeline.*
import de.fabmax.kool.pipeline.backend.vk.util.vkBytesPerPx
import de.fabmax.kool.pipeline.backend.vk.util.vkFormat
import de.fabmax.kool.util.Uint8Buffer
import de.fabmax.kool.util.logE
import de.fabmax.kool.util.logW
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.pipeline.*
import de.fabmax.kool.pipeline.backend.vk.util.vkFormat
import de.fabmax.kool.platform.Lwjgl3Context
import de.fabmax.kool.util.BaseReleasable
import de.fabmax.kool.util.launchDelayed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package de.fabmax.kool.pipeline.backend.vk

import de.fabmax.kool.pipeline.*
import de.fabmax.kool.pipeline.backend.vk.util.vkFormat
import de.fabmax.kool.platform.Lwjgl3Context
import de.fabmax.kool.util.BaseReleasable
import de.fabmax.kool.util.launchDelayed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ interface VkScene {

fun onLoad(sys: VkSystem)

fun onSwapChainCreated(swapChain: SwapChain)
fun onSwapChainCreated(swapChain: Swapchain)

fun onDrawFrame(swapChain: SwapChain, imageIndex: Int, fence: LongBuffer, waitSema: LongBuffer, signalSema: LongBuffer)
fun onDrawFrame(swapChain: Swapchain, imageIndex: Int, fence: LongBuffer, waitSema: LongBuffer, signalSema: LongBuffer)

fun onDestroy(sys: VkSystem)

Expand Down
Loading

0 comments on commit ab1c577

Please sign in to comment.