Skip to content

Commit

Permalink
Add IDifferentiablePtrType support for arrays (#5576)
Browse files Browse the repository at this point in the history
* Add `IDifferentiablePtrType` support for arrays

- Also fixes an issue with spirv-emit of constructors that contain references to global params

* Fix GLSL legalization for arrays of resource types
  • Loading branch information
saipraveenb25 authored Nov 18, 2024
1 parent 05903f7 commit ec5e019
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 4 deletions.
6 changes: 6 additions & 0 deletions source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,12 @@ extension Array<T, N> : IDifferentiable
}
}

__generic<T : IDifferentiablePtrType, let N : int>
extension Array<T, N> : IDifferentiablePtrType
{
typedef Array<T.Differential, N> Differential;
}

__generic<each T : IDifferentiable>
extension Tuple<T> : IDifferentiable
{
Expand Down
7 changes: 5 additions & 2 deletions source/slang/slang-ir-autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,9 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
IRWitnessTable* table = nullptr;
if (target == DiffConformanceKind::Value)
{
SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType));
if (!isDifferentiableValueType((IRType*)arrayType))
return nullptr;

auto innerWitness = tryGetDifferentiableWitness(
builder,
as<IRArrayTypeBase>(arrayType)->getElementType(),
Expand Down Expand Up @@ -1360,7 +1362,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
}
else if (target == DiffConformanceKind::Ptr)
{
SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType));
if (!isDifferentiablePtrType((IRType*)arrayType))
return nullptr;

table = builder->createWitnessTable(
sharedContext->differentiablePtrInterfaceType,
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-specialize-resources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,9 @@ bool specializeResourceUsage(CodeGenContext* codeGenContext, IRModule* irModule)

bool isIllegalGLSLParameterType(IRType* type)
{
if (auto arrayType = as<IRArrayTypeBase>(type))
return isIllegalGLSLParameterType(arrayType->getElementType());

if (as<IRParameterGroupType>(type))
return true;
if (as<IRHLSLStructuredBufferTypeBase>(type))
Expand Down
10 changes: 8 additions & 2 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,15 +1465,21 @@ struct SPIRVLegalizationContext : public SourceEmitterBase

void maybeHoistConstructInstToGlobalScope(IRInst* inst)
{
// If all of the operands to this instruction are global, we can hoist
// this constructor to be a global too. This is important to make sure
// If all of the operands to this instruction are global, and are not global
// variables, we can hoist this constructor to be a global too.
// This is important to make sure
// that vectors made of constant components end up being emitted as
// constant vectors (using OpConstantComposite).
UIndex opIndex = 0;
for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount();
operand++, opIndex++)
{
if (operand->get()->getParent() != m_module->getModuleInst())
return;

if (as<IRGlobalParam>(operand->get()))
return;
}
inst->insertAtEnd(m_module->getModuleInst());
}

Expand Down
59 changes: 59 additions & 0 deletions tests/autodiff/diff-ptr-type-array.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE:-wgpu

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

// ----- MyPtrType definition -----
struct MyPtrType : IDifferentiablePtrType
{
typealias Differential = MyPtrType;

RWStructuredBuffer<float> buffer;
uint offset;

float load(uint idx) { return buffer[offset + idx]; }
void accumulate(uint idx, float value) { buffer[offset + idx] += value; }
}

[BackwardDerivative(load_bwd)]
float load(MyPtrType[2] b, uint idx)
{
return b[1].load(idx);
}

void load_bwd(DifferentialPtrPair<MyPtrType[2]> b, uint idx, float grad)
{
b.d[1].accumulate(idx, grad);
}

// ------
[Differentiable]
float reduce(MyPtrType a)
{
return load( { a, a }, 0) + load( { a, a }, 1);
}

[Differentiable]
float test(MyPtrType b)
{
return reduce(b);
}

[numthreads(1, 1, 1)]
void computeMain(uint id: SV_DispatchThreadID)
{
outputBuffer[0] = 1; // CHECK: 1
outputBuffer[1] = 2; // CHECK: 2

// Denote the first two elements in the buffer as the primal buffer and the last two elements
// for the derivative.
var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 });

bwd_diff(test)(b, 1.5f);

// Check locations [2] and [3] in the buffer
// CHECK: 1.5
// CHECK: 1.5
}

0 comments on commit ec5e019

Please sign in to comment.