From f3d9a360834fd8728e547605ed029c2492ca1aa4 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 9 Jan 2025 18:51:32 -0800 Subject: [PATCH] Fix `markNonContextParamsAsSideEffectFree`. --- source/slang/slang-ir-autodiff-rev.cpp | 1 + source/slang/slang-ir-autodiff-unzip.cpp | 2 +- source/slang/slang-ir-inst-defs.h | 2 + tests/autodiff/max-iters.slang | 81 ++++++++++++++++++++++++ tests/autodiff/property.slang | 48 ++++++++++++++ tests/autodiff/trivial-primal.slang | 41 ++++++++++++ 6 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 tests/autodiff/max-iters.slang create mode 100644 tests/autodiff/property.slang create mode 100644 tests/autodiff/trivial-primal.slang diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 5ac4016d7d..65ce69877f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame auto ctxParam = builder->emitParam(as(diffFunc->getDataType())->getParamType(paramCount - 1)); builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx")); + builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration); result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 49c1d9ff7e..6bc428ad61 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func) { for (auto param : func->getParams()) { - if (!isIntermediateContextType(param->getDataType())) + if (!param->findDecorationImpl(kIROp_PrimalContextDecoration)) builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration); } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 38e5f8869f..0592879907 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1033,6 +1033,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) + // Mark a parameter as autodiff primal context. + INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0) diff --git a/tests/autodiff/max-iters.slang b/tests/autodiff/max-iters.slang new file mode 100644 index 0000000000..c83057b432 --- /dev/null +++ b/tests/autodiff/max-iters.slang @@ -0,0 +1,81 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d12 -use-dxil +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -vk +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -metal +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cuda +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cpu +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -wgpu + +// Note: there is a bug in fxc compiler errorneously reporting infinite loop for this shader. +// Skipping d3d11 test to avoid the bug. +//DISABLE_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d11 + +struct GradientBuffer +{ + RWStructuredBuffer primal; + StructuredBuffer grad; + int strides[D]; + + int toIndex(int idx[D]) { + int result = 0; + for (int i = 0; i < D; ++i) + result += strides[i] * idx[i]; + return result; + } + + [Differentiable] + void write(int[D] idx, float v) { primal[toIndex(idx)] = detach(v); } + + [BackwardDerivativeOf(write)] + void write_bwd(int[D] idx, inout DifferentialPair d) { d = diffPair(d.p, grad[toIndex(idx)]); } + + [Differentiable] + void store(int context[D - 1], in float value[N]) + { + int idx[D]; + //[ForceUnroll] /* Using ForceUnroll instead of MaxIters makes it work */ + [MaxIters(2)] + for (int i = 0; i < D - 1; ++i) + idx[i] = context[i]; + [ForceUnroll] + for (int i = 0; i < N; i++) { + idx[D - 1] = i; + write(idx, value[i]); + } + } +} + +[Differentiable] +void test(GradientBuffer<2> buf, int[1] base, float[3] value) +{ + buf.store(base, value); +} + +float3 repro(RWStructuredBuffer primal, StructuredBuffer grad) +{ + float input[3]; + input[0] = input[1] = input[2] = 1.0f; + var result = diffPair(input); + GradientBuffer<2> buf = { primal, grad, {3, 1} }; + bwd_diff(test)(buf, { 1 }, result); + return float3(result.d[0], result.d[1], result.d[2]); +} + +//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4) +uniform StructuredBuffer grad_in; + +//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4) +uniform RWStructuredBuffer grad_out; + +//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4) +uniform RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + let result = repro(grad_out, grad_in); + // CHECK: 104.0 + output[0] = result.x; + output[1] = result.y; + output[2] = result.z; +} \ No newline at end of file diff --git a/tests/autodiff/property.slang b/tests/autodiff/property.slang new file mode 100644 index 0000000000..e15b9a75ad --- /dev/null +++ b/tests/autodiff/property.slang @@ -0,0 +1,48 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type +public struct ReadOnlyIndex +{ + private int _idx; + __init(int i) { _idx = i; } + public property int idx { get { return _idx; } } +} +struct GradientBuffer +{ + RWStructuredBuffer primal; + StructuredBuffer grad; + + [Differentiable] + void write(int idx, float v) { primal[idx] = detach(v); } + + [BackwardDerivativeOf(write)] + void write_bwd(int idx, inout DifferentialPair d) { d = diffPair(d.p, grad[idx]); } + + [Differentiable] + void store(ReadOnlyIndex idx, float v) { write(idx.idx, v); } +} +[Differentiable] +void test(GradientBuffer buf, ReadOnlyIndex b, float x) +{ + buf.store(b, x); +} +public float repro(RWStructuredBuffer primal, StructuredBuffer grad) +{ + DifferentialPair result = diffPair(1.0f); + GradientBuffer buf = { primal, grad }; + bwd_diff(test)(buf, ReadOnlyIndex(5), result); + return result.d; +} + +//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer output; + +//TEST_INPUT: set gPrimal = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4) +RWStructuredBuffer gPrimal; +//TEST_INPUT: set gGrad = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4) +StructuredBuffer gGrad; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK: 5.0 + output[0] = repro(gPrimal, gGrad); +} \ No newline at end of file diff --git a/tests/autodiff/trivial-primal.slang b/tests/autodiff/trivial-primal.slang new file mode 100644 index 0000000000..d56c463997 --- /dev/null +++ b/tests/autodiff/trivial-primal.slang @@ -0,0 +1,41 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type + +struct GradientBuffer +{ + StructuredBuffer grads; + + [Differentiable] + void write(int idx, float value) { /* Discard write */ } + + [BackwardDerivativeOf(write)] + void write_bwd(int idx, inout DifferentialPair d) + { + d = diffPair(d.p, grads[idx]); + } +} + +[Differentiable] +void test(GradientBuffer dst, int idx, float v) +{ + dst.write(idx, v); +} + +//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4) +uniform StructuredBuffer grad_in; + +//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4) +uniform RWStructuredBuffer grad_out; + +//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4) +uniform RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + GradientBuffer grads = { grad_in }; + DifferentialPair result = diffPair(1.0f); + bwd_diff(test)(grads, 0, result); + // CHECK: 101.0 + output[0] = result.d; // Should return grad_in[0], but returns 0.0f instead +} \ No newline at end of file