Skip to content

Commit

Permalink
Fix markNonContextParamsAsSideEffectFree.
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe committed Jan 10, 2025
1 parent e8217c7 commit a6a63d8
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 1 deletion.
1 change: 1 addition & 0 deletions source/slang/slang-ir-autodiff-rev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame
auto ctxParam =
builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration);
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-autodiff-unzip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func)
{
for (auto param : func->getParams())
{
if (!isIntermediateContextType(param->getDataType()))
if (!param->findDecorationImpl(kIROp_PrimalContextDecoration))
builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration);
}
}
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)

// Mark a parameter as autodiff primal context.
INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0)

Expand Down
81 changes: 81 additions & 0 deletions tests/autodiff/max-iters.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d12 -use-dxil
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -vk
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -metal
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cuda
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cpu
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -wgpu

// Note: there is a bug in fxc compiler errorneously reporting infinite loop for this shader.
// Skipping d3d11 test to avoid the bug.
//DISABLE_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d11

struct GradientBuffer<let D : int>
{
RWStructuredBuffer<float> primal;
StructuredBuffer<float> grad;
int strides[D];

int toIndex(int idx[D]) {
int result = 0;
for (int i = 0; i < D; ++i)
result += strides[i] * idx[i];
return result;
}

[Differentiable]
void write(int[D] idx, float v) { primal[toIndex(idx)] = detach(v); }

[BackwardDerivativeOf(write)]
void write_bwd(int[D] idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[toIndex(idx)]); }

[Differentiable]
void store<let N : int>(int context[D - 1], in float value[N])
{
int idx[D];
//[ForceUnroll] /* Using ForceUnroll instead of MaxIters makes it work */
[MaxIters(2)]
for (int i = 0; i < D - 1; ++i)
idx[i] = context[i];
[ForceUnroll]
for (int i = 0; i < N; i++) {
idx[D - 1] = i;
write(idx, value[i]);
}
}
}

[Differentiable]
void test(GradientBuffer<2> buf, int[1] base, float[3] value)
{
buf.store(base, value);
}

float3 repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
{
float input[3];
input[0] = input[1] = input[2] = 1.0f;
var result = diffPair(input);
GradientBuffer<2> buf = { primal, grad, {3, 1} };
bwd_diff(test)(buf, { 1 }, result);
return float3(result.d[0], result.d[1], result.d[2]);
}

//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4)
uniform StructuredBuffer<float> grad_in;

//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4)
uniform RWStructuredBuffer<float> grad_out;

//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
uniform RWStructuredBuffer<float> output;

[shader("compute")]
[numthreads(1,1,1)]
void computeMain()
{
let result = repro(grad_out, grad_in);
// CHECK: 104.0
output[0] = result.x;
output[1] = result.y;
output[2] = result.z;
}
48 changes: 48 additions & 0 deletions tests/autodiff/property.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type
public struct ReadOnlyIndex
{
private int _idx;
__init(int i) { _idx = i; }
public property int idx { get { return _idx; } }
}
struct GradientBuffer
{
RWStructuredBuffer<float> primal;
StructuredBuffer<float> grad;

[Differentiable]
void write(int idx, float v) { primal[idx] = detach(v); }

[BackwardDerivativeOf(write)]
void write_bwd(int idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[idx]); }

[Differentiable]
void store(ReadOnlyIndex idx, float v) { write(idx.idx, v); }
}
[Differentiable]
void test(GradientBuffer buf, ReadOnlyIndex b, float x)
{
buf.store(b, x);
}
public float repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
{
DifferentialPair<float> result = diffPair(1.0f);
GradientBuffer buf = { primal, grad };
bwd_diff(test)(buf, ReadOnlyIndex(5), result);
return result.d;
}

//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> output;

//TEST_INPUT: set gPrimal = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
RWStructuredBuffer<float> gPrimal;
//TEST_INPUT: set gGrad = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
StructuredBuffer<float> gGrad;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 5.0
output[0] = repro(gPrimal, gGrad);
}
41 changes: 41 additions & 0 deletions tests/autodiff/trivial-primal.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type

struct GradientBuffer
{
StructuredBuffer<float> grads;

[Differentiable]
void write(int idx, float value) { /* Discard write */ }

[BackwardDerivativeOf(write)]
void write_bwd(int idx, inout DifferentialPair<float> d)
{
d = diffPair(d.p, grads[idx]);
}
}

[Differentiable]
void test(GradientBuffer dst, int idx, float v)
{
dst.write(idx, v);
}

//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4)
uniform StructuredBuffer<float> grad_in;

//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4)
uniform RWStructuredBuffer<float> grad_out;

//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
uniform RWStructuredBuffer<float> output;

[shader("compute")]
[numthreads(1,1,1)]
void computeMain()
{
GradientBuffer grads = { grad_in };
DifferentialPair<float> result = diffPair(1.0f);
bwd_diff(test)(grads, 0, result);
// CHECK: 101.0
output[0] = result.d; // Should return grad_in[0], but returns 0.0f instead
}

0 comments on commit a6a63d8

Please sign in to comment.