Skip to content

Commit

Permalink
Correct IR generation for no-diff pointer type (#5976)
Browse files Browse the repository at this point in the history
* Correct IR generation for no-diff pointer type

Close #5805 

There is an issue on checking whether a pointer type parameter
is no_diff, we should first check whether this parameter is
an Attribute type first, then check the data type.

In the back-propagate pass, for the pointer type parameter, we should
load this parameter to a temp variable, then pass it to the primal
function call. Otherwise, the temp variable will no be initialized,
which will cause the following calculation wrong.
  • Loading branch information
kaizhangNV authored Jan 2, 2025
1 parent e3b71cf commit d48cd13
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 10 deletions.
11 changes: 6 additions & 5 deletions source/slang/slang-ir-autodiff-rev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
{
// If primal parameter is mutable, we need to pass in a temp var.
auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
if (primalParamPtrType->getOp() == kIROp_InOutType)
{
// If the primal parameter is inout, we need to set the initial value.
builder.emitStore(tempVar, primalArg);
}

// We also need to setup the initial value of the temp var, otherwise
// the temp var will be uninitialized which could cause undefined behavior
// in the primal function.
builder.emitStore(tempVar, primalArg);

primalArgs.add(tempVar);
}
else
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-autodiff-transcriber-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
if (isNoDiffType(originalType))
return nullptr;

if (auto origPtrType = as<IRPtrTypeBase>(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
Expand Down
19 changes: 14 additions & 5 deletions source/slang/slang-ir-autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType(

bool isNoDiffType(IRType* paramType)
{
while (auto ptrType = as<IRPtrTypeBase>(paramType))
paramType = ptrType->getValueType();
while (auto attrType = as<IRAttributedType>(paramType))
while (paramType)
{
if (attrType->findAttr<IRNoDiffAttr>())
if (auto attrType = as<IRAttributedType>(paramType))
{
return true;
if (attrType->findAttr<IRNoDiffAttr>())
return true;

paramType = attrType->getBaseType();
}
else if (auto ptrType = as<IRPtrTypeBase>(paramType))
{
paramType = ptrType->getValueType();
}
else
{
return false;
}
}
return false;
Expand Down
40 changes: 40 additions & 0 deletions tests/autodiff/nodiff-ptr.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

[Differentiable]
float sumOfSquares(float x, float y, no_diff float4* test)
{
return x * x + y * y * (test->x + test->y + test->z);
}

//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly

//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4)
uniform float* ptr;

//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer
RWStructuredBuffer<float> outputBuffer;

[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain()
{
float4* testPtr = (float4*)ptr;

let result = sumOfSquares(2.0, 3.0, testPtr);

// Use forward differentiation to compute the gradient of the output w.r.t. x only.
let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr);

// Create a differentiable pair to pass in the primal value and to receive the gradient.
var dpX = diffPair(2.0);
var dpY = diffPair(3.0);

// Propagate the gradient of the output (1.0f) to the input parameters.
bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0);

outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58
outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4
outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58
outputBuffer[3] = dpX.d; // 2*x = 4

outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36
}
6 changes: 6 additions & 0 deletions tests/autodiff/nodiff-ptr.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type: float
58.000000
4.000000
58.000000
4.000000
36.000000

0 comments on commit d48cd13

Please sign in to comment.