diff --git a/.github/workflows/ci-examples.sh b/.github/workflows/ci-examples.sh index 7372d70a7a..157f8c3041 100755 --- a/.github/workflows/ci-examples.sh +++ b/.github/workflows/ci-examples.sh @@ -170,6 +170,9 @@ function run_sample { pushd "$bin_dir" 1>/dev/null 2>&1 if [[ ! "$dry_run" = true ]]; then ./"$sample" "${args[@]}" || result=$? + if [[ -f ./"log-$sample.txt" ]]; then + cat ./"log-$sample.txt" + fi fi if [[ $result -eq 0 ]]; then summary=("${summary[@]}" " success") diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e46f41e7a9..759d99994e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,6 @@ function(example dir) + cmake_parse_arguments(ARG "WIN32_EXECUTABLE" "" "" ${ARGN}) + set(debug_dir ${CMAKE_CURRENT_BINARY_DIR}/${dir}) file( @@ -30,6 +32,22 @@ function(example dir) ) endif() + # Libraries providing a main function that prints stack traces on exceptions + if(CMAKE_SYSTEM_NAME MATCHES "Windows") + # On Windows we have two different versions: main for "console applications" and + # WinMain for normal Windows applications. + if(${ARG_WIN32_EXECUTABLE}) + set(main_wrapper_libraries example-winmain) + else() + set(main_wrapper_libraries example-main) + endif() + # Add stack printing support + set(main_wrapper_libraries ${main_wrapper_libraries} stacktrace-windows) + set(main_wrapper_libraries ${main_wrapper_libraries} dbghelp.lib) + else() + set(main_wrapper_libraries example-main) + endif() + slang_add_target( ${dir} EXECUTABLE @@ -42,7 +60,9 @@ function(example dir) gfx-util platform $<$:CUDA::cuda_driver> + ${main_wrapper_libraries} EXTRA_COMPILE_DEFINITIONS_PRIVATE + SLANG_EXAMPLE_NAME=${dir} $<$:SLANG_ENABLE_XLIB> REQUIRED_BY all-examples OPTIONAL_REQUIRES ${copy_assets_target} copy-prebuilt-binaries @@ -68,6 +88,9 @@ if(SLANG_ENABLE_EXAMPLES) $<$:CUDA::cuda_driver> FOLDER examples ) + slang_add_target(example-main STATIC FOLDER examples) + slang_add_target(example-winmain STATIC FOLDER examples EXCLUDE_FROM_ALL) + slang_add_target(stacktrace-windows STATIC FOLDER examples EXCLUDE_FROM_ALL) add_custom_target( all-examples diff --git a/examples/autodiff-texture/main.cpp b/examples/autodiff-texture/main.cpp index d0c35d003f..d99f9f341f 100644 --- a/examples/autodiff-texture/main.cpp +++ b/examples/autodiff-texture/main.cpp @@ -823,4 +823,4 @@ struct AutoDiffTexture : public WindowedAppBase } }; -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/cpu-com-example/main.cpp b/examples/cpu-com-example/main.cpp index 382b3cacd0..6c67215b46 100644 --- a/examples/cpu-com-example/main.cpp +++ b/examples/cpu-com-example/main.cpp @@ -175,7 +175,7 @@ static SlangResult _innerMain(int argc, char** argv) return SLANG_OK; } -int main(int argc, char** argv) +int exampleMain(int argc, char** argv) { return SLANG_SUCCEEDED(_innerMain(argc, argv)) ? 0 : -1; } diff --git a/examples/cpu-hello-world/main.cpp b/examples/cpu-hello-world/main.cpp index 60a24fa8c1..76ca5af1de 100644 --- a/examples/cpu-hello-world/main.cpp +++ b/examples/cpu-hello-world/main.cpp @@ -217,7 +217,7 @@ static SlangResult _innerMain(int argc, char** argv) return SLANG_OK; } -int main(int argc, char** argv) +int exampleMain(int argc, char** argv) { return SLANG_SUCCEEDED(_innerMain(argc, argv)) ? 0 : -1; } diff --git a/examples/example-base/example-base.h b/examples/example-base/example-base.h index 6988d613be..9aabac8d44 100644 --- a/examples/example-base/example-base.h +++ b/examples/example-base/example-base.h @@ -10,6 +10,19 @@ void _Win32OutputDebugString(const char* str); #endif +#define SLANG_STRINGIFY(x) #x +#define SLANG_EXPAND_STRINGIFY(x) SLANG_STRINGIFY(x) + +#ifdef _WIN32 +#define EXAMPLE_MAIN(innerMain) \ + extern const char* const g_logFileName = \ + "log-" SLANG_EXPAND_STRINGIFY(SLANG_EXAMPLE_NAME) ".txt"; \ + PLATFORM_UI_MAIN(innerMain); + +#else +#define EXAMPLE_MAIN(innerMain) PLATFORM_UI_MAIN(innerMain) +#endif // _WIN32 + struct WindowedAppBase : public TestBase { protected: diff --git a/examples/example-main/main.cpp b/examples/example-main/main.cpp new file mode 100644 index 0000000000..46ffc7278d --- /dev/null +++ b/examples/example-main/main.cpp @@ -0,0 +1,32 @@ +#include "../stacktrace-windows/common.h" + +#include +#include + +extern int exampleMain(int argc, char** argv); + +#if defined(_WIN32) + +#include + +int main(int argc, char** argv) +{ + __try + { + return exampleMain(argc, argv); + } + __except (exceptionFilter(stdout, GetExceptionInformation())) + { + ::exit(1); + } +} + +#else // defined(_WIN32) + +int main(int argc, char** argv) +{ + // TODO: Catch exception and print stack trace also on non-Windows platforms. + return exampleMain(argc, argv); +} + +#endif diff --git a/examples/example-winmain/main.cpp b/examples/example-winmain/main.cpp new file mode 100644 index 0000000000..8094e7fc43 --- /dev/null +++ b/examples/example-winmain/main.cpp @@ -0,0 +1,28 @@ +#include "../stacktrace-windows/common.h" + +#include +#include +#include + +extern int exampleMain(int argc, char** argv); +extern const char* const g_logFileName; + +int WinMain( + HINSTANCE /* instance */, + HINSTANCE /* prevInstance */, + LPSTR /* commandLine */, + int /*showCommand*/) + +{ + FILE* logFile = fopen(g_logFileName, "w"); + __try + { + int argc = 0; + char** argv = nullptr; + return exampleMain(argc, argv); + } + __except (exceptionFilter(logFile, GetExceptionInformation())) + { + ::exit(1); + } +} diff --git a/examples/gpu-printing/main.cpp b/examples/gpu-printing/main.cpp index bbc300dba4..27a77a82b5 100644 --- a/examples/gpu-printing/main.cpp +++ b/examples/gpu-printing/main.cpp @@ -152,7 +152,7 @@ struct ExampleProgram : public TestBase } }; -int main(int argc, char* argv[]) +int exampleMain(int argc, char** argv) { ExampleProgram app; if (SLANG_FAILED(app.execute(argc, argv))) diff --git a/examples/hello-world/main.cpp b/examples/hello-world/main.cpp index 7e84d83e5c..fbf67569dd 100644 --- a/examples/hello-world/main.cpp +++ b/examples/hello-world/main.cpp @@ -66,7 +66,8 @@ struct HelloWorldExample : public TestBase ~HelloWorldExample(); }; -int main(int argc, char* argv[]) + +int exampleMain(int argc, char** argv) { initDebugCallback(); HelloWorldExample example; @@ -80,7 +81,11 @@ int main(int argc, char* argv[]) int HelloWorldExample::run() { - RETURN_ON_FAIL(initVulkanInstanceAndDevice()); + // If VK failed to initialize, skip running but return success anyway. + // This allows our automated testing to distinguish between essential failures and the + // case where the application is just not supported. + if (int result = initVulkanInstanceAndDevice()) + return (vkAPI.device == VK_NULL_HANDLE) ? 0 : result; RETURN_ON_FAIL(createComputePipelineFromShader()); RETURN_ON_FAIL(createInOutBuffers()); RETURN_ON_FAIL(dispatchCompute()); @@ -511,6 +516,9 @@ int HelloWorldExample::printComputeResults() HelloWorldExample::~HelloWorldExample() { + if (vkAPI.device == VK_NULL_HANDLE) + return; + vkAPI.vkDestroyPipeline(vkAPI.device, pipeline, nullptr); for (int i = 0; i < 3; i++) { diff --git a/examples/model-viewer/main.cpp b/examples/model-viewer/main.cpp index ecca818f18..8bbc8ec88c 100644 --- a/examples/model-viewer/main.cpp +++ b/examples/model-viewer/main.cpp @@ -969,4 +969,4 @@ struct ModelViewer : WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/nv-aftermath-example/main.cpp b/examples/nv-aftermath-example/main.cpp index 9d85f1ff4f..ed6db43a2b 100644 --- a/examples/nv-aftermath-example/main.cpp +++ b/examples/nv-aftermath-example/main.cpp @@ -599,4 +599,4 @@ void AftermathCrashExample::renderFrame(int frameBufferIndex) // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain) diff --git a/examples/platform-test/main.cpp b/examples/platform-test/main.cpp index 159e26c553..865e4eab79 100644 --- a/examples/platform-test/main.cpp +++ b/examples/platform-test/main.cpp @@ -122,4 +122,4 @@ struct PlatformTest : public WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/ray-tracing-pipeline/main.cpp b/examples/ray-tracing-pipeline/main.cpp index 0fbb857f02..a3d468db1f 100644 --- a/examples/ray-tracing-pipeline/main.cpp +++ b/examples/ray-tracing-pipeline/main.cpp @@ -382,6 +382,8 @@ struct RayTracing : public WindowedAppBase asDraftBufferDesc.sizeInBytes = (size_t)accelerationStructurePrebuildInfo.resultDataMaxSize; ComPtr draftBuffer = gDevice->createBufferResource(asDraftBufferDesc); + if (!draftBuffer) + return SLANG_FAIL; IBufferResource::Desc scratchBufferDesc; scratchBufferDesc.type = IResource::Type::Buffer; scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; @@ -389,6 +391,8 @@ struct RayTracing : public WindowedAppBase (size_t)accelerationStructurePrebuildInfo.scratchDataSize; ComPtr scratchBuffer = gDevice->createBufferResource(scratchBufferDesc); + if (!scratchBuffer) + return SLANG_FAIL; // Build acceleration structure. ComPtr compactedSizeQuery; @@ -708,4 +712,4 @@ struct RayTracing : public WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/ray-tracing/main.cpp b/examples/ray-tracing/main.cpp index 11529a7538..6a0abf8b48 100644 --- a/examples/ray-tracing/main.cpp +++ b/examples/ray-tracing/main.cpp @@ -373,12 +373,16 @@ struct RayTracing : public WindowedAppBase asDraftBufferDesc.defaultState = ResourceState::AccelerationStructure; asDraftBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize; ComPtr draftBuffer = gDevice->createBufferResource(asDraftBufferDesc); + if (!draftBuffer) + return SLANG_FAIL; IBufferResource::Desc scratchBufferDesc; scratchBufferDesc.type = IResource::Type::Buffer; scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize; ComPtr scratchBuffer = gDevice->createBufferResource(scratchBufferDesc); + if (!scratchBuffer) + return SLANG_FAIL; // Build acceleration structure. ComPtr compactedSizeQuery; @@ -672,4 +676,4 @@ struct RayTracing : public WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/reflection-api/main.cpp b/examples/reflection-api/main.cpp index 5c157b7976..c072c641b9 100644 --- a/examples/reflection-api/main.cpp +++ b/examples/reflection-api/main.cpp @@ -1469,7 +1469,7 @@ struct ExampleProgram : public TestBase } }; -int main(int argc, char* argv[]) +int exampleMain(int argc, char** argv) { ExampleProgram app; if (SLANG_FAILED(app.execute(argc, argv))) diff --git a/examples/shader-object/main.cpp b/examples/shader-object/main.cpp index f5c02141f2..1010cdcb91 100644 --- a/examples/shader-object/main.cpp +++ b/examples/shader-object/main.cpp @@ -131,7 +131,7 @@ Result loadShaderProgram( } // Main body of the example. -int main(int argc, char* argv[]) +int exampleMain(int argc, char** argv) { testBase.parseOption(argc, argv); diff --git a/examples/shader-toy/main.cpp b/examples/shader-toy/main.cpp index 185d182461..42054beaeb 100644 --- a/examples/shader-toy/main.cpp +++ b/examples/shader-toy/main.cpp @@ -408,4 +408,4 @@ struct ShaderToyApp : public WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/stacktrace-windows/common.cpp b/examples/stacktrace-windows/common.cpp new file mode 100644 index 0000000000..b07f78d0a4 --- /dev/null +++ b/examples/stacktrace-windows/common.cpp @@ -0,0 +1,201 @@ +#include "common.h" + +#include +#include +#include +#include + +// dbghelp.h needs to be included after windows.h +#include + +#define SLANG_EXAMPLE_LOG_ERROR(...) \ + fprintf(file, "error: %s: %d: ", __FILE__, __LINE__); \ + print(file, __VA_ARGS__); \ + fprintf(file, "\n"); + +static void print(FILE* /* file */) {} +static void print(FILE* file, unsigned int n) +{ + fprintf(file, "%u", n); +} + + +static bool getModuleFileNameAtAddress(FILE* file, DWORD64 const address, std::string& fileName) +{ + HMODULE module = NULL; + { + BOOL result = GetModuleHandleEx( + GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + (LPCTSTR)address, + &module); + if (result == 0) + { + SLANG_EXAMPLE_LOG_ERROR(GetLastError()); + return false; + } + if (module == NULL) + { + SLANG_EXAMPLE_LOG_ERROR(); + return false; + } + } + + std::vector buffer(1U << 8U); + uint32_t constexpr maxBufferSize = 1U << 20; + while (buffer.size() < maxBufferSize) + { + DWORD result = GetModuleFileNameA(module, buffer.data(), buffer.size()); + if (result == 0) + { + SLANG_EXAMPLE_LOG_ERROR(GetLastError()); + return false; + } + else if (result == ERROR_INSUFFICIENT_BUFFER) + { + buffer.resize(buffer.size() << 1U); + } + else + { + break; + } + } + if (buffer.size() == maxBufferSize) + { + SLANG_EXAMPLE_LOG_ERROR(); + return false; + } + + fileName = std::string(buffer.data(), buffer.data() + buffer.size()); + return true; +} + +// NOTE: This function is not thread-safe, due to usage of StackWalk64 and static buffers. +static bool printStack(FILE* file, HANDLE process, HANDLE thread, CONTEXT const& context) +{ +#if defined(_M_AMD64) + DWORD constexpr machineType = IMAGE_FILE_MACHINE_AMD64; +#else +#error Unsupported machine type +#endif + + static char symbolBuffer[sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR)]; + + // StackWalk64 may modify the context record + CONTEXT contextCopy; + memcpy(&contextCopy, &context, sizeof(CONTEXT)); + + STACKFRAME64 frame = {}; + constexpr uint32_t maxFrameCount = 1U << 10; + uint32_t frameIndex = 0U; + while (frameIndex < maxFrameCount) + { + // Use the default routine + PREAD_PROCESS_MEMORY_ROUTINE64 readMemoryRoutine = NULL; + // Not sure what this is for, but documentation says most callers can pass NULL + PTRANSLATE_ADDRESS_ROUTINE64 translateAddressRoutine = NULL; + { + BOOL result = StackWalk64( + machineType, + process, + thread, + &frame, + &contextCopy, + readMemoryRoutine, + SymFunctionTableAccess64, + SymGetModuleBase64, + translateAddressRoutine); + if (result == FALSE) + break; + } + + PSYMBOL_INFO maybeSymbol = (PSYMBOL_INFO)symbolBuffer; + { + maybeSymbol->SizeOfStruct = sizeof(SYMBOL_INFO); + maybeSymbol->MaxNameLen = MAX_SYM_NAME; + DWORD64 address = frame.AddrPC.Offset; + // Not required, we want to look up the symbol exactly at the address + PDWORD64 displacement = NULL; + BOOL result = SymFromAddr(process, address, displacement, maybeSymbol); + if (result == FALSE) + { + SLANG_EXAMPLE_LOG_ERROR(GetLastError()); + maybeSymbol = NULL; + } + } + + fprintf(file, "%u", frameIndex); + + std::string moduleFileName; + if (getModuleFileNameAtAddress(file, frame.AddrPC.Offset, moduleFileName)) + fprintf(file, ": %s", moduleFileName.c_str()); + + if (maybeSymbol) + { + PSYMBOL_INFO& symbol = maybeSymbol; + + IMAGEHLP_LINE64 line = {}; + line.SizeOfStruct = sizeof(IMAGEHLP_LINE64); + + DWORD displacement; + if (SymGetLineFromAddr64(process, frame.AddrPC.Offset, &displacement, &line)) + { + fprintf(file, ": %s: %s: %lu", symbol->Name, line.FileName, line.LineNumber); + } + else + { + fprintf(file, ": %s", symbol->Name); + } + + fprintf(file, ": 0x%.16" PRIXPTR, symbol->Address); + } + fprintf(file, "\n"); + + frameIndex++; + } + + return frameIndex < maxFrameCount; +} + +int exceptionFilter(FILE* logFile, _EXCEPTION_POINTERS* exception) +{ + FILE* file = logFile ? logFile : stdout; + fprintf( + file, + "error: Exception 0x%x occurred. Stack trace:\n", + exception->ExceptionRecord->ExceptionCode); + + HANDLE process = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + + bool symbolsLoaded = false; + { + // The default search paths should suffice + PCSTR symbolFileSearchPath = NULL; + BOOL loadSymbolsOfLoadedModules = TRUE; + BOOL result = SymInitialize(process, symbolFileSearchPath, loadSymbolsOfLoadedModules); + if (result == FALSE) + { + fprintf(file, "warning: Failed to load symbols\n"); + } + else + { + symbolsLoaded = true; + } + } + + if (!printStack(file, process, thread, *exception->ContextRecord)) + { + fprintf(file, "warning: Failed to print complete stack trace!\n"); + } + + if (symbolsLoaded) + { + BOOL result = SymCleanup(process); + if (result == FALSE) + { + SLANG_EXAMPLE_LOG_ERROR(GetLastError()); + } + } + + return EXCEPTION_EXECUTE_HANDLER; +} diff --git a/examples/stacktrace-windows/common.h b/examples/stacktrace-windows/common.h new file mode 100644 index 0000000000..0f375c4314 --- /dev/null +++ b/examples/stacktrace-windows/common.h @@ -0,0 +1,4 @@ +#pragma once +#include + +int exceptionFilter(FILE* logFile, struct _EXCEPTION_POINTERS* exception); diff --git a/examples/triangle/main.cpp b/examples/triangle/main.cpp index f757b59c70..6fd36f72d7 100644 --- a/examples/triangle/main.cpp +++ b/examples/triangle/main.cpp @@ -405,4 +405,4 @@ struct HelloWorld : public WindowedAppBase // This macro instantiates an appropriate main function to // run the application defined above. -PLATFORM_UI_MAIN(innerMain) +EXAMPLE_MAIN(innerMain); diff --git a/examples/wgpu-slang-wasm/README.md b/examples/wgpu-slang-wasm/README.md new file mode 100644 index 0000000000..f0dd5d6a64 --- /dev/null +++ b/examples/wgpu-slang-wasm/README.md @@ -0,0 +1,16 @@ +# Simple WebGPU example using Slang WebAssembly library + +## Description + +This is a simple example showing how WebGPU applications can use the slang-wasm library to compile slang shaders at runtime to WGSL. +The resulting application shows a green triangle rendered on a black background. + +## Instructions + +Follow the WebAssembly build instructions in `docs/building.md` to produce `slang-wasm.js` and `slang-wasm.wasm`, and place these files in this directory. + +Start a web server, for example by running the following command in this directory: + + $ python -m http.server + +Finally, visit `http://localhost:8000/` to see the application running in your browser. \ No newline at end of file diff --git a/examples/wgpu-slang-wasm/example.js b/examples/wgpu-slang-wasm/example.js new file mode 100644 index 0000000000..9e554bb447 --- /dev/null +++ b/examples/wgpu-slang-wasm/example.js @@ -0,0 +1,158 @@ +"use strict"; + +let Example = { + initialize: async function (slang, canvas) { + async function render(shaders) { + if (!navigator.gpu) { + throw new Error("WebGPU not supported on this browser."); + } + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + throw new Error("No appropriate GPUAdapter found."); + } + const device = await adapter.requestDevice(); + const context = canvas.getContext("webgpu"); + const canvasFormat = navigator.gpu.getPreferredCanvasFormat(); + context.configure({ + device: device, + format: canvasFormat, + }); + + const vertexBufferLayout = { + arrayStride: 8, + attributes: [{ + format: "float32x2", + offset: 0, + shaderLocation: 0, + }], + }; + + const pipeline = device.createRenderPipeline({ + label: "Pipeline", + layout: "auto", + vertex: { + module: device.createShaderModule({ + label: "Vertex shader module", + code: shaders.vertex + }), + entryPoint: "vertexMain", + buffers: [vertexBufferLayout] + }, + fragment: { + module: device.createShaderModule({ + label: "Fragment shader module", + code: shaders.fragment + }), + entryPoint: "fragmentMain", + targets: [{ + format: canvasFormat + }] + } + }); + + const vertices = new Float32Array([ + 0.0, -0.8, + +0.8, +0.8, + -0.8, +0.8, + ]); + const vertexBuffer = device.createBuffer({ + label: "Triangle vertices", + size: vertices.byteLength, + usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, + }); + const bufferOffset = 0; + device.queue.writeBuffer(vertexBuffer, bufferOffset, vertices); + + const encoder = device.createCommandEncoder(); + const pass = encoder.beginRenderPass({ + colorAttachments: [{ + view: context.getCurrentTexture().createView(), + loadOp: "clear", + clearValue: { r: 0, g: 0, b: 0, a: 1 }, + storeOp: "store", + }] + }); + pass.setPipeline(pipeline); + const vertexBufferSlot = 0; + pass.setVertexBuffer(vertexBufferSlot, vertexBuffer); + pass.draw(vertices.length / 2); + pass.end(); + const commandBuffer = encoder.finish(); + device.queue.submit([commandBuffer]); + } + + const slangCode = await fetch("shader.slang").then(r => r.text()); + + var wasmCompileTarget = null; + var compileTargetMap = slang.module.getCompileTargets(); + for (var i = 0; i < compileTargetMap.length; i++) { + var target = compileTargetMap[i]; + if(target.name == "WGSL") { + wasmCompileTarget = target.value; + } + } + if (wasmCompileTarget === null) { + throw new Error("Slang/WASM module doesn't support WGSL compile target."); + } + + var slangSession = slang.globalSession.createSession(wasmCompileTarget); + if (!slangSession) { + throw new Error("Failed to create global Slang session."); + } + + var wgslShaders = null; + try { + var module = slangSession.loadModuleFromSource( + slangCode, "shader", '/shader.slang' + ); + var vertexEntryPoint = module.findAndCheckEntryPoint( + "vertexMain", slang.constants.STAGE_VERTEX + ); + var fragmentEntryPoint = module.findAndCheckEntryPoint( + "fragmentMain", slang.constants.STAGE_FRAGMENT + ); + var linkedProgram = slangSession.createCompositeComponentType([ + module, vertexEntryPoint, fragmentEntryPoint + ]).link(); + wgslShaders = { + vertex: linkedProgram.getEntryPointCode( + 0 /* entryPointIndex */, 0 /* targetIndex */ + ), + fragment: linkedProgram.getEntryPointCode( + 1 /* entryPointIndex */, 0 /* targetIndex */ + ), + }; + } finally { + if (slangSession) { + slangSession.delete(); + } + } + + if (!wgslShaders) { + throw new Error("Failed to compile WGSL shaders."); + } + + render(wgslShaders); + } +} + +var Module = { + onRuntimeInitialized: function() { + const canvas = document.querySelector("canvas"); + + var globalSlangSession = Module.createGlobalSession(); + if (!globalSlangSession) { + throw new Error("Failed to create global Slang session."); + } + + const slang = { + module: Module, + globalSession: globalSlangSession, + constants: { + STAGE_VERTEX: 1, + STAGE_FRAGMENT: 5, + }, + }; + Example.initialize(slang, canvas); + }, +}; diff --git a/examples/wgpu-slang-wasm/index.html b/examples/wgpu-slang-wasm/index.html new file mode 100644 index 0000000000..b03d8366b9 --- /dev/null +++ b/examples/wgpu-slang-wasm/index.html @@ -0,0 +1,12 @@ + + + WebGPU Triangle using Slang WASM + + + +
+ +
+ + + diff --git a/examples/wgpu-slang-wasm/shader.slang b/examples/wgpu-slang-wasm/shader.slang new file mode 100644 index 0000000000..0721a1fffb --- /dev/null +++ b/examples/wgpu-slang-wasm/shader.slang @@ -0,0 +1,28 @@ +struct VertexStageInput +{ + float4 position : POSITION0; +}; + +struct VertexStageOutput +{ + float4 positionClipSpace : SV_POSITION; +}; + +struct FragmentStageOutput +{ + float4 color : SV_TARGET; +}; + +VertexStageOutput vertexMain(VertexStageInput input) : SV_Position +{ + VertexStageOutput output; + output.positionClipSpace = float4(input.position.xy, 1); + return output; +} + +FragmentStageOutput fragmentMain() : SV_Target +{ + FragmentStageOutput output; + output.color = float4(0, 1, 0, 1); + return output; +} diff --git a/source/compiler-core/slang-downstream-compiler.h b/source/compiler-core/slang-downstream-compiler.h index c0cc868d10..82aaef1076 100644 --- a/source/compiler-core/slang-downstream-compiler.h +++ b/source/compiler-core/slang-downstream-compiler.h @@ -337,6 +337,9 @@ class IDownstreamCompiler : public ICastable /// Validate and return the result virtual SLANG_NO_THROW SlangResult SLANG_MCALL validate(const uint32_t* contents, int contentsSize) = 0; + /// Disassemble and print to stdout + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + disassemble(const uint32_t* contents, int contentsSize) = 0; /// True if underlying compiler uses file system to communicate source virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0; @@ -374,6 +377,13 @@ class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler SLANG_UNUSED(contentsSize); return SLANG_FAIL; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE + { + SLANG_UNUSED(contents); + SLANG_UNUSED(contentsSize); + return SLANG_FAIL; + } DownstreamCompilerBase(const Desc& desc) : m_desc(desc) diff --git a/source/compiler-core/slang-glslang-compiler.cpp b/source/compiler-core/slang-glslang-compiler.cpp index 5550ac8ade..b619f468f3 100644 --- a/source/compiler-core/slang-glslang-compiler.cpp +++ b/source/compiler-core/slang-glslang-compiler.cpp @@ -47,6 +47,8 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase SLANG_OVERRIDE; virtual SLANG_NO_THROW SlangResult SLANG_MCALL validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE; /// Must be called before use SlangResult init(ISlangSharedLibrary* library); @@ -63,6 +65,7 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase glslang_CompileFunc_1_1 m_compile_1_1 = nullptr; glslang_CompileFunc_1_2 m_compile_1_2 = nullptr; glslang_ValidateSPIRVFunc m_validate = nullptr; + glslang_DisassembleSPIRVFunc m_disassemble = nullptr; ComPtr m_sharedLibrary; @@ -75,7 +78,8 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library) m_compile_1_1 = (glslang_CompileFunc_1_1)library->findFuncByName("glslang_compile_1_1"); m_compile_1_2 = (glslang_CompileFunc_1_2)library->findFuncByName("glslang_compile_1_2"); m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV"); - + m_disassemble = + (glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV"); if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr) { @@ -305,6 +309,20 @@ SlangResult GlslangDownstreamCompiler::validate(const uint32_t* contents, int co return SLANG_FAIL; } +SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int contentsSize) +{ + if (m_disassemble == nullptr) + { + return SLANG_FAIL; + } + + if (m_disassemble(contents, contentsSize)) + { + return SLANG_OK; + } + return SLANG_FAIL; +} + bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to) { // Can only disassemble blobs that are SPIR-V diff --git a/source/compiler-core/slang-json-rpc-connection.cpp b/source/compiler-core/slang-json-rpc-connection.cpp index 79ffc95831..fbbd23d75a 100644 --- a/source/compiler-core/slang-json-rpc-connection.cpp +++ b/source/compiler-core/slang-json-rpc-connection.cpp @@ -273,6 +273,10 @@ SlangResult JSONRPCConnection::sendCall( SlangResult JSONRPCConnection::waitForResult(Int timeOutInMs) { + // Invalidate m_jsonRoot before waitForResult, because when waitForResult fail, + // we don't want to use the result from the previous read. + m_jsonRoot.reset(); + SLANG_RETURN_ON_FAIL(m_connection->waitForResult(timeOutInMs)); return tryReadMessage(); } diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp index 4abcada624..ca45be05b0 100644 --- a/source/slang-glslang/slang-glslang.cpp +++ b/source/slang-glslang/slang-glslang.cpp @@ -182,6 +182,37 @@ extern "C" return tools.Validate(contents, contentsSize, options); } +// Disassemble the given SPIRV-ASM instructions. +extern "C" +#ifdef _MSC_VER + _declspec(dllexport) +#else + __attribute__((__visibility__("default"))) +#endif + bool glslang_disassembleSPIRV(const uint32_t* contents, int contentsSize) +{ + static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_5; + + uint32_t options = SPV_BINARY_TO_TEXT_OPTION_NONE; + options |= SPV_BINARY_TO_TEXT_OPTION_COMMENT; + options |= SPV_BINARY_TO_TEXT_OPTION_PRINT; + options |= SPV_BINARY_TO_TEXT_OPTION_COLOR; + + spv_diagnostic diagnostic = nullptr; + spv_context context = spvContextCreate(kDefaultEnvironment); + spv_result_t error = + spvBinaryToText(context, contents, contentsSize, options, nullptr, &diagnostic); + spvContextDestroy(context); + if (error) + { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + return false; + } + + return true; +} + // Apply the SPIRV-Tools optimizer to generated SPIR-V based on the desired optimization level // TODO: add flag for optimizing SPIR-V size as well static void glslang_optimizeSPIRV( diff --git a/source/slang-glslang/slang-glslang.h b/source/slang-glslang/slang-glslang.h index cfeed975ae..7b955c5af7 100644 --- a/source/slang-glslang/slang-glslang.h +++ b/source/slang-glslang/slang-glslang.h @@ -156,5 +156,6 @@ typedef int (*glslang_CompileFunc_1_0)(glslang_CompileRequest_1_0* request); typedef int (*glslang_CompileFunc_1_1)(glslang_CompileRequest_1_1* request); typedef int (*glslang_CompileFunc_1_2)(glslang_CompileRequest_1_2* request); typedef bool (*glslang_ValidateSPIRVFunc)(const uint32_t* contents, int contentsSize); +typedef bool (*glslang_DisassembleSPIRVFunc)(const uint32_t* contents, int contentsSize); #endif diff --git a/source/slang-llvm/slang-llvm.cpp b/source/slang-llvm/slang-llvm.cpp index 990f99f747..86565a4852 100644 --- a/source/slang-llvm/slang-llvm.cpp +++ b/source/slang-llvm/slang-llvm.cpp @@ -140,6 +140,13 @@ class LLVMDownstreamCompiler : public ComBaseObject, public IDownstreamCompiler SLANG_UNUSED(contentsSize); return SLANG_FAIL; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE + { + SLANG_UNUSED(contents); + SLANG_UNUSED(contentsSize); + return SLANG_FAIL; + } LLVMDownstreamCompiler() : m_desc( diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index e84f03da9f..bf53da88bc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1,3 +1,5 @@ +//public module core; + // Slang `core` library // Aliases for base types diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 2ff71a74e3..ba26b5d84a 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -9080,7 +9080,7 @@ public vector fwidthCoarse(vector p) [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtCentroid(__ref float interpolant) +public float interpolateAtCentroid(__constref float interpolant) { __target_switch { @@ -9099,7 +9099,7 @@ __generic [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector interpolateAtCentroid(__ref vector interpolant) +public vector interpolateAtCentroid(__constref vector interpolant) { __target_switch { @@ -9118,7 +9118,7 @@ public vector interpolateAtCentroid(__ref vector interpolant [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtSample(__ref float interpolant, int sample) +public float interpolateAtSample(__constref float interpolant, int sample) { __target_switch { @@ -9137,7 +9137,7 @@ __generic [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector interpolateAtSample(__ref vector interpolant, int sample) +public vector interpolateAtSample(__constref vector interpolant, int sample) { __target_switch { @@ -9156,7 +9156,7 @@ public vector interpolateAtSample(__ref vector interpolant, [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtOffset(__ref float interpolant, vec2 offset) +public float interpolateAtOffset(__constref float interpolant, vec2 offset) { __target_switch { @@ -9175,7 +9175,7 @@ __generic [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector interpolateAtOffset(__ref vector interpolant, vec2 offset) +public vector interpolateAtOffset(__constref vector interpolant, vec2 offset) { __target_switch { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a558affeaa..58ab10e241 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8101,89 +8101,139 @@ RasterizerOrderedStructuredBuffer __getEquivalentStructuredBuffer(Rasteriz // Attribute evaluation +T __EvaluateAttributeAtCentroid(__constref T x) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtCentroid"; + case glsl: __intrinsic_asm "interpolateAtCentroid"; + } +} + // TODO: The matrix cases of these functions won't actuall work // when compiled to GLSL, since they only support scalar/vector // TODO: Should these be constrains to `__BuiltinFloatingPointType`? // TODO: SPIRV-direct does not support non-floating-point types. +/// Interpolates vertex attribute at centroid position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtCentroid(T x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtCentroid(__constref T x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector EvaluateAttributeAtCentroid(vector x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector EvaluateAttributeAtCentroid(__constref vector x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$vector result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$vector result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix EvaluateAttributeAtCentroid(matrix x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix EvaluateAttributeAtCentroid(__constref matrix x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); default: MATRIX_MAP_UNARY(T, N, M, EvaluateAttributeAtCentroid, x); } } +T __EvaluateAttributeAtSample(__constref T x, uint sampleIndex) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtSample"; + case glsl: __intrinsic_asm "interpolateAtSample"; + } +} + +/// Interpolates vertex attribute at the current fragment sample position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtSample(T x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtSample(__constref T x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector EvaluateAttributeAtSample(vector x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector EvaluateAttributeAtSample(__constref vector x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$vector result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$vector result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix EvaluateAttributeAtSample(matrix x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix EvaluateAttributeAtSample(__constref matrix x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); default: matrix result; for(int i = 0; i < N; ++i) @@ -8194,21 +8244,59 @@ matrix EvaluateAttributeAtSample(matrix x, uint sampleindex) } } +T __EvaluateAttributeSnapped(__constref T x, int2 offset) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeSnapped"; + case glsl: __intrinsic_asm "EvaluateAttributeSnapped"; + } +} + +/// Interpolates vertex attribute at the specified subpixel offset. +/// @param x The vertex attribute to interpolate. +/// @param offset The subpixel offset. Each component is a 4-bit signed integer in range [-8, 7]. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// +/// The valid values of each component of `offset` are: +/// +/// - 1000 = -0.5f (-8 / 16) +/// - 1001 = -0.4375f (-7 / 16) +/// - 1010 = -0.375f (-6 / 16) +/// - 1011 = -0.3125f (-5 / 16) +/// - 1100 = -0.25f (-4 / 16) +/// - 1101 = -0.1875f (-3 / 16) +/// - 1110 = -0.125f (-2 / 16) +/// - 1111 = -0.0625f (-1 / 16) +/// - 0000 = 0.0f ( 0 / 16) +/// - 0001 = 0.0625f ( 1 / 16) +/// - 0010 = 0.125f ( 2 / 16) +/// - 0011 = 0.1875f ( 3 / 16) +/// - 0100 = 0.25f ( 4 / 16) +/// - 0101 = 0.3125f ( 5 / 16) +/// - 0110 = 0.375f ( 6 / 16) +/// - 0111 = 0.4375f ( 7 / 16) +/// @category interpolation Vertex Interpolation Functions __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeSnapped(T x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeSnapped(__constref T x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$T = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$T = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8216,19 +8304,23 @@ T EvaluateAttributeSnapped(T x, int2 offset) __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector EvaluateAttributeSnapped(vector x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector EvaluateAttributeSnapped(__constref vector x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$vector = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$vector = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8236,12 +8328,15 @@ vector EvaluateAttributeSnapped(vector x, int2 offset) __generic [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix EvaluateAttributeSnapped(matrix x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix EvaluateAttributeSnapped(__constref matrix x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); default: matrix result; for(int i = 0; i < N; ++i) @@ -9243,8 +9338,16 @@ matrix fwidth(matrix x) } } +__intrinsic_op($(kIROp_ResolveVaryingInputRef)) +Ref __ResolveVaryingInputRef(__constref T attribute); + __intrinsic_op($(kIROp_GetPerVertexInputArray)) -Array __GetPerVertexInputArray(T attribute); +Ref> __GetPerVertexInputArray(__constref T attribute); + +T __GetAttributeAtVertex(__constref T attribute, uint vertexIndex) +{ + __intrinsic_asm "GetAttributeAtVertex"; +} /// Get the value of a vertex attribute at a specific vertex. /// @@ -9265,15 +9368,15 @@ __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] [KnownBuiltin("GetAttributeAtVertex")] [__unsafeForceInlineEarly] -T GetAttributeAtVertex(T attribute, uint vertexIndex) +T GetAttributeAtVertex(__constref T attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); case glsl: case spirv: - return __GetPerVertexInputArray(attribute)[vertexIndex]; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9294,20 +9397,16 @@ __generic __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -vector GetAttributeAtVertex(vector attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +vector GetAttributeAtVertex(__constref vector attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_vectorT = OpTypePointer Input $$vector; - %addr = OpAccessChain %_ptr_Input_vectorT $attribute $vertexIndex; - result:$$vector = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9328,20 +9427,16 @@ __generic __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -matrix GetAttributeAtVertex(matrix attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +matrix GetAttributeAtVertex(__constref matrix attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_matrixT = OpTypePointer Input $$matrix; - %addr = OpAccessChain %_ptr_Input_matrixT $attribute $vertexIndex; - result:$$matrix = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e8886a59aa..911455f17e 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -461,6 +461,8 @@ class ModuleDecl : public NamespaceDeclBase /// `__implementing` etc. bool isInLegacyLanguage = true; + DeclVisibility defaultVisibility = DeclVisibility::Internal; + SLANG_UNREFLECTED /// Map a type to the list of extensions of that type (if any) declared in this module diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index b77003bff0..85d2d0d9fc 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -460,6 +460,26 @@ struct ASTDumpContext } } + void dump(DeclVisibility vis) + { + switch (vis) + { + case DeclVisibility::Private: + m_writer->emit("private"); + break; + case DeclVisibility::Internal: + m_writer->emit("internal"); + break; + case DeclVisibility::Public: + m_writer->emit("public"); + break; + default: + m_writer->emit(String((int)vis).getUnownedSlice()); + break; + } + } + + void dump(const QualType& qualType) { if (qualType.isLeftValue) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ce3f1e64c1..3667a36ba2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3186,6 +3186,11 @@ void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) } } + if (moduleDecl->findModifier()) + { + moduleDecl->defaultVisibility = DeclVisibility::Public; + } + // We need/want to visit any `import` declarations before // anything else, to make sure that scoping works. // @@ -12604,8 +12609,10 @@ DeclVisibility getDeclVisibility(Decl* decl) } auto defaultVis = DeclVisibility::Default; if (auto parentModule = getModuleDecl(decl)) - defaultVis = - parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + { + defaultVis = parentModule->isInLegacyLanguage ? DeclVisibility::Public + : parentModule->defaultVisibility; + } // Members of other agg type decls will have their default visibility capped to the parents'. if (as(decl)) @@ -12790,10 +12797,6 @@ void diagnoseCapabilityProvenance( auto moduleDecl = getModuleDecl(declToPrint); if (thisModule != moduleDecl) break; - if (moduleDecl && moduleDecl->isInLegacyLanguage) - continue; - if (getDeclVisibility(declToPrint) == DeclVisibility::Public) - break; } if (previousDecl == declToPrint) break; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 83b668a336..95d5a2a7c0 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4099,6 +4099,15 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas elementType = QualType(ptrType->getValueType()); elementType.isLeftValue = true; } + else + { + auto newExpr = maybeOpenRef(expr); + if (newExpr != expr) + { + expr = newExpr; + continue; + } + } if (elementType.type) { auto derefExpr = m_astBuilder->create(); @@ -4108,9 +4117,10 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas expr = derefExpr; continue; } - // Default case: just use the expression as-is - return expr; + break; } + // Default case: just use the expression as-is + return expr; } Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 9d08d5b73c..821a895bc7 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2562,11 +2562,7 @@ DIAGNOSTIC( Internal, serialDebugVerificationFailed, "Verification of serial debug information failed.") -DIAGNOSTIC( - 99999, - Internal, - spirvValidationFailed, - "Validation of generated SPIR-V failed. SPIRV generated: \n$0") +DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed.") DIAGNOSTIC( 99999, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index dd8b821a16..d4b3cac275 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1529,6 +1529,30 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } + if (auto load = as(inst)) + { + // Loads from a constref global param should always be folded. + auto ptrType = load->getPtr()->getDataType(); + if (load->getPtr()->getOp() == kIROp_GlobalParam) + { + if (ptrType->getOp() == kIROp_ConstRefType) + return true; + if (auto ptrTypeBase = as(ptrType)) + { + auto addrSpace = ptrTypeBase->getAddressSpace(); + switch (addrSpace) + { + case Slang::AddressSpace::Uniform: + case Slang::AddressSpace::Input: + case Slang::AddressSpace::BuiltinInput: + return true; + default: + break; + } + } + } + } + // Always hold if inst is a call into an [__alwaysFoldIntoUseSite] function. if (auto call = as(inst)) { @@ -4701,9 +4725,21 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) auto rawType = varDecl->getDataType(); auto varType = rawType; - if (auto outType = as(varType)) + if (auto ptrType = as(varType)) { - varType = outType->getValueType(); + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::Output: + case AddressSpace::BuiltinInput: + case AddressSpace::BuiltinOutput: + varType = ptrType->getValueType(); + break; + default: + if (as(ptrType)) + varType = ptrType->getValueType(); + break; + } } if (as(varType)) return; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index f91c4d06ed..b0d1fbb4cc 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -310,6 +310,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S } case kIROp_NativePtrType: case kIROp_PtrType: + case kIROp_ConstRefType: { auto elementType = (IRType*)type->getOperand(0); SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 17c20d0642..326eef8b47 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -3257,6 +3257,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as(type)->getValueType()); + return; + } default: break; } @@ -3562,15 +3567,18 @@ void GLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* varType) // auto matrixType = as(unwrapArray(varType)); - if (matrixType) { + auto layout = getIntVal(matrixType->getLayout()); + if (layout == getTargetProgram()->getOptionSet().getMatrixLayoutMode()) + return; + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL // calls them, because what they call "columns" // are what we call "rows." // - switch (getIntVal(matrixType->getLayout())) + switch (layout) { case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: m_writer->emit("layout(row_major)\n"); diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 40d6f75d99..83eec17b48 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1180,6 +1180,34 @@ void HLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) Super::emitSimpleValueImpl(inst); } +void HLSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + if (declarator) + { + // HLSL only allow matrix layout modifier when declaring a variable or struct field. + if (auto matType = as(type)) + { + auto matrixLayout = getIntVal(matType->getLayout()); + if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != + (MatrixLayoutMode)matrixLayout) + { + switch (matrixLayout) + { + case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: + m_writer->emit("column_major "); + break; + case SLANG_MATRIX_LAYOUT_ROW_MAJOR: + m_writer->emit("row_major "); + break; + default: + break; + } + } + } + } + Super::emitSimpleTypeAndDeclaratorImpl(type, declarator); +} + void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { switch (type->getOp()) @@ -1313,6 +1341,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as(type)->getValueType()); + return; + } default: break; } @@ -1671,28 +1704,6 @@ void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) } } -void HLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* type) -{ - auto matType = as(type); - if (!matType) - return; - auto matrixLayout = getIntVal(matType->getLayout()); - if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != (MatrixLayoutMode)matrixLayout) - { - switch (matrixLayout) - { - case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: - m_writer->emit("column_major "); - break; - case SLANG_MATRIX_LAYOUT_ROW_MAJOR: - m_writer->emit("row_major "); - break; - default: - break; - } - } -} - void HLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) { if (inst->findDecoration()) diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index b2e2ca05a5..6b99a7f50a 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -55,11 +55,12 @@ class HLSLSourceEmitter : public CLikeSourceEmitter IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) + SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; - virtual void emitMatrixLayoutModifiersImpl(IRType* varType) SLANG_OVERRIDE; virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE { emitMatrixLayoutModifiersImpl(type); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 6b74648420..f5599289a2 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -12,7 +12,6 @@ #include "slang-ir-util.h" #include "slang-ir.h" #include "slang-lookup-spirv.h" -#include "slang-spirv-val.h" #include "spirv/unified1/spirv.h" #include diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index bac8fc1ddd..45dd683421 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -42,6 +42,7 @@ #include "slang-ir-entry-point-uniforms.h" #include "slang-ir-explicit-global-context.h" #include "slang-ir-explicit-global-init.h" +#include "slang-ir-fix-entrypoint-callsite.h" #include "slang-ir-fuse-satcoop.h" #include "slang-ir-glsl-legalize.h" #include "slang-ir-glsl-liveness.h" @@ -76,6 +77,7 @@ #include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-resolve-texture-format.h" +#include "slang-ir-resolve-varying-input-ref.h" #include "slang-ir-restructure-scoping.h" #include "slang-ir-restructure.h" #include "slang-ir-sccp.h" @@ -104,7 +106,6 @@ #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" -#include "slang-spirv-val.h" #include "slang-syntax.h" #include "slang-type-layout.h" #include "slang-visitor.h" @@ -315,6 +316,7 @@ struct RequiredLoweringPassSet bool glslSSBO; bool byteAddressBuffer; bool dynamicResource; + bool resolveVaryingInputRef; }; // Scan the IR module and determine which lowering/legalization passes are needed based @@ -424,6 +426,9 @@ void calcRequiredLoweringPassSet( case kIROp_DynamicResourceType: result.dynamicResource = true; break; + case kIROp_ResolveVaryingInputRef: + result.resolveVaryingInputRef = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -592,6 +597,11 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.glslGlobalVar) translateGLSLGlobalVar(codeGenContext, irModule); + if (requiredLoweringPassSet.resolveVaryingInputRef) + resolveVaryingInputRef(irModule); + + fixEntryPointCallsites(irModule); + // Replace any global constants with their values. // replaceGlobalConstants(irModule); @@ -1940,18 +1950,16 @@ SlangResult emitSPIRVForEntryPointsDirectly( ArtifactUtil::createArtifactForCompileTarget(asExternal(codeGenContext->getTargetFormat())); artifact->addRepresentationUnknown(ListBlob::moveCreate(spirv)); -#if 0 - // Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV - String err; String dis; - disassembleSPIRV(spirv, err, dis); - printf("%s", dis.begin()); -#endif - IDownstreamCompiler* compiler = codeGenContext->getSession()->getOrLoadDownstreamCompiler( PassThroughMode::SpirvOpt, codeGenContext->getSink()); if (compiler) { +#if 0 + // Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV + compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4)); +#endif + if (!codeGenContext->shouldSkipSPIRVValidation()) { StringBuilder runSpirvValEnvVar; @@ -1964,13 +1972,10 @@ SlangResult emitSPIRVForEntryPointsDirectly( (uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4)))) { - String err; - String dis; - disassembleSPIRV(spirv, err, dis); + compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4)); codeGenContext->getSink()->diagnoseWithoutSourceView( SourceLoc{}, - Diagnostics::spirvValidationFailed, - dis); + Diagnostics::spirvValidationFailed); } } } diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp new file mode 100644 index 0000000000..7390f3a7f1 --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp @@ -0,0 +1,101 @@ +#include "slang-ir-fix-entrypoint-callsite.h" + +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +// If the entrypoint is called by some other function, we need to clone the +// entrypoint and replace the callsites to call the cloned entrypoint instead. +// This is because we will be modifying the signature of the entrypoint during +// entrypoint legalization to rewrite the way system values are passed in. +// By replacing the callsites to call the cloned entrypoint that act as ordinary +// functions, we will no longer need to worry about changing the callsites when we +// legalize the entry-points. +// +void fixEntryPointCallsites(IRFunc* entryPoint) +{ + IRFunc* clonedEntryPointForCall = nullptr; + auto ensureClonedEntryPointForCall = [&]() -> IRFunc* + { + if (clonedEntryPointForCall) + return clonedEntryPointForCall; + IRCloneEnv cloneEnv; + IRBuilder builder(entryPoint); + builder.setInsertBefore(entryPoint); + clonedEntryPointForCall = (IRFunc*)cloneInst(&cloneEnv, &builder, entryPoint); + // Remove entrypoint and linkage decorations from the cloned callee. + List decorsToRemove; + for (auto decor : clonedEntryPointForCall->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ExportDecoration: + case kIROp_UserExternDecoration: + case kIROp_HLSLExportDecoration: + case kIROp_EntryPointDecoration: + case kIROp_LayoutDecoration: + case kIROp_NumThreadsDecoration: + case kIROp_ImportDecoration: + case kIROp_ExternCDecoration: + case kIROp_ExternCppDecoration: + decorsToRemove.add(decor); + break; + } + } + for (auto decor : decorsToRemove) + decor->removeAndDeallocate(); + return clonedEntryPointForCall; + }; + traverseUses( + entryPoint, + [&](IRUse* use) + { + auto user = use->getUser(); + auto call = as(user); + if (!call) + return; + auto callee = ensureClonedEntryPointForCall(); + call->setOperand(0, callee); + + // Fix up argument types: if the callee entrypoint is expecting a constref + // and the caller is passing a value, we need to wrap the value in a temporary var + // and pass the temporary var. + // + auto funcType = as(callee->getDataType()); + SLANG_ASSERT(funcType); + IRBuilder builder(call); + builder.setInsertBefore(call); + List params; + for (auto param : callee->getParams()) + params.add(param); + if ((UInt)params.getCount() != call->getArgCount()) + return; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto paramType = params[i]->getDataType(); + auto arg = call->getArg(i); + if (auto refType = as(paramType)) + { + if (!as(arg->getDataType())) + { + auto tempVar = builder.emitVar(refType->getValueType()); + builder.emitStore(tempVar, arg); + call->setArg(i, tempVar); + } + } + } + }); +} + +void fixEntryPointCallsites(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration()) + fixEntryPointCallsites((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.h b/source/slang/slang-ir-fix-entrypoint-callsite.h new file mode 100644 index 0000000000..493d67a774 --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.h @@ -0,0 +1,9 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void fixEntryPointCallsites(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 09bf245df4..39f9703191 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -67,11 +67,7 @@ struct ScalarizedValImpl : RefObject }; struct ScalarizedTupleValImpl; struct ScalarizedTypeAdapterValImpl; - -struct ScalarizedArrayIndexValImpl : ScalarizedValImpl -{ - Index index; -}; +struct ScalarizedArrayIndexValImpl; struct ScalarizedVal { @@ -132,15 +128,12 @@ struct ScalarizedVal result.impl = (ScalarizedValImpl*)impl; return result; } - static ScalarizedVal scalarizedArrayIndex(IRInst* irValue, Index index) + static ScalarizedVal scalarizedArrayIndex(ScalarizedArrayIndexValImpl* impl) { ScalarizedVal result; result.flavor = Flavor::arrayIndex; - auto impl = new ScalarizedArrayIndexValImpl; - impl->index = index; - - result.irValue = irValue; - result.impl = impl; + result.irValue = nullptr; + result.impl = (ScalarizedValImpl*)impl; return result; } @@ -151,8 +144,6 @@ struct ScalarizedVal RefPtr impl; }; -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - // This is the case for a value that is a "tuple" of other values struct ScalarizedTupleValImpl : ScalarizedValImpl { @@ -175,6 +166,36 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl IRType* pretendType; // the type this value pretends to have }; +struct ScalarizedArrayIndexValImpl : ScalarizedValImpl +{ + ScalarizedVal arrayVal; + Index index; + IRType* elementType; +}; + +ScalarizedVal extractField( + IRBuilder* builder, + ScalarizedVal const& val, + UInt fieldIndex, // Pass ~0 in to search for the index via the key + IRStructKey* fieldKey); +ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* fromType); +ScalarizedVal adaptType( + IRBuilder* builder, + ScalarizedVal const& val, + IRType* toType, + IRType* fromType); +IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + IRInst* indexVal); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + UInt index); + struct GlobalVaryingDeclarator { enum class Flavor @@ -1303,6 +1324,22 @@ ScalarizedVal createSimpleGLSLGlobalVarying( } } + AddressSpace addrSpace = AddressSpace::Uniform; + IROp ptrOpCode = kIROp_PtrType; + switch (kind) + { + case LayoutResourceKind::VaryingInput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinInput : AddressSpace::Input; + break; + case LayoutResourceKind::VaryingOutput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinOutput : AddressSpace::Output; + ptrOpCode = kIROp_OutType; + break; + default: + break; + } + + // If we have a declarator, we just use the normal logic, as that seems to work correctly // if (systemValueInfo && systemValueInfo->arrayIndex >= 0 && declarator == nullptr) @@ -1339,9 +1376,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // Set the array size to 0, to mean it is unsized auto arrayType = builder->getArrayType(type, 0); - IRType* paramType = kind == LayoutResourceKind::VaryingOutput - ? (IRType*)builder->getOutType(arrayType) - : arrayType; + IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); @@ -1371,9 +1406,12 @@ ScalarizedVal createSimpleGLSLGlobalVarying( semanticGlobal->addIndex(systemValueInfo->arrayIndex); // Make it an array index - ScalarizedVal val = ScalarizedVal::scalarizedArrayIndex( - semanticGlobal->globalParam, - systemValueInfo->arrayIndex); + ScalarizedVal val = ScalarizedVal::address(semanticGlobal->globalParam); + RefPtr arrayImpl = new ScalarizedArrayIndexValImpl(); + arrayImpl->arrayVal = val; + arrayImpl->index = systemValueInfo->arrayIndex; + arrayImpl->elementType = type; + val = ScalarizedVal::scalarizedArrayIndex(arrayImpl); // We need to make this access, an array access to the global if (auto fromType = systemValueInfo->requiredType) @@ -1466,14 +1504,14 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // like our IR function parameters, and need a wrapper // `Out<...>` type to represent outputs. // - bool isOutput = (kind == LayoutResourceKind::VaryingOutput); - IRType* paramType = isOutput ? builder->getOutType(type) : type; + + // Non system value varying inputs shall be passed as pointers. + IRType* paramType = builder->getPtrType(ptrOpCode, type, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); - ScalarizedVal val = - isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); + ScalarizedVal val = ScalarizedVal::address(globalParam); if (systemValueInfo) { @@ -1958,10 +1996,10 @@ ScalarizedVal adaptType( break; case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as(val.impl)->index); - return adaptType(builder, element, toType, fromType); + auto arrayImpl = as(val.impl); + auto elementVal = + getSubscriptVal(builder, fromType, arrayImpl->arrayVal, arrayImpl->index); + return adaptType(builder, elementVal, toType, fromType); } break; default: @@ -1970,8 +2008,6 @@ ScalarizedVal adaptType( } } -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - void assign( IRBuilder* builder, ScalarizedVal const& left, @@ -1988,16 +2024,12 @@ void assign( // Determine the index auto leftArrayIndexVal = as(left.impl); - const auto arrayIndex = leftArrayIndexVal->index; - - auto arrayIndexInst = builder->getIntValue(builder->getIntType(), arrayIndex); - - // Store to the index - auto address = builder->emitElementAddress( - builder->getPtrType(right.irValue->getFullType()), - left.irValue, - arrayIndexInst); - builder->emitStore(address, rhs); + auto leftVal = getSubscriptVal( + builder, + leftArrayIndexVal->elementType, + leftArrayIndexVal->arrayVal, + leftArrayIndexVal->index); + builder->emitStore(leftVal.irValue, rhs); break; } @@ -2236,10 +2268,10 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as(val.impl)->index); - return element; + auto impl = as(val.impl); + auto elementVal = + getSubscriptVal(builder, impl->elementType, impl->arrayVal, impl->index); + return materializeValue(builder, elementVal); } case ScalarizedVal::Flavor::tuple: { @@ -2735,9 +2767,9 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst* IRBuilder builder(inputVertexAttr); builder.setInsertBefore(inputVertexAttr); auto arrayType = builder.getArrayType( - inputVertexAttr->getDataType(), + tryGetPointedToType(&builder, inputVertexAttr->getDataType()), builder.getIntValue(builder.getIntType(), 3)); - arrayInst = builder.createGlobalParam(arrayType); + arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input)); context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst; builder.addDecoration(arrayInst, kIROp_PerVertexDecoration); @@ -2765,6 +2797,27 @@ void tryReplaceUsesOfStageInput( { case ScalarizedVal::Flavor::value: { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + builder.replaceOperand(use, val.irValue); + }); + } + break; + case ScalarizedVal::Flavor::address: + { + bool needMaterialize = false; + if (as(val.irValue->getDataType())) + { + if (!as(originalVal->getDataType())) + { + needMaterialize = true; + } + } traverseUses( originalVal, [&](IRUse* use) @@ -2775,16 +2828,79 @@ void tryReplaceUsesOfStageInput( auto arrayInst = getOrCreatePerVertexInputArray(context, val.irValue); user->replaceUsesWith(arrayInst); user->removeAndDeallocate(); + return; + } + IRBuilder builder(user); + builder.setInsertBefore(user); + if (needMaterialize) + { + auto materializedVal = materializeValue(&builder, val); + builder.replaceOperand(use, materializedVal); } else { - IRBuilder builder(user); - builder.setInsertBefore(user); builder.replaceOperand(use, val.irValue); } }); } break; + case ScalarizedVal::Flavor::typeAdapter: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto typeAdapter = as(val.impl); + auto materializedInner = materializeValue(&builder, typeAdapter->val); + auto adapted = adaptType( + &builder, + materializedInner, + typeAdapter->pretendType, + typeAdapter->actualType); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(adapted.irValue); + user->removeAndDeallocate(); + } + else + { + use->set(adapted.irValue); + } + }); + } + break; + case ScalarizedVal::Flavor::arrayIndex: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto arrayIndexImpl = as(val.impl); + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto subscriptVal = getSubscriptVal( + &builder, + arrayIndexImpl->elementType, + arrayIndexImpl->arrayVal, + arrayIndexImpl->index); + builder.setInsertBefore(user); + auto materializedInner = materializeValue(&builder, subscriptVal); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(materializedInner); + user->removeAndDeallocate(); + } + else + { + use->set(materializedInner); + } + }); + break; + } case ScalarizedVal::Flavor::tuple: { auto tupleVal = as(val.impl); @@ -2793,22 +2909,36 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); - if (auto fieldExtract = as(user)) + switch (user->getOp()) { - auto fieldKey = fieldExtract->getField(); - ScalarizedVal fieldVal; - for (auto element : tupleVal->elements) + case kIROp_FieldExtract: + case kIROp_FieldAddress: { - if (element.key == fieldKey) + auto fieldKey = user->getOperand(1); + ScalarizedVal fieldVal; + for (auto element : tupleVal->elements) { - fieldVal = element.val; - break; + if (element.key == fieldKey) + { + fieldVal = element.val; + break; + } + } + if (fieldVal.flavor != ScalarizedVal::Flavor::none) + { + tryReplaceUsesOfStageInput(context, fieldVal, user); } } - if (fieldVal.flavor != ScalarizedVal::Flavor::none) + break; + case kIROp_Load: { - tryReplaceUsesOfStageInput(context, fieldVal, user); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto materializedVal = materializeTupleValue(&builder, val); + user->replaceUsesWith(materializedVal); + user->removeAndDeallocate(); } + break; } }); } @@ -3066,7 +3196,7 @@ void legalizeEntryPointParameterForGLSL( // We are going to create a local variable of the appropriate // type, which will replace the parameter, along with // one or more global variables for the actual input/output. - + setInsertAfterOrdinaryInst(builder, pp); auto localVariable = builder->emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); @@ -3135,6 +3265,73 @@ void legalizeEntryPointParameterForGLSL( assign(&terminatorBuilder, globalOutputVal, localVal); } } + else if (auto ptrType = as(paramType)) + { + // This is the case where the parameter is passed by const + // reference. We simply replace existing uses of the parameter + // with the real global variable. + SLANG_ASSERT( + ptrType->getOp() == kIROp_ConstRefType || + ptrType->getAddressSpace() == AddressSpace::Input || + ptrType->getAddressSpace() == AddressSpace::BuiltinInput); + + auto globalValue = createGLSLGlobalVaryings( + context, + codeGenContext, + builder, + valueType, + paramLayout, + LayoutResourceKind::VaryingInput, + stage, + pp); + tryReplaceUsesOfStageInput(context, globalValue, pp); + for (auto dec : pp->getDecorations()) + { + if (dec->getOp() != kIROp_GlobalVariableShadowingGlobalParameterDecoration) + continue; + auto globalVar = dec->getOperand(0); + auto key = dec->getOperand(1); + IRInst* realGlobalVar = nullptr; + if (globalValue.flavor != ScalarizedVal::Flavor::tuple) + continue; + if (auto tupleVal = as(globalValue.impl)) + { + for (auto elem : tupleVal->elements) + { + if (elem.key == key) + { + realGlobalVar = elem.val.irValue; + break; + } + } + } + SLANG_ASSERT(realGlobalVar); + + // Remove all stores into the global var introduced during + // the initial glsl global var translation pass since we are + // going to replace the global var with a pointer to the real + // input, and it makes no sense to store values into such real + // input locations. + traverseUses( + globalVar, + [&](IRUse* use) + { + auto user = use->getUser(); + if (auto store = as(user)) + { + if (store->getPtrUse() == use) + { + store->removeAndDeallocate(); + } + } + }); + // we will be replacing uses of `globalVarToReplace`. We need + // globalVarToReplaceNextUse to catch the next use before it is removed from the + // list of uses. + globalVar->replaceUsesWith(realGlobalVar); + globalVar->removeAndDeallocate(); + } + } else { // This is the "easy" case where the parameter wasn't @@ -3451,6 +3648,7 @@ ScalarizedVal legalizeEntryPointReturnValueForGLSL( return result; } + void legalizeEntryPointForGLSL( Session* session, IRModule* module, @@ -3554,12 +3752,12 @@ void legalizeEntryPointForGLSL( // and turn them into global variables. if (auto firstBlock = func->getFirstBlock()) { - // Any initialization code we insert for parameters needs - // to be at the start of the "ordinary" instructions in the block: - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { + // Any initialization code we insert for parameters needs + // to be at the start of the "ordinary" instructions in the block: + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + // We assume that the entry-point parameters will all have // layout information attached to them, which is kept up-to-date // by any transformations affecting the parameter list. @@ -3606,11 +3804,11 @@ void legalizeEntryPointForGLSL( { auto type = value.globalParam->getDataType(); - // Strip out if there is one - auto outType = as(type); - if (outType) + // Strip ptr if there is one. + auto ptrType = as(type); + if (ptrType) { - type = outType->getValueType(); + type = ptrType->getValueType(); } // Get the array type @@ -3627,10 +3825,13 @@ void legalizeEntryPointForGLSL( auto elementCountInst = builder.getIntValue(builder.getIntType(), value.maxIndex + 1); IRType* sizedArrayType = builder.getArrayType(elementType, elementCountInst); - // Re-add out if there was one on the input - if (outType) + // Re-add ptr if there was one on the input + if (ptrType) { - sizedArrayType = builder.getOutType(sizedArrayType); + sizedArrayType = builder.getPtrType( + ptrType->getOp(), + sizedArrayType, + ptrType->getAddressSpace()); } // Change the globals type diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 574b755108..27003f6a79 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -740,7 +740,8 @@ INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1, INST(GetLegalizedSPIRVGlobalParamAddr, GetLegalizedSPIRVGlobalParamAddr, 1, 0) -INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, 0) +INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE) +INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE) INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0) INST(MetalAtomicCast, MetalAtomicCast, 1, 0) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 025bcf1b85..33f3944fd9 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -651,7 +651,9 @@ struct EntryPointVaryingParamLegalizeContext // The materialized value can be used to completely // replace the original parameter. // - param->replaceUsesWith(materialized); + auto localVar = builder.emitVar(materialized->getDataType()); + builder.emitStore(localVar, materialized); + param->replaceUsesWith(localVar); param->removeAndDeallocate(); } @@ -1475,4 +1477,71 @@ void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* si context.processModule(module, sink); } +void depointerizeInputParams(IRFunc* entryPointFunc) +{ + List workList; + List modifiedParamIndices; + Index i = 0; + for (auto param : entryPointFunc->getParams()) + { + if (auto constRefType = as(param->getFullType())) + { + switch (constRefType->getValueType()->getOp()) + { + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + continue; + default: + break; + } + workList.add(param); + modifiedParamIndices.add(i); + } + else if (auto ptrType = as(param->getFullType())) + { + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::BuiltinInput: + workList.add(param); + modifiedParamIndices.add(i); + break; + } + } + i++; + } + for (auto param : workList) + { + auto valueType = as(param->getDataType())->getValueType(); + IRBuilder builder(param); + setInsertBeforeOrdinaryInst(&builder, param); + auto var = builder.emitVar(valueType); + param->replaceUsesWith(var); + param->setFullType(valueType); + builder.emitStore(var, param); + } + + fixUpFuncType(entryPointFunc); + + // Fix up callsites of the entrypoint func. + for (auto use = entryPointFunc->firstUse; use; use = use->nextUse) + { + auto call = as(use->getUser()); + if (!call) + continue; + IRBuilder builder(call); + builder.setInsertBefore(call); + for (auto paramIndex : modifiedParamIndices) + { + auto arg = call->getArg(paramIndex); + auto ptrType = as(arg->getDataType()); + if (!ptrType) + continue; + auto val = builder.emitLoad(arg); + call->setArg(paramIndex, val); + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 7604cb2458..efd61e87cf 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -18,6 +18,7 @@ void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sin void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink); +void depointerizeInputParams(IRFunc* entryPoint); // (#4375) Once `slang-ir-metal-legalize.cpp` is merged with // `slang-ir-legalize-varying-params.cpp`, move the following diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 7f67c92546..74e84f1eed 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -412,6 +412,28 @@ struct LoweredElementTypeContext return 4; } + bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) + { + // For spirv, we always want to lower all matrix types, because SPIRV does not support + // specifying matrix layout/stride if the matrix type is used in places other than + // defining a struct field. This means that if a matrix is used to define a varying + // parameter, we always want to wrap it in a struct. + // + if (target->shouldEmitSPIRVDirectly()) + { + return true; + } + + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) + { + // For other targets, we only lower the matrix types if they differ from the default + // matrix layout. + return false; + } + return true; + } + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); @@ -422,18 +444,10 @@ struct LoweredElementTypeContext if (auto matrixType = as(type)) { - // For spirv, we always want to lower all matrix types, because matrix types - // are considered abstract types. - if (!target->shouldEmitSPIRVDirectly()) + if (!shouldLowerMatrixType(matrixType, config)) { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } auto loweredType = builder.createStructType(); @@ -859,27 +873,24 @@ struct LoweredElementTypeContext { IRType* elementType = nullptr; - if (options.lowerBufferPointer) + if (auto ptrType = as(globalInst)) { - if (auto ptrType = as(globalInst)) + switch (ptrType->getAddressSpace()) { - switch (ptrType->getAddressSpace()) - { - case AddressSpace::UserPointer: - case AddressSpace::Input: - case AddressSpace::Output: - elementType = ptrType->getValueType(); - break; - } + case AddressSpace::UserPointer: + if (!options.lowerBufferPointer) + continue; + [[fallthrough]]; + case AddressSpace::Input: + case AddressSpace::Output: + elementType = ptrType->getValueType(); + break; } } - else - { - if (auto structBuffer = as(globalInst)) - elementType = structBuffer->getElementType(); - else if (auto constBuffer = as(globalInst)) - elementType = constBuffer->getElementType(); - } + if (auto structBuffer = as(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as(globalInst)) + elementType = constBuffer->getElementType(); if (as(globalInst)) continue; if (!as(elementType) && !as(elementType) && diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 835041a592..ce5b34c3e6 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1924,6 +1924,7 @@ struct LegalizeMetalEntryPointContext void legalizeEntryPointForMetal(EntryPointInfo entryPoint) { // Input Parameter Legalize + depointerizeInputParams(entryPoint.entryPointFunc); hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir-resolve-varying-input-ref.cpp b/source/slang/slang-ir-resolve-varying-input-ref.cpp new file mode 100644 index 0000000000..0707c566f8 --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.cpp @@ -0,0 +1,92 @@ +#include "slang-ir-resolve-varying-input-ref.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func) +{ + List toRemove; + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + for (auto inst : bb->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_ResolveVaryingInputRef: + { + // Resolve a reference to varying input to the actual global param + // representing the varying input. + auto operand = inst->getOperand(0); + List accessChain; + List types; + auto rootAddr = getRootAddr(operand, accessChain, &types); + if (rootAddr->getOp() == kIROp_Param || rootAddr->getOp() == kIROp_GlobalParam) + { + // If the referred operand is already a global param, use it directly. + inst->replaceUsesWith(operand); + toRemove.add(inst); + break; + } + // If the referred operand is a local var, + // and there is a store(var, load(globalParam)), + // replace `inst` with `globalParam`. + IRInst* srcPtr = nullptr; + for (auto use = rootAddr->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto store = as(user)) + { + if (store->getPtrUse() == use) + { + if (auto load = as(store->getVal())) + { + auto ptr = load->getPtr(); + if (ptr->getOp() == kIROp_Param || + ptr->getOp() == kIROp_GlobalParam) + { + if (!srcPtr) + srcPtr = ptr; + else + { + srcPtr = nullptr; + break; + } + } + } + } + } + } + if (srcPtr) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto resolvedPtr = builder.emitElementAddress( + srcPtr, + accessChain.getArrayView(), + types.getArrayView()); + inst->replaceUsesWith(resolvedPtr); + toRemove.add(inst); + } + } + break; + } + } + } + for (auto inst : toRemove) + { + inst->removeAndDeallocate(); + } +} + +void resolveVaryingInputRef(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration()) + resolveVaryingInputRef((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-resolve-varying-input-ref.h b/source/slang/slang-ir-resolve-varying-input-ref.h new file mode 100644 index 0000000000..5cbff0f8c5 --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func); +void resolveVaryingInputRef(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 2757538a67..50dfa2c6a3 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -125,6 +125,12 @@ struct SpecializationContext } } + if (isWrapperType(inst)) + { + // For all the wrapper type, we need to make sure the operands are fully specialized. + return areAllOperandsFullySpecialized(inst); + } + // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 65cb8f64fd..a44e16a7ce 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -122,7 +122,8 @@ struct GlobalVarTranslationContext // Add an entry point parameter for all the inputs. auto firstBlock = entryPointFunc->getFirstBlock(); builder.setInsertInto(firstBlock); - auto inputParam = builder.emitParam(inputStructType); + auto inputParam = builder.emitParam( + builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); builder.addLayoutDecoration(inputParam, paramLayout); // Initialize all global variables. @@ -133,7 +134,8 @@ struct GlobalVarTranslationContext auto inputType = cast(input->getDataType())->getValueType(); builder.emitStore( input, - builder.emitFieldExtract(inputType, inputParam, inputKeys[i])); + builder + .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i])); // Relate "global variable" to a "global parameter" for use later in compilation // to resolve a "global variable" shadowing a "global parameter" relationship. builder.addGlobalVariableShadowingGlobalParameterDecoration( diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 7788a50d5d..c753600a7c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -277,6 +277,31 @@ bool isSimpleHLSLDataType(IRInst* inst) return true; } +bool isWrapperType(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_ArrayType: + case kIROp_TextureType: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_PtrType: + case kIROp_RefType: + case kIROp_ConstRefType: + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + case kIROp_HLSLRasterizerOrderedStructuredBufferType: + case kIROp_HLSLAppendStructuredBufferType: + case kIROp_HLSLConsumeStructuredBufferType: + case kIROp_TupleType: + case kIROp_OptionalType: + case kIROp_TypePack: + return true; + default: + return false; + } +} + SourceLoc findFirstUseLoc(IRInst* inst) { for (auto use = inst->firstUse; use; use = use->nextUse) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 9a712ba961..e23aeb6180 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -104,6 +104,8 @@ bool isSimpleDataType(IRType* type); bool isSimpleHLSLDataType(IRInst* inst); +bool isWrapperType(IRInst* inst); + SourceLoc findFirstUseLoc(IRInst* inst); inline bool isChildInstOf(IRInst* inst, IRInst* parent) diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp index e7fc811449..70f2584acc 100644 --- a/source/slang/slang-ir-vk-invert-y.cpp +++ b/source/slang/slang-ir-vk-invert-y.cpp @@ -104,10 +104,15 @@ void rcpWOfPositionInput(IRModule* module) [&](IRUse* use) { // Get the inverted vector. - builder.setInsertBefore(use->getUser()); - auto invertedVal = _invertWOfVector(builder, globalInst); - // Replace original uses with the invertex vector. - builder.replaceOperand(use, invertedVal); + auto user = use->getUser(); + if (user->getOp() == kIROp_Load) + { + builder.setInsertBefore(user); + auto val = builder.emitLoad(globalInst); + auto invertedVal = _invertWOfVector(builder, val); + user->replaceUsesWith(invertedVal); + user->removeAndDeallocate(); + } }); } } diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 907c2b8ba6..f76a0541c2 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1362,6 +1362,9 @@ struct LegalizeWGSLEntryPointContext void legalizeEntryPointForWGSL(EntryPointInfo entryPoint) { + // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. + depointerizeInputParams(entryPoint.entryPointFunc); + // Input Parameter Legalize flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ff1cd49eaa..daeaca67bf 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8243,6 +8243,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetStringHash: case kIROp_AllocateOpaqueHandle: case kIROp_GetArrayLength: + case kIROp_ResolveVaryingInputRef: + case kIROp_GetPerVertexInputArray: return false; case kIROp_ForwardDifferentiate: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 864d41ed76..fbd05f2ab0 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2767,14 +2767,15 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl) /// ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection) { - auto parentParent = getParentDecl(parentDecl); + auto parentParent = getParentAggTypeDecl(parentDecl); + // The `this` parameter for a `class` is always `in`. if (as(parentParent)) { return kParameterDirection_In; } - if (parentParent->findModifier()) + if (parentParent && parentParent->findModifier()) { if (parentDecl->hasModifier()) return kParameterDirection_Ref; @@ -2982,6 +2983,9 @@ struct IRLoweringParameterInfo // The direction (`in` vs `out` vs `in out`) ParameterDirection direction; + // The direction declared in user code. + ParameterDirection declaredDirection = ParameterDirection::kParameterDirection_In; + // The variable/parameter declaration for // this parameter (if any) VarDeclBase* decl = nullptr; @@ -3005,6 +3009,7 @@ IRLoweringParameterInfo getParameterInfo( info.type = getParamType(context->astBuilder, paramDecl); info.decl = paramDecl.getDecl(); info.direction = getParameterDirection(paramDecl.getDecl()); + info.declaredDirection = info.direction; info.isThisParam = false; return info; } @@ -3051,6 +3056,7 @@ void addThisParameter(ParameterDirection direction, Type* type, ParameterLists* info.type = type; info.decl = nullptr; info.direction = direction; + info.declaredDirection = direction; info.isThisParam = true; ioParameterLists->params.add(info); @@ -3064,10 +3070,22 @@ void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resu info.type = resultType; info.decl = nullptr; info.direction = kParameterDirection_Ref; + info.declaredDirection = info.direction; info.isReturnDestination = true; ioParameterLists->params.add(info); } } + +void makeVaryingInputParamConstRef(IRLoweringParameterInfo& paramInfo) +{ + if (paramInfo.direction != kParameterDirection_In) + return; + if (paramInfo.decl->findModifier()) + return; + if (as(paramInfo.type)) + return; + paramInfo.direction = kParameterDirection_ConstRef; +} // // And here is our function that will do the recursive walk: void collectParameterLists( @@ -3137,13 +3155,31 @@ void collectParameterLists( // if (auto callableDeclRef = declRef.as()) { + // We need a special case here when lowering the varying parameters of an entrypoint + // function. Due to the existence of `EvaluateAttributeAtSample` and friends, we need to + // always lower the varying inputs as `__constref` parameters so we can pass pointers to + // these intrinsics. + // This means that although these parameters are declared as "in" parameters in the source, + // we will actually treat them as __constref parameters when lowering to IR. A complication + // result from this is that if the original source code actually modifies the input + // parameter we still need to create a local var to hold the modified value. In the future + // when we are able to update our language spec to always assume input parameters are + // immutable, then we can remove this adhoc logic of introducing temporary variables. For + // For now we will rely on a follow up pass to remove unnecessary temporary variables if + // we can determine that they are never actually writtten to by the user. + // + bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier(); + // Don't collect parameters from the outer scope if // we are in a `static` context. if (mode == kParameterListCollectMode_Default) { for (auto paramDeclRef : getParameters(context->astBuilder, callableDeclRef)) { - ioParameterLists->params.add(getParameterInfo(context, paramDeclRef)); + auto paramInfo = getParameterInfo(context, paramDeclRef); + if (lowerVaryingInputAsConstRef) + makeVaryingInputParamConstRef(paramInfo); + ioParameterLists->params.add(paramInfo); } maybeAddReturnDestinationParam( ioParameterLists, @@ -5623,9 +5659,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBaseinnerExpr); - auto builder = getBuilder(); - auto irLoad = builder->emitLoad(inner.val); - return LoweredValInfo::simple(irLoad); + return LoweredValInfo::ptr(inner.val); } }; @@ -9980,6 +10014,22 @@ struct DeclLoweringVisitor : DeclVisitor if (paramInfo.isReturnDestination) subContext->returnDestination = paramVal; + if (paramInfo.declaredDirection == kParameterDirection_In && + paramInfo.direction == kParameterDirection_ConstRef) + { + // If the parameter is originally declared as "in", but we are + // lowering it as constref for any reason (e.g. it is a varying input), + // then we need to emit a local variable to hold the original value, so + // that we can still generate correct code when the user trys to mutate + // the variable. + // The local variable introduced here is cleaned up by the SSA pass, if + // we can determine that there are no actual writes into the local var. + auto irLocal = + subBuilder->emitVar(tryGetPointedToType(subBuilder, irParamType)); + auto localVal = LoweredValInfo::ptr(irLocal); + assign(subContext, localVal, paramVal); + paramVal = localVal; + } // TODO: We might want to copy the pointed-to value into // a temporary at the start of the function, and then copy // back out at the end, so that we don't have to worry @@ -10999,6 +11049,16 @@ static void lowerFrontEndEntryPointToIR( auto entryPointFuncDecl = entryPoint->getFuncDecl(); + if (!entryPointFuncDecl->findModifier()) + { + // If the entry point doesn't have an explicit `[shader("...")]` attribute, + // then we make sure to add one here, so the lowering logic knows it is an + // entry point. + auto entryPointAttr = context->astBuilder->create(); + entryPointAttr->capabilitySet = entryPoint->getProfile().getCapabilityName(); + addModifier(entryPointFuncDecl, entryPointAttr); + } + auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 22491c848b..c275a868b5 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7531,12 +7531,6 @@ static IRFloatingPointValue _foldFloatPrefixOp(TokenType tokenType, IRFloatingPo static std::optional parseSPIRVAsmOperand(Parser* parser) { - const auto slangIdentOperand = [&](auto flavor) - { - auto token = parser->tokenReader.peekToken(); - return SPIRVAsmOperand{flavor, token, parseAtomicExpr(parser)}; - }; - const auto slangTypeExprOperand = [&](auto flavor) { auto tok = parser->tokenReader.peekToken(); @@ -7673,12 +7667,13 @@ static std::optional parseSPIRVAsmOperand(Parser* parser) // A &foo variable reference (for the address of foo) else if (AdvanceIf(parser, TokenType::OpBitAnd)) { - return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr); + Expr* expr = parsePostfixExpr(parser); + return SPIRVAsmOperand{SPIRVAsmOperand::SlangValueAddr, Token{}, expr}; } // A $foo variable else if (AdvanceIf(parser, TokenType::Dollar)) { - Expr* expr = parseAtomicExpr(parser); + Expr* expr = parsePostfixExpr(parser); return SPIRVAsmOperand{SPIRVAsmOperand::SlangValue, Token{}, expr}; } // A $$foo type diff --git a/source/slang/slang-spirv-val.cpp b/source/slang/slang-spirv-val.cpp deleted file mode 100644 index e2b4da46c8..0000000000 --- a/source/slang/slang-spirv-val.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "slang-spirv-val.h" - -namespace Slang -{ - -SlangResult disassembleSPIRV(const List& spirv, String& outErr, String& outDis) -{ - // Set up our process - CommandLine commandLine; - commandLine.m_executableLocation.setName("spirv-dis"); - commandLine.addArg("--comment"); - commandLine.addArg("--color"); - RefPtr p; - - // If we failed to even start the process, then validation isn't available - SLANG_RETURN_ON_FAIL(Process::create(commandLine, 0, p)); - const auto in = p->getStream(StdStreamType::In); - const auto out = p->getStream(StdStreamType::Out); - const auto err = p->getStream(StdStreamType::ErrorOut); - - List outData; - List outErrData; - SLANG_RETURN_ON_FAIL( - StreamUtil::readAndWrite(in, spirv.getArrayView(), out, outData, err, outErrData)); - - SLANG_RETURN_ON_FAIL(p->waitForTermination(10)); - - outDis = String( - reinterpret_cast(outData.begin()), - reinterpret_cast(outData.end())); - - outErr = String( - reinterpret_cast(outErrData.begin()), - reinterpret_cast(outErrData.end())); - - const auto ret = p->getReturnValue(); - return ret == 0 ? SLANG_OK : SLANG_FAIL; -} - - -} // namespace Slang diff --git a/source/slang/slang-spirv-val.h b/source/slang/slang-spirv-val.h deleted file mode 100644 index 01e111f91a..0000000000 --- a/source/slang/slang-spirv-val.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "slang-compiler.h" - -#include - -namespace Slang -{ -SlangResult disassembleSPIRV(const List& spirv, String& outErr, String& outDis); -} diff --git a/tests/bugs/gh-5776.slang b/tests/bugs/gh-5776.slang new file mode 100644 index 0000000000..625a7b5ccb --- /dev/null +++ b/tests/bugs/gh-5776.slang @@ -0,0 +1,86 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile sm_6_0 -use-dxil -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -wgpu -output-using-type + + +interface IFoo +{ + associatedtype FooType : IFoo; +} + +extension float : IFoo +{ + typedef float FooType; +} + +__generic +extension Array : IFoo +{ + typedef Array FooType; +} + +__generic +extension vector : IFoo +{ + typedef vector FooType; +} + +__generic +extension matrix : IFoo +{ + typedef matrix FooType; +} + +struct WrappedBuffer +{ + StructuredBuffer buffer; + int shape; + + T get(int idx) { return buffer[idx]; } +} + + +struct GradInBuffer +{ + WrappedBuffer wrapBuffer; +} + +struct CallData +{ + GradInBuffer grad_in1; + GradInBuffer> grad_in2; + GradInBuffer grad_in3; +} + + +//TEST_INPUT: set call_data.grad_in1.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0], stride=4); +//TEST_INPUT: set call_data.grad_in2.wrapBuffer.buffer = ubuffer(data=[5.0 6.0 7.0 8.0], stride=4); +//TEST_INPUT: set call_data.grad_in3.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0], stride=4); +ParameterBlock call_data; + + +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer +RWStructuredBuffer outputBuffer; + + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + float[2] data1 = call_data.grad_in1.wrapBuffer.buffer[0]; + float[2] data2 = call_data.grad_in1.wrapBuffer.get(1); + outputBuffer[0] = data1[0]; + outputBuffer[1] = data2[0]; + + vector data3 = call_data.grad_in2.wrapBuffer.buffer[0]; + vector data4 = call_data.grad_in2.wrapBuffer.get(1); + outputBuffer[2] = data3[0]; + outputBuffer[3] = data4[0]; + + float2x2 data5 = call_data.grad_in3.wrapBuffer.buffer[0]; + float2x2 data6 = call_data.grad_in3.wrapBuffer.get(1); + outputBuffer[4] = data5[0][0]; + outputBuffer[5] = data6[0][0]; +} diff --git a/tests/bugs/gh-5776.slang.expected.txt b/tests/bugs/gh-5776.slang.expected.txt new file mode 100644 index 0000000000..ffde6889ee --- /dev/null +++ b/tests/bugs/gh-5776.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +1.000000 +3.000000 +5.000000 +7.000000 +1.000000 +5.000000 diff --git a/tests/bugs/gh-841.slang b/tests/bugs/gh-841.slang index ba746984b8..5f7e0c81fe 100644 --- a/tests/bugs/gh-841.slang +++ b/tests/bugs/gh-841.slang @@ -11,8 +11,8 @@ struct RasterVertex float4 c : COLOR; // Make sure that the input value in location 1 is decorated as Flat - // SPV-DAG: [[VAL:%[_A-Za-z0-9]+]] = OpVariable {{.*}} Input - // SPV-DAG: OpDecorate [[VAL]] Location 1 + // SPV-DAG: OpDecorate [[VAL:%[_A-Za-z0-9]+]] Location 1 + // SPV-DAG: [[VAL]] = OpVariable {{.*}} Input // SPV-DAG: OpDecorate [[VAL]] Flat // // Likewise for GLSL diff --git a/tests/bugs/vk-structured-buffer-load.hlsl b/tests/bugs/vk-structured-buffer-load.hlsl index d9e54d9253..ac8a86a5c1 100644 --- a/tests/bugs/vk-structured-buffer-load.hlsl +++ b/tests/bugs/vk-structured-buffer-load.hlsl @@ -1,4 +1,9 @@ //TEST:CROSS_COMPILE: -profile glsl_460+GL_NV_ray_tracing -entry HitMain -stage closesthit -target spirv-assembly +//TEST:SIMPLE(filecheck=DXIL): -target dxil -entry HitMain -stage closesthit -profile sm_6_5 +//TEST:SIMPLE(filecheck=SPV): -target spirv + +// DXIL: define void @ +// SPV: OpEntryPoint #define USE_RCP 0 diff --git a/tests/cross-compile/glsl-generic-in.slang b/tests/cross-compile/glsl-generic-in.slang index a743c32cb1..6bf2d28fb8 100644 --- a/tests/cross-compile/glsl-generic-in.slang +++ b/tests/cross-compile/glsl-generic-in.slang @@ -1,9 +1,9 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -entry main -profile vs_5_0 -emit-spirv-directly //TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -entry main -profile vs_5_0 -emit-spirv-via-glsl -// CHECK: vIn_field_v0{{.*}} = OpVariable %_ptr_Input_v4float Input -// CHECK: %vIn_field_v1{{.*}}= OpVariable %_ptr_Input_v2float Input -// CHECK: %vIn_p0{{.*}}= OpVariable %_ptr_Input_v3float Input +// CHECK-DAG: vIn_field_v0{{.*}} = OpVariable %_ptr_Input_v4float Input +// CHECK-DAG: %vIn_field_v1{{.*}}= OpVariable %_ptr_Input_v2float Input +// CHECK-DAG: %vIn_p0{{.*}}= OpVariable %_ptr_Input_v3float Input interface IField { diff --git a/tests/diagnostics/illegal-func-decl.slang b/tests/diagnostics/illegal-func-decl.slang index 0ec73dc277..64199a7913 100644 --- a/tests/diagnostics/illegal-func-decl.slang +++ b/tests/diagnostics/illegal-func-decl.slang @@ -6,9 +6,9 @@ //TEST:COMPILE: tests/diagnostics/illegal-func-decl-module.slang -o tests/diagnostics/illegal-func-decl-module.slang-module //DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK1): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST1 -target spirv -o illegal-func-decl.spv -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK2): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST2 -target spirv -o illegal-func-decl.spv +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK2): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST2 -target spirv -o illegal-func-decl.spv -skip-spirv-validation //DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK3): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST3 -target spirv -o illegal-func-decl.spv -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK4): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST4 -target spirv -o illegal-func-decl.spv +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK4): -r tests/diagnostics/illegal-func-decl-module.slang-module -DTEST4 -target spirv -o illegal-func-decl.spv -skip-spirv-validation #ifdef TEST1 // CHECK1: ([[# @LINE+1]]): error 45001: unresolved external symbol 'libraryFunction'. diff --git a/tests/expected-example-failure-github.txt b/tests/expected-example-failure-github.txt index ee3ad2fa12..7801ee564d 100644 --- a/tests/expected-example-failure-github.txt +++ b/tests/expected-example-failure-github.txt @@ -11,6 +11,5 @@ macos:aarch64:(debug|release):hello-world # See issue 5520 macos:aarch64:(debug|release):model-viewer # See issue 5520 macos:aarch64:(debug|release):ray-tracing # See issue 5520 macos:aarch64:(debug|release):ray-tracing-pipeline # See issue 5520 -windows:x86_64:debug:hello-world # See issue 5520 -windows:x86_64:debug:ray-tracing # See issue 5520 -windows:x86_64:debug:ray-tracing-pipeline # See issue 5520 +windows:x86_64:debug:ray-tracing # See issue 5988 +windows:x86_64:debug:ray-tracing-pipeline # See issue 5988 diff --git a/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang b/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang index 909679bbe4..c69752cd0d 100644 --- a/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang +++ b/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang @@ -76,8 +76,8 @@ bool testFragmentProcessingDerivativeFunctionsVector() } bool testFragmentProcessingInterpolateFunctions() { -// CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtCentroid %inDataV1 -// CHECK_GLSL: interpolateAtCentroid{{.*}}inDataV1 +// CHECK_SPV-DAG: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtCentroid %inDataV1 +// CHECK_GLSL-DAG: interpolateAtCentroid{{.*}}inDataV1 // CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtSample %inDataV1 {{.*}} // CHECK_GLSL: interpolateAtSample{{.*}}inDataV1 // CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtOffset %inDataV1 {{.*}} diff --git a/tests/glsl/matrix-mul.slang b/tests/glsl/matrix-mul.slang index 156673b87c..3bdd1cb8de 100644 --- a/tests/glsl/matrix-mul.slang +++ b/tests/glsl/matrix-mul.slang @@ -1,6 +1,6 @@ //TEST:SIMPLE(filecheck=SPIRV): -target spirv -stage vertex -entry main -allow-glsl -emit-spirv-directly //TEST:SIMPLE(filecheck=SPIRV): -target spirv -stage vertex -entry main -allow-glsl -//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry main -allow-glsl +//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry main -allow-glsl -matrix-layout-row-major #version 310 es layout(location = 0) in highp vec4 a_position; diff --git a/tests/hlsl-intrinsic/fragment-interpolate.slang b/tests/hlsl-intrinsic/fragment-interpolate.slang new file mode 100644 index 0000000000..f64e4e13b9 --- /dev/null +++ b/tests/hlsl-intrinsic/fragment-interpolate.slang @@ -0,0 +1,17 @@ +//TEST:SIMPLE(filecheck=CHECK_HLSL): -target hlsl -stage fragment -entry main +//TEST:SIMPLE(filecheck=CHECK_SPV): -target spirv -emit-spirv-directly -stage fragment -entry main + +struct VertexOut +{ + float4 pos : SV_Position; + float3 color; +} + +// CHECK_SPV: %v_color = OpVariable %_ptr_Input_v3float Input +// CHECK_SPV: %{{.*}} = OpExtInst %v3float %{{.*}} InterpolateAtCentroid %v_color +// CHECK_HLSL: EvaluateAttributeAtCentroid(v_0.color_0) + +float4 main(VertexOut v) : SV_Target +{ + return float4(EvaluateAttributeAtCentroid(v.color), 1.0); +} \ No newline at end of file diff --git a/tests/language-feature/capability/capability3.slang b/tests/language-feature/capability/capability3.slang index 67099a1dad..4091b9c936 100644 --- a/tests/language-feature/capability/capability3.slang +++ b/tests/language-feature/capability/capability3.slang @@ -1,5 +1,5 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main -stage compute -//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry main -stage compute -ignore-capabilities +//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry main -stage compute -ignore-capabilities -skip-spirv-validation // CHECK_IGNORE_CAPS-NOT: error 36108 // Test that capabilities can be declared on module. @@ -39,4 +39,4 @@ void main() { use1(); use2(); -} \ No newline at end of file +} diff --git a/tests/language-feature/capability/capability7.slang b/tests/language-feature/capability/capability7.slang index 21f3d68e41..011112a34a 100644 --- a/tests/language-feature/capability/capability7.slang +++ b/tests/language-feature/capability/capability7.slang @@ -1,5 +1,5 @@ //TEST:SIMPLE(filecheck=CHECK): -target glsl -entry computeMain -stage compute -profile sm_5_0 -//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target glsl -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 -ignore-capabilities +//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target glsl -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 -ignore-capabilities -skip-spirv-validation // Test that we diagnose simplified capabilities // CHECK_IGNORE_CAPS-NOT: error 36104 diff --git a/tests/language-feature/capability/conflicting-profile-stage-for-entry-point.slang b/tests/language-feature/capability/conflicting-profile-stage-for-entry-point.slang index 9cc06347f7..a888e7cf96 100644 --- a/tests/language-feature/capability/conflicting-profile-stage-for-entry-point.slang +++ b/tests/language-feature/capability/conflicting-profile-stage-for-entry-point.slang @@ -3,7 +3,7 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv -entry vsmain -profile vs_6_0 //TEST:SIMPLE(filecheck=CHECK): -target spirv -entry psmain -profile vs_6_0 -ignore-capabilities -//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry vsmain -profile ps_6_0 -ignore-capabilities +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry vsmain -profile ps_6_0 -ignore-capabilities -skip-spirv-validation // CHECK_ERROR: warning 36112 // CHECK-NOT: warning 36112 @@ -33,4 +33,4 @@ VSOutput vsmain(float3 PositionOS : POSITION, float3 Color : COLOR0) float4 psmain(VSOutput input) : SV_TARGET { return float4(input.Color, 1); -} \ No newline at end of file +} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang deleted file mode 100644 index d7bdbc69c9..0000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang +++ /dev/null @@ -1,17 +0,0 @@ -// get-attribute-at-vertex.slang - -// Basic test for `GetAttributeAtVertex` function - -//TEST:CROSS_COMPILE:-target dxil -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment -profile sm_6_1 -//TEST:CROSS_COMPILE:-target spirv -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment -profile glsl_450 - -[shader("fragment")] -void main( - pervertex float4 color : COLOR, - float3 bary : SV_Barycentrics, - out float4 result : SV_Target) -{ - result = bary.x * GetAttributeAtVertex(color, 0) - + bary.y * GetAttributeAtVertex(color, 1) - + bary.z * GetAttributeAtVertex(color, 2); -} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl deleted file mode 100644 index 820918d8bc..0000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl +++ /dev/null @@ -1,20 +0,0 @@ -// get-attribute-at-vertex.slang.glsl -//TEST_IGNORE_FILE: - -#version 450 -#extension GL_EXT_fragment_shader_barycentric : require -layout(row_major) uniform; -layout(row_major) buffer; - -pervertexEXT layout(location = 0) -in vec4 color_0[3]; - -layout(location = 0) -out vec4 result_0; - -void main() -{ - result_0 = gl_BaryCoordEXT.x * ((color_0)[(0U)]) + gl_BaryCoordEXT.y * ((color_0)[(1U)]) + gl_BaryCoordEXT.z * ((color_0)[(2U)]); - return; -} - diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl deleted file mode 100644 index a6b45eab4f..0000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl +++ /dev/null @@ -1,16 +0,0 @@ -// get-attribute-at-vertex.slang.hlsl - -//TEST_IGNORE_FILE: - -#pragma warning(disable: 3557) - -[shader("pixel")] -void main( - nointerpolation vector color_0 : COLOR, - vector bary_0 : SV_BARYCENTRICS, - out vector result_0 : SV_TARGET) -{ - result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U) - + bary_0.y * GetAttributeAtVertex(color_0, 1U) - + bary_0.z * GetAttributeAtVertex(color_0, 2U); -} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang b/tests/pipeline/rasterization/get-attribute-at-vertex.slang index 9ae347a3a3..c334200fb7 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang @@ -2,8 +2,6 @@ // Basic test for `GetAttributeAtVertex` function -//TEST:CROSS_COMPILE:-target dxil -entry main -stage fragment -profile sm_6_1 -//TEST:CROSS_COMPILE:-target spirv -entry main -stage fragment -profile glsl_450+GL_EXT_fragment_shader_barycentric //TEST:SIMPLE(filecheck=CHECK):-emit-spirv-directly -target spirv -entry main -stage fragment -profile glsl_450+GL_EXT_fragment_shader_barycentric // CHECK: OpCapability FragmentBarycentricKHR diff --git a/tests/pipeline/rasterization/varying-to-inout.slang b/tests/pipeline/rasterization/varying-to-inout.slang new file mode 100644 index 0000000000..7a54fd82fe --- /dev/null +++ b/tests/pipeline/rasterization/varying-to-inout.slang @@ -0,0 +1,22 @@ +// Test passing a varying parameter direclty to an inout parameter. + +//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry main -stage fragment + +// CHECK: OpEntryPoint Fragment %main "main" +struct PS_IN +{ + float3 pos : SV_Position; + float4 color : COLOR; +} + +void test(inout PS_IN v) +{ + v.color = v.color + v.pos.x; +} + +[shader("fragment")] +float4 main(PS_IN psIn):SV_Target +{ + test(psIn); + return psIn.color; +} diff --git a/tests/spirv/array-uniform-param.slang b/tests/spirv/array-uniform-param.slang index 235e85bbd0..672543b9a5 100644 --- a/tests/spirv/array-uniform-param.slang +++ b/tests/spirv/array-uniform-param.slang @@ -1,8 +1,10 @@ // array-uniform-param.slang -//TESTD:SIMPLE:-target spirv -entry computeMain -stage compute -emit-spirv-directly -force-glsl-scalar-layout +//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry computeMain -stage compute -emit-spirv-directly -force-glsl-scalar-layout //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -output-using-type +// CHECK: OpEntryPoint + // Test direct SPIR-V emit on arrays in uniforms. //TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) diff --git a/tests/spirv/matrix-vertex-input.slang b/tests/spirv/matrix-vertex-input.slang index fc4af8c614..b6277bead6 100644 --- a/tests/spirv/matrix-vertex-input.slang +++ b/tests/spirv/matrix-vertex-input.slang @@ -1,23 +1,84 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv -// CHECK: OpVectorTimesMatrix +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=ROWMAJOR): -vk -output-using-type +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=ROWMAJOR): -d3d11 -output-using-type -struct Vertex +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=COLMAJOR): -vk -output-using-type -emit-spirv-directly -xslang -DCOLUMN_MAJOR +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=COLMAJOR): -d3d11 -output-using-type -xslang -DCOLUMN_MAJOR + +// Check that row_major and column_major matrix typed vertex input are correctly handled. + +//TEST_INPUT: Texture2D(size=4, content = one):name t +//TEST_INPUT: Sampler:name s +//TEST_INPUT: ubuffer(data=[0], stride=4):out, name outputBuffer + +Texture2D t; +SamplerState s; +RWStructuredBuffer outputBuffer; + +cbuffer Uniforms { - float4x4 m; - float4 pos; + float4x4 modelViewProjection; } -struct VertexOut +struct AssembledVertex +{ + float3 position; + float3 color; + float2 uv; +#ifdef COLUMN_MAJOR + column_major float4x4 m; +#else + row_major float4x4 m; +#endif +}; + +struct CoarseVertex +{ + float3 color; +}; + +struct Fragment +{ + float4 color; +}; + +// Vertex Shader + +struct VertexStageInput +{ + AssembledVertex assembledVertex : A; +}; + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +VertexStageOutput vertexMain(VertexStageInput input) { - float4 pos : SV_Position; - float4 color; + VertexStageOutput output; + output.coarseVertex.color = input.assembledVertex.m[1][2]; + output.sv_position = mul(modelViewProjection, float4(input.assembledVertex.position, 1.0)); + return output; } -[shader("vertex")] -VertexOut vertMain(Vertex v) +struct FragmentStageInput { - VertexOut o; - o.pos = mul(v.m, v.pos); - o.color = v.pos; - return o; -} \ No newline at end of file + CoarseVertex coarseVertex : CoarseVertex; +}; + +struct FragmentStageOutput +{ + Fragment fragment : SV_Target; +}; + +FragmentStageOutput fragmentMain(FragmentStageInput input) +{ + FragmentStageOutput output; + float3 color = input.coarseVertex.color; + output.fragment.color = float4(color, 1.0); + outputBuffer[0] = color.x; + // ROWMAJOR: 7.0 + // COLMAJOR: 10.0 + return output; +} diff --git a/tests/spirv/nested-entrypoint.slang b/tests/spirv/nested-entrypoint.slang new file mode 100644 index 0000000000..28e9b9c4a4 --- /dev/null +++ b/tests/spirv/nested-entrypoint.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -fvk-use-entrypoint-name + +// CHECK: OpEntryPoint + +RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void innerMain(int id : SV_DispatchThreadID) +{ + output[id] = id; +} + +[shader("compute")] +[numthreads(1,1,1)] +void outerMain(int id : SV_DispatchThreadID) +{ + innerMain(id); +} \ No newline at end of file diff --git a/tests/spirv/optional-vertex-output.slang b/tests/spirv/optional-vertex-output.slang index df15befa22..7baf02d0b6 100644 --- a/tests/spirv/optional-vertex-output.slang +++ b/tests/spirv/optional-vertex-output.slang @@ -20,7 +20,10 @@ struct VSOut { VSOut vertMain(VIn i) { VSOut o; - o.a = i.inA; + if (i.inA.hasValue) + o.a = i.inA; + else + o.a = 0.0; o.outputValues = { true, false, true }; return o; } \ No newline at end of file diff --git a/tests/vkray/anyhit.slang b/tests/vkray/anyhit.slang index 45d35b1fae..8f5a6e597f 100644 --- a/tests/vkray/anyhit.slang +++ b/tests/vkray/anyhit.slang @@ -57,7 +57,6 @@ void main( // SPIRV: OpEntryPoint // SPIRV: BuiltIn HitTriangleVertexPositionsKHR // SPIRV: OpTypePointer HitAttribute{{NV|KHR}} -// SPIRV: OpTypePointer HitAttribute{{NV|KHR}} // SPIRV: OpVariable{{.*}}HitAttribute{{NV|KHR}} // SPIRV: OpIgnoreIntersectionKHR // SPIRV: OpTerminateRayKHR @@ -70,7 +69,6 @@ void main( // GL_SPIRV: OpEntryPoint // GL_SPIRV: BuiltIn HitTriangleVertexPositionsKHR // GL_SPIRV-DAG: OpTypePointer HitAttribute{{NV|KHR}} -// GL_SPIRV-DAG: OpTypePointer HitAttribute{{NV|KHR}} // GL_SPIRV: OpTerminateRayKHR // GL_SPIRV: OpIgnoreIntersectionKHR // GL_SPIRV-DAG: %{{.*}} = OpAccessChain %{{.*}} %{{.*}} %{{.*}} diff --git a/tests/vkray/anyhit.slang.glsl b/tests/vkray/anyhit.slang.glsl index 8255599b93..4d2e5a0dd7 100644 --- a/tests/vkray/anyhit.slang.glsl +++ b/tests/vkray/anyhit.slang.glsl @@ -8,7 +8,7 @@ struct Params_0 }; layout(binding = 0) -layout(std140) uniform _S1 +layout(std140) uniform block_Params_0 { int mode_0; }gParams_0; @@ -23,20 +23,21 @@ struct SphereHitAttributes_0 vec3 normal_0; }; -hitAttributeEXT SphereHitAttributes_0 _S2; +hitAttributeEXT SphereHitAttributes_0 _S1; struct ShadowRay_0 { vec4 hitDistance_0; + vec3 dummyOut_0; }; -rayPayloadInEXT ShadowRay_0 _S3; +rayPayloadInEXT ShadowRay_0 _S2; void main() { if(gParams_0.mode_0 != 0) { - if((textureLod(sampler2D(gParams_alphaMap_0,gParams_sampler_0), (_S2.normal_0.xy), (0.0)).x) > 0.0) + if((textureLod(sampler2D(gParams_alphaMap_0,gParams_sampler_0), (_S1.normal_0.xy), (0.0)).x) > 0.0) { terminateRayEXT;; } @@ -45,6 +46,14 @@ void main() ignoreIntersectionEXT;; } } + + vec3 _S3 = (gl_HitTriangleVertexPositionsEXT[(0U)]); + _S2.dummyOut_0 = _S3; + vec3 _S4 = (gl_HitTriangleVertexPositionsEXT[(1U)]); + vec3 _S5 = _S3 + _S4; + _S2.dummyOut_0 = _S5; + vec3 _S6 = (gl_HitTriangleVertexPositionsEXT[(2U)]); + _S2.dummyOut_0 = _S5 + _S6; return; } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index dfa19d42ff..1f2c3033ab 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -235,6 +235,8 @@ if(SLANG_ENABLE_TESTS) slang-llvm copy-webgpu_dawn copy-slang-tint + # See issue 5305. + copy-prebuilt-binaries FOLDER test DEBUG_DIR ${slang_SOURCE_DIR} ) diff --git a/tools/gfx/d3d12/d3d12-device.cpp b/tools/gfx/d3d12/d3d12-device.cpp index cb54049611..35192fd7d9 100644 --- a/tools/gfx/d3d12/d3d12-device.cpp +++ b/tools/gfx/d3d12/d3d12-device.cpp @@ -2231,6 +2231,7 @@ Result DeviceImpl::createAccelerationStructure( IAccelerationStructure** outAS) { #if SLANG_GFX_HAS_DXR_SUPPORT + assert(desc.buffer != nullptr); RefPtr result = new AccelerationStructureImpl(); result->m_device5 = m_device5; result->m_buffer = static_cast(desc.buffer); diff --git a/tools/platform/window.h b/tools/platform/window.h index 4ff9e245f6..654f0daab4 100644 --- a/tools/platform/window.h +++ b/tools/platform/window.h @@ -237,33 +237,19 @@ class Application #define GFX_DUMP_LEAK _CrtDumpMemoryLeaks(); #endif #endif + +#endif + #ifndef GFX_DUMP_LEAK #define GFX_DUMP_LEAK #endif -#define PLATFORM_UI_MAIN(APPLICATION_ENTRY) \ - int __stdcall wWinMain( \ - void* /*instance*/, \ - void* /* prevInstance */, \ - void* /* commandLine */, \ - int /*showCommand*/ \ - ) \ - { \ - platform::Application::init(); \ - auto result = APPLICATION_ENTRY(0, nullptr); \ - platform::Application::dispose(); \ - GFX_DUMP_LEAK \ - return result; \ - } - -#else #define PLATFORM_UI_MAIN(APPLICATION_ENTRY) \ - int main(int argc, char** argv) \ + int exampleMain(int argc, char** argv) \ { \ platform::Application::init(); \ auto rs = APPLICATION_ENTRY(argc, argv); \ platform::Application::dispose(); \ + GFX_DUMP_LEAK \ return rs; \ } - -#endif diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 5907be66d5..b1f9575512 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -45,12 +45,16 @@ struct Vertex float position[3]; float color[3]; float uv[2]; + float customData0[4]; + float customData1[4]; + float customData2[4]; + float customData3[4]; }; static const Vertex kVertexData[] = { - {{0, 0, 0.5}, {1, 0, 0}, {0, 0}}, - {{0, 1, 0.5}, {0, 0, 1}, {1, 0}}, - {{1, 0, 0.5}, {0, 1, 0}, {1, 1}}, + {{0, 0, 0.5}, {1, 0, 0}, {0, 0}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, + {{0, 1, 0.5}, {0, 0, 1}, {1, 0}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, + {{1, 0, 0.5}, {0, 1, 0}, {1, 1}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, }; static const int kVertexCount = SLANG_COUNT_OF(kVertexData); @@ -614,6 +618,10 @@ SlangResult RenderTestApp::initialize( {"A", 0, Format::R32G32B32_FLOAT, offsetof(Vertex, position)}, {"A", 1, Format::R32G32B32_FLOAT, offsetof(Vertex, color)}, {"A", 2, Format::R32G32_FLOAT, offsetof(Vertex, uv)}, + {"A", 3, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData0)}, + {"A", 4, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData1)}, + {"A", 5, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData2)}, + {"A", 6, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData3)}, }; ComPtr inputLayout; diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index 0987cbf804..a8127c3a82 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -47,6 +47,20 @@ using namespace Slang; +#if defined(_WIN32) +// https://devblogs.microsoft.com/directx/gettingstarted-dx12agility/#2.-set-agility-sdk-parameters + +extern "C" +{ + __declspec(dllexport) extern const uint32_t D3D12SDKVersion = 711; +} + +extern "C" +{ + __declspec(dllexport) extern const char* D3D12SDKPath = u8".\\D3D12\\"; +} +#endif + // Options for a particular test struct TestOptions {