From 87d22dc000aec1f62451d2da4120e5f3f0aa3248 Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Tue, 31 Dec 2024 15:09:29 -0800 Subject: [PATCH 1/6] Correct IR generation for no-diff pointer type There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong. --- source/slang/slang-ir-autodiff-rev.cpp | 11 +++--- .../slang-ir-autodiff-transcriber-base.cpp | 3 ++ source/slang/slang-ir-autodiff.cpp | 6 +++ tests/autodiff/nodiff-ptr.slang | 37 +++++++++++++++++++ tests/autodiff/nodiff-ptr.slang.expected.txt | 9 +++++ 5 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 tests/autodiff/nodiff-ptr.slang create mode 100644 tests/autodiff/nodiff-ptr.slang.expected.txt diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index f0ac428c7f..36093518ae 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF { // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - if (primalParamPtrType->getOp() == kIROp_InOutType) - { - // If the primal parameter is inout, we need to set the initial value. - builder.emitStore(tempVar, primalArg); - } + + // We also need to setup the initial value of the temp var, otherwise + // the temp var will be uninitialized which could cause undefined behavior + // in the primal function. + builder.emitStore(tempVar, primalArg); + primalArgs.add(tempVar); } else diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index ada35689ca..1b3825a7d8 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* // If this is a PtrType (out, inout, etc..), then create diff pair from // value type and re-apply the appropropriate PtrType wrapper. // + if (isNoDiffType(originalType)) + return nullptr; + if (auto origPtrType = as(originalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5c05b08117..530b431f35 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,6 +126,12 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { + if (auto attrType = as(paramType)) + { + if (attrType->findAttr()) + return true; + } + while (auto ptrType = as(paramType)) paramType = ptrType->getValueType(); while (auto attrType = as(paramType)) diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang new file mode 100644 index 0000000000..aab6d5da8e --- /dev/null +++ b/tests/autodiff/nodiff-ptr.slang @@ -0,0 +1,37 @@ + +[Differentiable] +float sumOfSquares(float x, float y, no_diff float4* test) +{ + return x * x + y * y * (test->x + test->y + test->z); +} + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 1.0 1.0 0.0 0.0 0.0], stride=4):out, name outputBuffer +RWStructuredBuffer outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + float4* testPtr = &outputBuffer[0]; + + let result = sumOfSquares(2.0, 3.0, testPtr); + + // Use forward differentiation to compute the gradient of the output w.r.t. x only. + let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr); + + // Create a differentiable pair to pass in the primal value and to receive the gradient. + var dpX = diffPair(2.0); + var dpY = diffPair(3.0); + + // Propagate the gradient of the output (1.0f) to the input parameters. + bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0); + + outputBuffer[0].x = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[0].y = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 + outputBuffer[0].z = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[0].w = dpX.d; // 2*x = 4 + + outputBuffer[1].x = dpY.d; // 2*y * (1 + 2 +3) = 36 +} diff --git a/tests/autodiff/nodiff-ptr.slang.expected.txt b/tests/autodiff/nodiff-ptr.slang.expected.txt new file mode 100644 index 0000000000..53e1a28e68 --- /dev/null +++ b/tests/autodiff/nodiff-ptr.slang.expected.txt @@ -0,0 +1,9 @@ +type: float +58.000000 +4.000000 +58.000000 +4.000000 +36.000000 +0.000000 +0.000000 +0.000000 From 91b19433a566f134549845964e2e767adb234e13 Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Tue, 31 Dec 2024 18:49:53 -0800 Subject: [PATCH 2/6] address comment --- source/slang/slang-ir-autodiff.cpp | 9 +++++---- tests/autodiff/nodiff-ptr.slang | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 530b431f35..2c36ea721e 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -133,12 +133,13 @@ bool isNoDiffType(IRType* paramType) } while (auto ptrType = as(paramType)) - paramType = ptrType->getValueType(); - while (auto attrType = as(paramType)) { - if (attrType->findAttr()) + paramType = ptrType->getValueType(); + + if (auto attrType = as(paramType)) { - return true; + if (attrType->findAttr()) + return true; } } return false; diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang index aab6d5da8e..73529c9397 100644 --- a/tests/autodiff/nodiff-ptr.slang +++ b/tests/autodiff/nodiff-ptr.slang @@ -5,7 +5,7 @@ float sumOfSquares(float x, float y, no_diff float4* test) return x * x + y * y * (test->x + test->y + test->z); } -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation //TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 1.0 1.0 0.0 0.0 0.0], stride=4):out, name outputBuffer RWStructuredBuffer outputBuffer; From ef8e4d457c8bfc532bd7b9ea565eb3921896c46a Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Thu, 2 Jan 2025 09:34:48 -0800 Subject: [PATCH 3/6] address comment --- source/slang/slang-ir-autodiff.cpp | 20 +++++++++++--------- tests/autodiff/nodiff-ptr.slang | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 2c36ea721e..e699c2434f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,20 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { - if (auto attrType = as(paramType)) + for(;;) { - if (attrType->findAttr()) - return true; - } - - while (auto ptrType = as(paramType)) - { - paramType = ptrType->getValueType(); - if (auto attrType = as(paramType)) { if (attrType->findAttr()) return true; + + paramType = attrType->getBaseType(); + } + else if (auto ptrType = as(paramType)) + { + paramType = ptrType->getValueType(); + } + else + { + return false; } } return false; diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang index 73529c9397..aa60394aca 100644 --- a/tests/autodiff/nodiff-ptr.slang +++ b/tests/autodiff/nodiff-ptr.slang @@ -5,7 +5,7 @@ float sumOfSquares(float x, float y, no_diff float4* test) return x * x + y * y * (test->x + test->y + test->z); } -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly //TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 1.0 1.0 0.0 0.0 0.0], stride=4):out, name outputBuffer RWStructuredBuffer outputBuffer; From 90eb41a09a8d462a0189061410b31448acebd678 Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Thu, 2 Jan 2025 09:46:31 -0800 Subject: [PATCH 4/6] fix warning --- source/slang/slang-ir-autodiff.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index e699c2434f..766ce305c9 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,7 +126,7 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { - for(;;) + while(paramType) { if (auto attrType = as(paramType)) { From 447fa86e8610173650977b296cbae141fdc003af Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Thu, 2 Jan 2025 10:17:24 -0800 Subject: [PATCH 5/6] formatting --- source/slang/slang-ir-autodiff.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 766ce305c9..4edd8eabe6 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,7 +126,7 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { - while(paramType) + while (paramType) { if (auto attrType = as(paramType)) { From e46ebf3196591dbf3bbddf5f26b9cd786eaff6b2 Mon Sep 17 00:00:00 2001 From: kaizhangNV Date: Thu, 2 Jan 2025 11:35:34 -0800 Subject: [PATCH 6/6] correct the unit test --- tests/autodiff/nodiff-ptr.slang | 19 +++++++++++-------- tests/autodiff/nodiff-ptr.slang.expected.txt | 3 --- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang index aa60394aca..d20abddacc 100644 --- a/tests/autodiff/nodiff-ptr.slang +++ b/tests/autodiff/nodiff-ptr.slang @@ -7,14 +7,17 @@ float sumOfSquares(float x, float y, no_diff float4* test) //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly -//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 1.0 1.0 0.0 0.0 0.0], stride=4):out, name outputBuffer -RWStructuredBuffer outputBuffer; +//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4) +uniform float* ptr; + +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer +RWStructuredBuffer outputBuffer; [shader("compute")] [numthreads(1, 1, 1)] void computeMain() { - float4* testPtr = &outputBuffer[0]; + float4* testPtr = (float4*)ptr; let result = sumOfSquares(2.0, 3.0, testPtr); @@ -28,10 +31,10 @@ void computeMain() // Propagate the gradient of the output (1.0f) to the input parameters. bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0); - outputBuffer[0].x = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 - outputBuffer[0].y = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 - outputBuffer[0].z = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 - outputBuffer[0].w = dpX.d; // 2*x = 4 + outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 + outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[3] = dpX.d; // 2*x = 4 - outputBuffer[1].x = dpY.d; // 2*y * (1 + 2 +3) = 36 + outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36 } diff --git a/tests/autodiff/nodiff-ptr.slang.expected.txt b/tests/autodiff/nodiff-ptr.slang.expected.txt index 53e1a28e68..959cc68e4d 100644 --- a/tests/autodiff/nodiff-ptr.slang.expected.txt +++ b/tests/autodiff/nodiff-ptr.slang.expected.txt @@ -4,6 +4,3 @@ type: float 58.000000 4.000000 36.000000 -0.000000 -0.000000 -0.000000