Skip to content

Commit

Permalink
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispa…
Browse files Browse the repository at this point in the history
…tch for differentiable functions (#5866)

* Overhauled the auto-diff system for dynamic dispatch

* More fixes

* remove intermediate dumps

* Update slang-ast-type.h

* More fixes + add a workaround for existential no-diff

* Update reverse-control-flow-3.slang

* remove dumps

* remove more dumps

* Delete working-reverse-control-flow-3.hlsl

* Cleanup comments + unused variables

* More comment cleanup

* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`

* Fix array of issues in Falcor tests.

* Update slang-ir-autodiff-pairs.cpp

* More fixes for Falcor image tests

* Small fixups.

---------

Co-authored-by: Yong He <[email protected]>
  • Loading branch information
saipraveenb25 and csyonghe authored Jan 9, 2025
1 parent 6706c1a commit 87f00a3
Show file tree
Hide file tree
Showing 25 changed files with 1,927 additions and 526 deletions.
26 changes: 17 additions & 9 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9247,12 +9247,16 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type from interface.
auto interfaceType =
DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
auto thisType = DeclRefType::create(
m_astBuilder,
createDefaultSubstitutionsIfNeeded(
m_astBuilder,
this,
makeDeclRef(interfaceDecl->getThisTypeDecl())));

// If the interface is differentiable, make the this type a pair.
if (tryGetDifferentialType(getASTBuilder(), interfaceType))
reqDecl->diffThisType = getDifferentialPairType(interfaceType);
if (tryGetDifferentialType(getASTBuilder(), thisType))
reqDecl->diffThisType = getDifferentialPairType(thisType);
}

auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
Expand All @@ -9277,13 +9281,17 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
reqDecl->parentDecl = interfaceDecl;
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type from interface.
auto interfaceType =
DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
// Build decl-ref-type for this-type.
auto thisType = DeclRefType::create(
m_astBuilder,
createDefaultSubstitutionsIfNeeded(
m_astBuilder,
this,
makeDeclRef(interfaceDecl->getThisTypeDecl())));

// If the interface is differentiable, make the this type a pair.
if (tryGetDifferentialType(getASTBuilder(), interfaceType))
reqDecl->diffThisType = getDifferentialPairType(interfaceType);
if (tryGetDifferentialType(getASTBuilder(), thisType))
reqDecl->diffThisType = getDifferentialPairType(thisType);
}

auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
Expand Down
28 changes: 26 additions & 2 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,18 @@ Result linkAndOptimizeIR(
bool changed = false;
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
changed |= specializeModule(targetProgram, irModule, codeGenContext->getSink());
{
// Pre-autodiff, we will attempt to specialize as much as possible.
//
// Note: Lowered dynamic-dispatch code cannot be differentiated correctly due to
// missing information, so we defer that to after the auto-dff step.
//
SpecializationOptions specOptions;
specOptions.lowerWitnessLookups = false;
changed |=
specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
}

if (codeGenContext->getSink()->getErrorCount() != 0)
return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
Expand Down Expand Up @@ -867,9 +878,20 @@ Result linkAndOptimizeIR(
reportCheckpointIntermediates(codeGenContext, sink, irModule);

// Finalization is always run so AD-related instructions can be removed,
// even the AD pass itself is not run.
// even if the AD pass itself is not run.
//
finalizeAutoDiffPass(targetProgram, irModule);
eliminateDeadCode(irModule, deadCodeEliminationOptions);

// After auto-diff, we can perform more aggressive specialization with dynamic-dispatch
// lowering.
//
if (!codeGenContext->isSpecializationDisabled())
{
SpecializationOptions specOptions;
specOptions.lowerWitnessLookups = true;
specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
}

finalizeSpecialization(irModule);

Expand Down Expand Up @@ -930,6 +952,8 @@ Result linkAndOptimizeIR(

validateIRModuleIfEnabled(codeGenContext, irModule);

inferAnyValueSizeWhereNecessary(targetProgram, irModule);

// If we have any witness tables that are marked as `KeepAlive`,
// but are not used for dynamic dispatch, unpin them so we don't
// do unnecessary work to lower them.
Expand Down
143 changes: 119 additions & 24 deletions source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns
return InstPair(primalVal, diffVal);
}

InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation(
IRBuilder* builder,
IRInst* origInst)
{
auto primalAnnotation =
as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst));

IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst);

differentiableTypeConformanceContext.addTypeToDictionary(
(IRType*)primalAnnotation->getBaseType(),
primalAnnotation->getWitness());

auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType());
if (!diffType)
return InstPair(primalAnnotation, nullptr);

auto diffTypeDiffWitness =
tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any);

IRInst* args[] = {diffType, diffTypeDiffWitness};

auto diffAnnotation = builder->emitIntrinsicInst(
builder->getVoidType(),
kIROp_DifferentiableTypeAnnotation,
2,
args);

builder->markInstAsPrimal(diffAnnotation);
builder->markInstAsPrimal(primalAnnotation);

return InstPair(primalAnnotation, diffAnnotation);
}

InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
Expand Down Expand Up @@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto pairValType = as<IRDifferentialPairTypeBase>(
pairPtrType ? pairPtrType->getValueType() : pairType);

auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(
&argBuilder,
pairValType);
auto diffType = differentiateType(&argBuilder, primalType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
Expand Down Expand Up @@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (diffArg)
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential(
(IRType*)diffType,
(IRType*)as<IRPtrTypeBase>(diffType)->getValueType(),
newVal);
markDiffTypeInst(
&afterBuilder,
Expand Down Expand Up @@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
}
}

{
// --WORKAROUND--
// This is a temporary workaround for a very specific case..
//
// If all the following are true:
// 1. the parameter type expects a differential pair,
// 2. the argument is derived from a no_diff type, and
// 3. the argument type is a run-time type (i.e. extract_existential_type),
// then we need to generate a differential 0, but the IR has no
// information on the diff witness.
//
// We will bypass the conformance system & brute-force the lookup for the interface
// keys, but the proper fix is to lower this key mapping during `no_diff` lowering.
//

// Condition 1
if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType)))
{
// Condition 3
if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
{
// Condition 2
if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType()))
{
// Force-differentiate the type (this will perform a search for the witness
// without going through the diff-type annotation list)
//
IRInst* witnessTable = nullptr;
auto diffType = differentiateExtractExistentialType(
&argBuilder,
extractExistentialType,
witnessTable);

auto pairType =
getOrCreateDiffPairType(&argBuilder, primalType, witnessTable);
auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst(
differentiableTypeConformanceContext.sharedContext->zeroMethodType,
witnessTable,
differentiableTypeConformanceContext.sharedContext
->zeroMethodStructKey);
auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr);
auto diffPair =
argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero);

args.add(diffPair);
continue;
}
}
}
}

// Argument is not differentiable.
// Add original/primal argument.
args.add(primalArg);
}

IRType* diffReturnType = nullptr;
diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());
auto primalReturnType =
(IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());

diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType);

if (!diffReturnType)
{
diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
diffReturnType = primalReturnType;
}

auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args);
Expand Down Expand Up @@ -1035,18 +1122,16 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
IRInst* diffBase = nullptr;
if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase))
{
auto diffType = differentiateType(builder, origSpecialize->getFullType());
if (diffBase)
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
args.add(primalSpecialize->getArg(i));
}
auto diffSpecialize = builder->emitSpecializeInst(
builder->getTypeKind(),
diffBase,
args.getCount(),
args.getBuffer());
auto diffSpecialize =
builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else
Expand Down Expand Up @@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
}

auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
IRFunc* diffFunc = nullptr;

// If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
// insert location unchanged). If we're transcribing it as a declaration, we should
// insert into the module.
//
auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc)
{
// Dealing with a declaration.. insert into module scope.
IRBuilder subBuilder = *inBuilder;
subBuilder.setInsertInto(inBuilder->getModule());
diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc);
}
else
{
diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
}

if (auto outerGen = findOuterGeneric(diffFunc))
{
Expand Down Expand Up @@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRBuilder builder = *inBuilder;

maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);

differentiableTypeConformanceContext.setFunc(origFunc);

auto diffFunc = builder.createFunc();
Expand All @@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I

// Transfer checkpoint hint decorations
copyCheckpointHints(&builder, origFunc, diffFunc);

// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
}
return diffFunc;
}

Expand Down Expand Up @@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Reinterpret:
return transcribeReinterpret(builder, origInst);

case kIROp_DifferentiableTypeAnnotation:
return transcribeDifferentiableTypeAnnotation(builder, origInst);

// Differentiable insts that should have been lowered in a previous pass.
case kIROp_SwizzledStore:
{
Expand Down Expand Up @@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(

if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType))
{
auto diffType = differentiateType(builder, (IRType*)origParam->getFullType());
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(
(IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
builder,
as<IRDifferentialPairTypeBase>(diffPairType)),
diffPairParam));
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-autodiff-fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase

InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst);

InstPair transcribeDifferentiableTypeAnnotation(IRBuilder* builder, IRInst* origInst);

virtual IRFuncType* differentiateFunctionType(
IRBuilder* builder,
IRInst* func,
Expand Down
Loading

0 comments on commit 87f00a3

Please sign in to comment.