From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [PATCH] [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 26 +- source/slang/slang-emit.cpp | 28 +- source/slang/slang-ir-autodiff-fwd.cpp | 143 +- source/slang/slang-ir-autodiff-fwd.h | 2 + source/slang/slang-ir-autodiff-pairs.cpp | 95 +- .../slang/slang-ir-autodiff-primal-hoist.cpp | 141 +- source/slang/slang-ir-autodiff-rev.cpp | 49 +- .../slang-ir-autodiff-transcriber-base.cpp | 112 +- source/slang/slang-ir-autodiff-transpose.h | 210 +-- source/slang/slang-ir-autodiff-unzip.cpp | 16 + source/slang/slang-ir-autodiff-unzip.h | 91 +- source/slang/slang-ir-autodiff.cpp | 1177 ++++++++++++++--- source/slang/slang-ir-autodiff.h | 73 +- source/slang/slang-ir-inst-defs.h | 3 + source/slang/slang-ir-insts.h | 12 + source/slang/slang-ir-link.cpp | 3 +- source/slang/slang-ir-lower-generics.cpp | 2 - source/slang/slang-ir-specialize.cpp | 18 +- source/slang/slang-ir-specialize.h | 14 +- source/slang/slang-lower-to-ir.cpp | 75 +- .../autodiff/dynamic-dispatch-material.slang | 142 ++ ...namic-dispatch-material.slang.expected.txt | 6 + .../autodiff/material2/DiffuseMaterial.slang | 5 + tests/autodiff/material2/GlossyMaterial.slang | 5 + .../material2/MxLayeredMaterial.slang | 5 + 25 files changed, 1927 insertions(+), 526 deletions(-) create mode 100644 tests/autodiff/dynamic-dispatch-material.slang create mode 100644 tests/autodiff/dynamic-dispatch-material.slang.expected.txt diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 3667a36ba2..04d5b7a75b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -9247,12 +9247,16 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* if (!decl->hasModifier()) { // Build decl-ref-type from interface. - auto interfaceType = - DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + auto thisType = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + makeDeclRef(interfaceDecl->getThisTypeDecl()))); // If the interface is differentiable, make the this type a pair. - if (tryGetDifferentialType(getASTBuilder(), interfaceType)) - reqDecl->diffThisType = getDifferentialPairType(interfaceType); + if (tryGetDifferentialType(getASTBuilder(), thisType)) + reqDecl->diffThisType = getDifferentialPairType(thisType); } auto reqRef = m_astBuilder->create(); @@ -9277,13 +9281,17 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* reqDecl->parentDecl = interfaceDecl; if (!decl->hasModifier()) { - // Build decl-ref-type from interface. - auto interfaceType = - DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + // Build decl-ref-type for this-type. + auto thisType = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + makeDeclRef(interfaceDecl->getThisTypeDecl()))); // If the interface is differentiable, make the this type a pair. - if (tryGetDifferentialType(getASTBuilder(), interfaceType)) - reqDecl->diffThisType = getDifferentialPairType(interfaceType); + if (tryGetDifferentialType(getASTBuilder(), thisType)) + reqDecl->diffThisType = getDifferentialPairType(thisType); } auto reqRef = m_astBuilder->create(); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b9217de416..cd1b177b27 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -815,7 +815,18 @@ Result linkAndOptimizeIR( bool changed = false; dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); if (!codeGenContext->isSpecializationDisabled()) - changed |= specializeModule(targetProgram, irModule, codeGenContext->getSink()); + { + // Pre-autodiff, we will attempt to specialize as much as possible. + // + // Note: Lowered dynamic-dispatch code cannot be differentiated correctly due to + // missing information, so we defer that to after the auto-dff step. + // + SpecializationOptions specOptions; + specOptions.lowerWitnessLookups = false; + changed |= + specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); + } + if (codeGenContext->getSink()->getErrorCount() != 0) return SLANG_FAIL; dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); @@ -867,9 +878,20 @@ Result linkAndOptimizeIR( reportCheckpointIntermediates(codeGenContext, sink, irModule); // Finalization is always run so AD-related instructions can be removed, - // even the AD pass itself is not run. + // even if the AD pass itself is not run. // finalizeAutoDiffPass(targetProgram, irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); + + // After auto-diff, we can perform more aggressive specialization with dynamic-dispatch + // lowering. + // + if (!codeGenContext->isSpecializationDisabled()) + { + SpecializationOptions specOptions; + specOptions.lowerWitnessLookups = true; + specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); + } finalizeSpecialization(irModule); @@ -930,6 +952,8 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); + inferAnyValueSizeWhereNecessary(targetProgram, irModule); + // If we have any witness tables that are marked as `KeepAlive`, // but are not used for dynamic dispatch, unpin them so we don't // do unnecessary work to lower them. diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 9f26f9d552..30c14f706a 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns return InstPair(primalVal, diffVal); } +InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation( + IRBuilder* builder, + IRInst* origInst) +{ + auto primalAnnotation = + as(maybeCloneForPrimalInst(builder, origInst)); + + IRDifferentiableTypeAnnotation* annotation = as(origInst); + + differentiableTypeConformanceContext.addTypeToDictionary( + (IRType*)primalAnnotation->getBaseType(), + primalAnnotation->getWitness()); + + auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType()); + if (!diffType) + return InstPair(primalAnnotation, nullptr); + + auto diffTypeDiffWitness = + tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any); + + IRInst* args[] = {diffType, diffTypeDiffWitness}; + + auto diffAnnotation = builder->emitIntrinsicInst( + builder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + + builder->markInstAsPrimal(diffAnnotation); + builder->markInstAsPrimal(primalAnnotation); + + return InstPair(primalAnnotation, diffAnnotation); +} + InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) { if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) @@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairValType = as( pairPtrType ? pairPtrType->getValueType() : pairType); - auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType( - &argBuilder, - pairValType); + auto diffType = differentiateType(&argBuilder, primalType); if (auto ptrParamType = as(diffParamType)) { // Create temp var to pass in/out arguments. @@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (diffArg) { auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential( - (IRType*)diffType, + (IRType*)as(diffType)->getValueType(), newVal); markDiffTypeInst( &afterBuilder, @@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } } } + + { + // --WORKAROUND-- + // This is a temporary workaround for a very specific case.. + // + // If all the following are true: + // 1. the parameter type expects a differential pair, + // 2. the argument is derived from a no_diff type, and + // 3. the argument type is a run-time type (i.e. extract_existential_type), + // then we need to generate a differential 0, but the IR has no + // information on the diff witness. + // + // We will bypass the conformance system & brute-force the lookup for the interface + // keys, but the proper fix is to lower this key mapping during `no_diff` lowering. + // + + // Condition 1 + if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType))) + { + // Condition 3 + if (auto extractExistentialType = as(primalType)) + { + // Condition 2 + if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType())) + { + // Force-differentiate the type (this will perform a search for the witness + // without going through the diff-type annotation list) + // + IRInst* witnessTable = nullptr; + auto diffType = differentiateExtractExistentialType( + &argBuilder, + extractExistentialType, + witnessTable); + + auto pairType = + getOrCreateDiffPairType(&argBuilder, primalType, witnessTable); + auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst( + differentiableTypeConformanceContext.sharedContext->zeroMethodType, + witnessTable, + differentiableTypeConformanceContext.sharedContext + ->zeroMethodStructKey); + auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr); + auto diffPair = + argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero); + + args.add(diffPair); + continue; + } + } + } + } + // Argument is not differentiable. // Add original/primal argument. args.add(primalArg); } IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType()); + auto primalReturnType = + (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); + + diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType); if (!diffReturnType) { - diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); + diffReturnType = primalReturnType; } auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args); @@ -1035,6 +1122,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( IRInst* diffBase = nullptr; if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase)) { + auto diffType = differentiateType(builder, origSpecialize->getFullType()); if (diffBase) { List args; @@ -1042,11 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( { args.add(primalSpecialize->getArg(i)); } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), - diffBase, - args.getCount(), - args.getBuffer()); + auto diffSpecialize = + builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else @@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc()); } - auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); + IRFunc* diffFunc = nullptr; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanged). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as(findOuterGeneric(origFunc)); + if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); + } if (auto outerGen = findOuterGeneric(diffFunc)) { @@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I IRBuilder builder = *inBuilder; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I // Transfer checkpoint hint decorations copyCheckpointHints(&builder, origFunc, diffFunc); - - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration()) - { - cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule()); - } return diffFunc; } @@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Reinterpret: return transcribeReinterpret(builder, origInst); + case kIROp_DifferentiableTypeAnnotation: + return transcribeDifferentiableTypeAnnotation(builder, origInst); + // Differentiable insts that should have been lowered in a previous pass. case kIROp_SwizzledStore: { @@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam( if (as(diffPairType) || as(diffPairType)) { + auto diffType = differentiateType(builder, (IRType*)origParam->getFullType()); return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - as(diffPairType)), - diffPairParam)); + builder->emitDifferentialPairGetDifferential(diffType, diffPairParam)); } else if (auto pairPtrType = as(diffPairType)) { diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 39e195464c..09b3f14b8d 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -94,6 +94,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst); + InstPair transcribeDifferentiableTypeAnnotation(IRBuilder* builder, IRInst* origInst); + virtual IRFuncType* differentiateFunctionType( IRBuilder* builder, IRInst* func, diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index a49a2f7625..c732263f0c 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -13,7 +13,6 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerPairType(IRBuilder* builder, IRType* pairType) { - builder->setInsertBefore(pairType); auto loweredPairType = pairBuilder->lowerDiffPairType(builder, pairType); return loweredPairType; } @@ -22,26 +21,81 @@ struct DiffPairLoweringPass : InstPassBase { if (auto makePairInst = as(inst)) { - bool isTrivial = false; auto pairType = as(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType)) + builder->setInsertBefore(makePairInst); + if (auto loweredPairType = (IRType*)lowerPairType(builder, pairType)) { - builder->setInsertBefore(makePairInst); - IRInst* result = nullptr; - if (isTrivial) + if (isRuntimeType(pairType->getValueType())) { - result = makePairInst->getPrimalValue(); + auto result = pairBuilder->emitExistentialMakePair( + builder, + loweredPairType, + makePairInst->getPrimalValue(), + makePairInst->getDifferentialValue()); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; + } + else if (auto typePack = as(pairType->getValueType())) + { + // TODO: Do we need to flatten the packs here? + + // If the type is a type pack, then the value must be in + // MakePair(MakeValuePack(p_0, p_1, ...), MakeValuePack(d_0, d_1, ...)) form + // Convert it to MakeValuePack(MakePair(p_0, d_0), MakePair(p_1, d_1), ...) + // and lower each MakePair. + // + + // Primal pack + auto primalValue = as(makePairInst->getPrimalValue()); + SLANG_ASSERT(primalValue); + + // Differential pack + auto diffValue = as(makePairInst->getDifferentialValue()); + SLANG_ASSERT(diffValue); + + // Expect the lowered pair type to be a type pack of pair types. + SLANG_ASSERT(as(loweredPairType)); + + List newValues; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto primalElement = primalValue->getOperand(i); + auto diffElement = diffValue->getOperand(i); + + auto loweredElementPairType = (IRType*)loweredPairType->getOperand(i); + + IRInst* operands[] = {primalElement, diffElement}; + + auto loweredMakePair = + builder->emitMakeStruct((IRType*)loweredElementPairType, 2, operands); + + newValues.add(loweredMakePair); + } + + auto newPack = builder->emitMakeValuePack( + loweredPairType, + newValues.getCount(), + newValues.getBuffer()); + + makePairInst->replaceUsesWith(newPack); + makePairInst->removeAndDeallocate(); + return newPack; } else { + IRInst* result = nullptr; + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue()}; result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; } - makePairInst->replaceUsesWith(result); - makePairInst->removeAndDeallocate(); - return result; } } @@ -58,12 +112,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getDiffInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; - diffFieldExtract = - pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase()); + diffFieldExtract = pairBuilder->emitDiffFieldAccess( + builder, + (IRType*)loweredType, + getDiffInst->getBase()); getDiffInst->replaceUsesWith(diffFieldExtract); getDiffInst->removeAndDeallocate(); return diffFieldExtract; @@ -77,13 +133,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getPrimalInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getPrimalInst); - IRInst* primalFieldExtract = nullptr; - primalFieldExtract = - pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + primalFieldExtract = pairBuilder->emitPrimalFieldAccess( + builder, + (IRType*)loweredType, + getPrimalInst->getBase()); getPrimalInst->replaceUsesWith(primalFieldExtract); getPrimalInst->removeAndDeallocate(); return primalFieldExtract; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index a3f6079acb..ef51611049 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -344,8 +344,18 @@ RefPtr AutodiffCheckpointPolicyBase::processFunc( continue; } + // General case: we'll add all primal operands to the work list. addPrimalOperandsToWorkList(child); + // Also add type annotations to the list, since these have to be made available to the + // function context. + // + if (as(child)) + { + checkpointInfo->recomputeSet.add(child); + addPrimalOperandsToWorkList(child); + } + // We'll be conservative with the decorations we consider as differential uses // of a primal inst, in order to avoid weird behaviour with some decorations // @@ -1333,7 +1343,7 @@ struct UseChain return result; } - void replace(IRBuilder* builder, IRInst* inst) + void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst) { SLANG_ASSERT(chain.getCount() > 0); @@ -1345,30 +1355,27 @@ struct UseChain return; } - IRCloneEnv env; - // Pop the last use, which is the base use that needs to be replaced. auto baseUse = chain.getLast(); chain.removeLast(); // Ensure that replacement inst is set as mapping for the baseUse. - env.mapOldValToNew[baseUse->get()] = inst; - - auto lastInstInChain = inst; + ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst; IRBuilder chainBuilder(builder->getModule()); setInsertAfterOrdinaryInst(&chainBuilder, inst); chain.reverse(); + chain.removeLast(); // Clone the rest of the chain. for (auto& use : chain) { - lastInstInChain = cloneInst(&env, &chainBuilder, use->get()); + ctx->cloneInstOutOfOrder(&chainBuilder, use->get()); } - // Replace the base use. - builder->replaceOperand(chain.getLast(), lastInstInChain); + // We won't actually replace the final use, because if there are multiple chains + // it can cause problems. The parent UseGraph will handle that. chain.clear(); } @@ -1380,13 +1387,93 @@ struct UseChain } }; +struct UseGraph +{ + // Set of linear paths to the base use. + // Note that some nodes may be common to multiple paths. + // + OrderedDictionary> chainSets; + + static UseGraph from( + IRInst* baseInst, + Func isRelevantUse, + Func passthroughInst) + { + UseGraph result; + for (auto use = baseInst->firstUse; use;) + { + auto nextUse = use->nextUse; + + auto chains = UseChain::from(use, isRelevantUse, passthroughInst); + for (auto& chain : chains) + { + auto finalUse = chain.chain.getFirst(); + + if (!result.chainSets.containsKey(finalUse)) + { + result.chainSets[finalUse] = List(); + } + + result.chainSets[finalUse].getValue().add(chain); + } + + use = nextUse; + } + return result; + } + + void replace(IRBuilder* builder, IRUse* use, IRInst* inst) + { + // Since we may have common nodes, we will use an out-of-order cloning context + // that can retroactively correct the uses as needed. + // + IROutOfOrderCloneContext ctx; + List chains = chainSets[use]; + for (auto chain : chains) + { + chain.replace(&ctx, builder, inst); + } + + if (!isTrivial()) + { + builder->setInsertBefore(use->getUser()); + auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get()); + + // Replace the base use. + builder->replaceOperand(use, lastInstInChain); + } + } + + bool isTrivial() + { + // We're trivial if there's only one chain, and it has only one use. + if (chainSets.getCount() != 1) + return false; + + auto& chain = chainSets.getFirst().value; + return chain.getCount() == 1; + } + + List getUniqueUses() const + { + List result; + + for (auto& pair : chainSets) + { + result.add(pair.key); + } + + return result; + } +}; + // Trim defBlockIndices based on the indices of out of scope uses. // static List maybeTrimIndices( const List& defBlockIndices, const Dictionary>& indexedBlockInfo, - const List& outOfScopeUses) + const List& outOfScopeUses) { // Go through uses, lookup the defBlockIndices, and remove any indices if they // are not present in any of the uses. (This is sort of slow...) @@ -1397,7 +1484,7 @@ static List maybeTrimIndices( bool found = false; for (const auto& use : outOfScopeUses) { - auto useInst = use.getUser(); + auto useInst = use->getUser(); auto useBlock = useInst->getParent(); auto useBlockIndices = indexedBlockInfo.getValue(as(useBlock)); if (useBlockIndices.contains(index)) @@ -1419,7 +1506,8 @@ bool canInstBeStored(IRInst* inst) // stored into variables or context structs as normal values. // if (as(inst->getDataType()) || as(inst->getDataType()) || - as(inst->getDataType()) || as(inst->getDataType())) + as(inst->getDataType()) || as(inst->getDataType()) || + !inst->getDataType()) return false; return true; @@ -1577,6 +1665,9 @@ RefPtr ensurePrimalAvailability( // auto isPassthroughInst = [&](IRInst* inst) { + if (as(inst)) + return false; + if (!canInstBeStored(inst)) return true; @@ -1590,16 +1681,9 @@ RefPtr ensurePrimalAvailability( return false; }; - List outOfScopeUses; - for (auto use = instToStore->firstUse; use;) - { - auto nextUse = use->nextUse; + UseGraph useGraph = UseGraph::from(instToStore, isRelevantUse, isPassthroughInst); - List useChains = UseChain::from(use, isRelevantUse, isPassthroughInst); - outOfScopeUses.addRange(useChains); - - use = nextUse; - } + List outOfScopeUses = useGraph.getUniqueUses(); if (outOfScopeUses.getCount() == 0) { @@ -1659,10 +1743,10 @@ RefPtr ensurePrimalAvailability( for (auto use : outOfScopeUses) { - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); List& useBlockIndices = - indexedBlockInfo[getBlock(use.getUser())]; + indexedBlockInfo[getBlock(use->getUser())]; IRInst* loadAddr = emitIndexedLoadAddressForVar( &builder, @@ -1670,7 +1754,8 @@ RefPtr ensurePrimalAvailability( defBlock, defBlockIndices, useBlockIndices); - use.replace(&builder, loadAddr); + + useGraph.replace(&builder, use, loadAddr); } if (!isRecomputeInst) @@ -1729,11 +1814,13 @@ RefPtr ensurePrimalAvailability( for (auto use : outOfScopeUses) { + // TODO: Prevent terminator insts from being treated as passthrough.. List useBlockIndices = - indexedBlockInfo[getBlock(use.getUser())]; - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); - use.replace( + indexedBlockInfo[getBlock(use->getUser())]; + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + useGraph.replace( &builder, + use, loadIndexedValue( &builder, localVar, diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 36093518ae..5ac4016d7d 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType( IRInst* func, IRFuncType* funcType) { - IRType* intermediateType = - builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + IRType* intermediateType = nullptr; if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as(outerGeneric)); } + else if (as(func)) + { + intermediateType = nullptr; + } + else + { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - - // The original func may not have a type dictionary if it is not originally marked as - // differentiable, in this case we would have already pulled the necessary types from - // the user-provided derivative function, so we are still fine. - if (origFunc->findDecoration()) - { - differentiableTypeConformanceContext.setFunc(origFunc); - } + differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration()) - { - builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); - cloneInst(&cloneEnv, &builder, dictDecor); - } + copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); @@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration( InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); + InstPair result; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanges). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as(findOuterGeneric(origFunc)); + if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc)) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + result = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + result = transcribeFuncHeaderImpl(inBuilder, origFunc); + } FuncBodyTranscriptionTask task; task.originalFunc = as(result.primal); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 1b3825a7d8..38a7a18bbd 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -256,7 +256,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o return nullptr; // Special-case for differentiable existential types. - if (as(origType) || as(origType)) + if (as(origType)) { if (differentiableTypeConformanceContext.lookUpConformanceForType( origType, @@ -269,6 +269,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o else return nullptr; } + else if (as(origType)) + { + SLANG_UNEXPECTED("unexpected associated type during auto-diff"); + } auto primalType = lookupPrimalInst(builder, origType, origType); if (primalType->getOp() == kIROp_Param && primalType->getParent() && @@ -324,9 +328,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, - differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -336,9 +338,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy { auto primalPairType = as(primalType); return builder->getDifferentialPairUserCodeType( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -406,6 +406,7 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType* type) case kIROp_ExtractExistentialType: case kIROp_InterfaceType: case kIROp_AssociatedType: + case kIROp_LookupWitness: return true; default: return false; @@ -460,47 +461,34 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative IRBuilder* builder, IRInst* origFunc) { - auto decor = origFunc->findDecoration(); - if (decor) - return; - // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a - // `IRUserDefinedBackwardDerivativeDecoration`. - auto udfDecor = origFunc->findDecoration(); - SLANG_RELEASE_ASSERT(udfDecor); - // We need to migrate the dictionary from the backward derivative func so we can properly - // differentiate the function header. - IRBuilder subBuilder = *builder; - subBuilder.setInsertBefore(origFunc); - - auto derivative = udfDecor->getBackwardDerivativeFunc(); - if (auto specialize = as(derivative)) - { - auto derivativeGeneric = cast(specialize->getBase()); - GenericChildrenMigrationContext migrationContext; - migrationContext.init( - derivativeGeneric, - cast(findOuterGeneric(origFunc)), - origFunc); - auto derivativeFunc = findGenericReturnVal(derivativeGeneric); - auto derivativeBlock = cast(derivativeFunc->getParent()); - for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; - dInst = dInst->getNextInst()) - { - migrationContext.cloneInst(&subBuilder, dInst); - } - auto udfDictDecor = - derivativeFunc->findDecoration(); - SLANG_RELEASE_ASSERT(udfDictDecor); - subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild()); - migrationContext.cloneInst(&subBuilder, udfDictDecor); - eliminateDeadCode(origFunc->getParent()); - } - else + // There's one corner case where our function may not have the differentiable type annotations. + // If the function is not declared differentiable, but has a custom derivative, we need to copy + // over any IRDifferentiableTypeAnnotation insts + if (auto udfDecor = origFunc->findDecoration()) { - auto udfDictDecor = derivative->findDecoration(); - if (udfDictDecor) + // We need to migrate the dictionary from the backward derivative func so we can properly + // differentiate the function header. + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(origFunc); + + auto derivative = udfDecor->getBackwardDerivativeFunc(); + if (auto specialize = as(derivative)) { - cloneDecoration(udfDictDecor, origFunc); + auto derivativeGeneric = cast(specialize->getBase()); + + GenericChildrenMigrationContext migrationContext; + migrationContext.init( + derivativeGeneric, + cast(findOuterGeneric(origFunc)), + origFunc); + auto derivativeFunc = findGenericReturnVal(derivativeGeneric); + auto derivativeBlock = cast(derivativeFunc->getParent()); + for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; + dInst = dInst->getNextInst()) + { + migrationContext.cloneInst(&subBuilder, dInst); + } + eliminateDeadCode(origFunc->getParent()); } } } @@ -575,8 +563,8 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* else return nullptr; } - auto diffType = differentiateType(builder, originalType); - if (diffType) + + if (tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Any)) return (IRType*)getOrCreateDiffPairType(builder, originalType); return nullptr; } @@ -690,6 +678,15 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod( return InstPair(primal, diffWitness); } } + else if (as(lookupInst->getDataType())) + { + if (auto diffType = differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)) + { + return InstPair(primal, diffType); + } + } auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); @@ -997,8 +994,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene if (auto innerFunc = as(innerVal)) { maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); - if (!innerFunc->findDecoration()) + // Is our function differentiable? + if (!(innerFunc->findDecoration() || + innerFunc->findDecoration() || + innerFunc->findDecoration() || + innerFunc->findDecoration())) + { return InstPair(origGeneric, nullptr); + } + differentiableTypeConformanceContext.setFunc(innerFunc); } else if (const auto funcType = as(innerVal)) @@ -1027,7 +1031,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene IRType* diffType = nullptr; if (primalType) { - diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + if (as(primalType)) + { + diffType = builder.getGenericKind(); + } + else + { + diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + } } diffGeneric->setFullType(diffType); @@ -1110,7 +1121,6 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); - if (pair.primal != pair.differential && !pair.primal->findDecoration() && !as(pair.primal)) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c84fd778c1..a5ed5814cb 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -623,8 +623,9 @@ struct DiffTransposePass { varInst->insertAtEnd(firstRevDiffBlock); - auto dzero = - emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType()); + auto dzero = diffTypeContext.emitDZeroOfDiffInstType( + &builder, + ptrPrimalType->getValueType()); builder.emitStore(varInst, dzero); } else @@ -726,7 +727,9 @@ struct DiffTransposePass auto gradValue = builder->emitLoad(accVar); builder->emitStore( accVar, - emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst))); + diffTypeContext.emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdInst))); return gradValue; } @@ -760,7 +763,7 @@ struct DiffTransposePass auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst); auto diffType = fwdInst->getDataType(); - auto zero = emitDZeroOfDiffInstType(&tempVarBuilder, primalType); + auto zero = diffTypeContext.emitDZeroOfDiffInstType(&tempVarBuilder, primalType); // Emit a var in the top-level differential block to hold the gradient, // and initialize it. @@ -925,8 +928,9 @@ struct DiffTransposePass } else { - phiParamRevGradInsts.add( - emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); + phiParamRevGradInsts.add(diffTypeContext.emitDZeroOfDiffInstType( + &builder, + tryGetPrimalTypeFromDiffInst(param))); } } else @@ -1177,7 +1181,8 @@ struct DiffTransposePass auto pairType = as(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); - auto diffZero = emitDZeroOfDiffInstType(builder, pairType->getValueType()); + auto diffZero = + diffTypeContext.emitDZeroOfDiffInstType(builder, pairType->getValueType()); // Initialize this var to (arg.primal, 0). builder->emitStore( @@ -1236,7 +1241,13 @@ struct DiffTransposePass argRequiresLoad.add(false); } - auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); + + auto revFnType = + this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType( + builder, + getResolvedInstForDecorations(baseFn), + baseFnType); + IRInst* revCallee = nullptr; if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness) { @@ -1615,7 +1626,7 @@ struct DiffTransposePass SLANG_ASSERT(primalType); // Clear the value at the differential address, by setting to 0. - IRInst* emptyVal = emitDZeroOfDiffInstType(builder, primalType); + IRInst* emptyVal = diffTypeContext.emitDZeroOfDiffInstType(builder, primalType); builder->emitStore(fwdStore->getPtr(), emptyVal); if (auto diffPairType = as(revVal->getDataType())) @@ -2071,7 +2082,7 @@ struct DiffTransposePass auto primalElementTypeDecor = updateInst->findDecoration(); SLANG_RELEASE_ASSERT(primalElementTypeDecor); - auto diffZero = emitDZeroOfDiffInstType( + auto diffZero = diffTypeContext.emitDZeroOfDiffInstType( builder, (IRType*)primalElementTypeDecor->getPrimalElementType()); SLANG_ASSERT(diffZero); @@ -2350,16 +2361,18 @@ struct DiffTransposePass { auto primalCondition = fwdInst->getOperand(0); - auto leftZero = - emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(1))); + auto leftZero = diffTypeContext.emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(1))); auto leftGradientInst = builder->emitIntrinsicInst( fwdInst->getOperand(1)->getDataType(), kIROp_Select, 3, List(primalCondition, revValue, leftZero).getBuffer()); - auto rightZero = - emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(2))); + auto rightZero = diffTypeContext.emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(2))); auto rightGradientInst = builder->emitIntrinsicInst( fwdInst->getOperand(2)->getDataType(), kIROp_Select, @@ -2527,7 +2540,8 @@ struct DiffTransposePass List zeroElements; for (Index i = 0; i < elementCount; ++i) { - auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementTypes[i]); + auto zeroElement = + diffTypeContext.emitDZeroOfDiffInstType(builder, primalElementTypes[i]); elementGrads.add(zeroElement); zeroElements.add(zeroElement); } @@ -2537,8 +2551,11 @@ struct DiffTransposePass if (elementGrads[i] == zeroElements[i]) elementGrads[i] = grad; else - elementGrads[i] = - emitDAddOfDiffInstType(builder, primalElementTypes[i], elementGrads[i], grad); + elementGrads[i] = diffTypeContext.emitDAddOfDiffInstType( + builder, + primalElementTypes[i], + elementGrads[i], + grad); }; for (auto gradient : gradients) @@ -2624,7 +2641,7 @@ struct DiffTransposePass gradient.targetInst, builder->emitMakeDifferentialPairUserCode( baseType, - emitDZeroOfDiffInstType(builder, baseType->getValueType()), + diffTypeContext.emitDZeroOfDiffInstType(builder, baseType->getValueType()), gradient.revGradInst), gradient.fwdGradInst)); } @@ -2640,7 +2657,9 @@ struct DiffTransposePass builder->emitMakeDifferentialPairUserCode( baseType, gradient.revGradInst, - emitDZeroOfDiffInstType(builder, fwdGetPrimal->getFullType())), + diffTypeContext.emitDZeroOfDiffInstType( + builder, + fwdGetPrimal->getFullType())), gradient.fwdGradInst)); } } @@ -2694,7 +2713,7 @@ struct DiffTransposePass (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType)); // Initialize with T.dzero() - auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType); + auto zeroValueInst = diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType); builder->emitStore(revGradVar, zeroValueInst); @@ -2764,7 +2783,7 @@ struct DiffTransposePass (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType)); // Initialize with T.dzero() - auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType); + auto zeroValueInst = diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType); builder->emitStore(revGradVar, zeroValueInst); @@ -2839,8 +2858,11 @@ struct DiffTransposePass continue; } - currentValue = - emitDAddOfDiffInstType(builder, aggPrimalType, currentValue, gradient.revGradInst); + currentValue = diffTypeContext.emitDAddOfDiffInstType( + builder, + aggPrimalType, + currentValue, + gradient.revGradInst); } return RevGradient( @@ -2919,7 +2941,7 @@ struct DiffTransposePass if (aggDiffType != nullptr) { // If type is non-null/non-void, call T.dzero() to produce a 0 gradient. - return emitDZeroOfDiffInstType(builder, aggPrimalType); + return diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType); } else { @@ -2951,146 +2973,6 @@ struct DiffTransposePass return nullptr; } - IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType) - { - if (auto arrayType = as(primalType)) - { - auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType( - builder, - arrayType->getElementType()); - SLANG_RELEASE_ASSERT(diffElementType); - auto diffArrayType = - builder->getArrayType(diffElementType, arrayType->getElementCount()); - auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType()); - return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); - } - else if (auto diffPairUserType = as(primalType)) - { - auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType()); - auto diffZero = primalZero; - auto diffType = primalZero->getFullType(); - auto diffWitness = - diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType); - auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); - return builder->emitMakeDifferentialPairUserCode( - diffDiffPairType, - primalZero, - diffZero); - } - else if (as(primalType) || as(primalType)) - { - // Pack a null value into an existential type. - auto existentialZero = builder->emitMakeExistential( - autodiffContext->differentiableInterfaceType, - diffTypeContext.emitNullDifferential(builder), - autodiffContext->nullDifferentialWitness); - - return existentialZero; - } - - auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); - - // Should exist. - SLANG_ASSERT(zeroMethod); - - return builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), - zeroMethod, - List()); - } - - IRInst* emitDAddForExistentialType( - IRBuilder* builder, - IRType* primalType, - IRInst* op1, - IRInst* op2) - { - auto existentialDAddFunc = diffTypeContext.getOrCreateExistentialDAddMethod(); - - // Should exist. - SLANG_ASSERT(existentialDAddFunc); - - return builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), - existentialDAddFunc, - List({op1, op2})); - } - - IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2) - { - if (auto arrayType = as(primalType)) - { - auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType( - builder, - arrayType->getElementType()); - SLANG_RELEASE_ASSERT(diffElementType); - auto arraySize = arrayType->getElementCount(); - - if (auto constArraySize = as(arraySize)) - { - List args; - for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++) - { - auto index = builder->getIntValue(builder->getIntType(), i); - auto op1Val = builder->emitElementExtract(diffElementType, op1, index); - auto op2Val = builder->emitElementExtract(diffElementType, op2, index); - args.add(emitDAddOfDiffInstType( - builder, - arrayType->getElementType(), - op1Val, - op2Val)); - } - auto diffArrayType = - builder->getArrayType(diffElementType, arrayType->getElementCount()); - return builder->emitMakeArray( - diffArrayType, - (UInt)args.getCount(), - args.getBuffer()); - } - else - { - // TODO: insert a runtime loop here. - SLANG_UNIMPLEMENTED_X("dadd of dynamic array."); - } - } - else if (auto diffPairUserType = as(primalType)) - { - auto diffType = - (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, diffPairUserType); - auto diffWitness = - diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType); - - auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1); - auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2); - auto primal = - emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2); - - auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1); - auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2); - auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2); - - auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); - return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff); - } - else if (as(primalType) || as(primalType)) - { - // If our type is existential, we need to handle the case where - // one or both of our operands are null-type. - // - return emitDAddForExistentialType(builder, primalType, op1, op2); - } - - auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType); - - // Should exist. - SLANG_ASSERT(addMethod); - - return builder->emitCallInst( - (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), - addMethod, - List(op1, op2)); - } - void addRevGradientForFwdInst(IRInst* fwdInst, RevGradient assignment) { if (!hasRevGradients(fwdInst)) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 9ee2cb4d2a..49c1d9ff7e 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -93,6 +93,22 @@ struct ExtractPrimalFuncContext as(findOuterGeneric(destFunc)), destFunc); + if (auto origGeneric = as(findOuterGeneric(originalFunc))) + { + // Clone in everything else except the return value. + IRBuilder subBuilder(destFunc); + builder.setInsertAfter(findOuterGeneric(destFunc)->getFirstBlock()->getLastParam()); + + // Clone in any hoistable insts. + for (auto child = origGeneric->getFirstBlock()->getFirstOrdinaryInst(); child; + child = child->getNextInst()) + { + if ((child != originalFunc) && !as(child) && + !as(child)) + migrationContext.cloneInst(&subBuilder, child); + } + } + originalFuncType = as(originalFunc->getDataType()); SLANG_RELEASE_ASSERT(originalFuncType); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 31c4dbf91e..556fb58a8a 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -201,49 +201,61 @@ struct DiffUnzipPass return nullptr; } - InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall) + IRInst* getIntermediateType(IRBuilder* builder, IRInst* baseFn) { - IRBuilder globalBuilder(autodiffContext->moduleInst->getModule()); - - auto fwdCalleeType = mixedCall->getCallee()->getDataType(); - auto baseFn = _getOriginalFunc(mixedCall); - SLANG_RELEASE_ASSERT(baseFn); - - auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->transcribe( - primalBuilder, - baseFn->getDataType()); - - IRInst* intermediateType = nullptr; - - if (auto specialize = as(baseFn)) + if (as(baseFn)) + { + return builder->getVoidType(); + } + else if (auto specialize = as(baseFn)) { + if (as(specialize->getBase())) + return builder->getVoidType(); + auto func = findSpecializeReturnVal(specialize); - auto outerGen = findOuterGeneric(func); - if (func->getOp() == kIROp_LookupWitness) + if (as(func)) { // An interface method won't have intermediate type. - intermediateType = primalBuilder->getVoidType(); + return builder->getVoidType(); } else { - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); + auto outerGen = findOuterGeneric(func); + auto innerIntermediateType = + builder->getBackwardDiffIntermediateContextType(outerGen); + List args; for (UInt i = 0; i < specialize->getArgCount(); i++) args.add(specialize->getArg(i)); - intermediateType = primalBuilder->emitSpecializeInst( - primalBuilder->getTypeKind(), - intermediateType, + + return builder->emitSpecializeInst( + builder->getTypeKind(), + innerIntermediateType, args.getCount(), args.getBuffer()); } } else { - if (baseFn->getOp() == kIROp_LookupWitness) - intermediateType = primalBuilder->getVoidType(); - else - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); + return builder->getBackwardDiffIntermediateContextType(baseFn); } + } + + InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall) + { + IRBuilder globalBuilder(autodiffContext->moduleInst->getModule()); + + auto fwdCalleeType = mixedCall->getCallee()->getDataType(); + auto baseFn = _getOriginalFunc(mixedCall); + SLANG_RELEASE_ASSERT(baseFn); + + auto primalFuncType = + autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType( + primalBuilder, + baseFn, + as(baseFn->getDataType())); + + IRInst* intermediateType = getIntermediateType(primalBuilder, baseFn); IRVar* intermediateVar = nullptr; if (!as(intermediateType)) @@ -314,8 +326,8 @@ struct DiffUnzipPass auto arg = mixedCall->getArg(ii); // Depending on the type and direction of each argument, - // we might need to prepare a different value for the transposition logic to produce the - // correct final argument in the propagate function call. + // we might need to prepare a different value for the transposition logic to produce + // the correct final argument in the propagate function call. if (isRelevantDifferentialPair(arg->getDataType())) { auto primalArg = lookupPrimalInst(arg); @@ -328,13 +340,13 @@ struct DiffUnzipPass if (const auto outType = as(primalParamType)) { - // For `out` parameters that expects an input derivative to propagate through, - // we insert a `LoadReverseGradient` inst here to signify the logic in - // `transposeStore` that this argument should actually be the currently + // For `out` parameters that expects an input derivative to propagate + // through, we insert a `LoadReverseGradient` inst here to signify the logic + // in `transposeStore` that this argument should actually be the currently // accumulated derivative on this variable. The end purpose is that we will // generate a load(diffArg) in the final transposed code and use that as the - // argument for the call, but we can't just emit a normal load inst here because - // the transposition logic will turn loads into stores. + // argument for the call, but we can't just emit a normal load inst here + // because the transposition logic will turn loads into stores. auto outDiffType = cast(diffArg->getDataType())->getValueType(); auto gradArg = diffBuilder->emitLoadReverseGradient(outDiffType, diffArg); diffBuilder->markInstAsDifferential(gradArg, primalArg->getDataType()); @@ -342,23 +354,24 @@ struct DiffUnzipPass } else if (const auto inoutType = as(primalParamType)) { - // Since arg is split into separate vars, we need a new temp var that represents - // the remerged diff pair. + // Since arg is split into separate vars, we need a new temp var that + // represents the remerged diff pair. auto diffPairType = as( as(arg->getDataType())->getValueType()); auto primalValueType = diffPairType->getValueType(); // We can't simply reuse primalArg for an inout parameter since this will // represent the value after the primal call which can potentially alter - // primalArg. Therefore, we will find the first store into primalArg, and create - // a temp var holding that value (i.e. value prior to primal call) + // primalArg. Therefore, we will find the first store into primalArg, and + // create a temp var holding that value (i.e. value prior to primal call) // auto storeUse = findUniqueStoredVal(cast(primalArg)); auto storeInst = cast(storeUse->getUser()); auto storedVal = storeInst->getVal(); - // Emit the temp var into the primal blocks since it's holding a primal value. + // Emit the temp var into the primal blocks since it's holding a primal + // value. auto tempPrimalVar = primalBuilder->emitVar(primalValueType); primalBuilder->emitStore(tempPrimalVar, storedVal); @@ -407,8 +420,8 @@ struct DiffUnzipPass // For pure 'in' type. Simply re-use the original argument inst. // // For 'out' type parameters, it doesn't really matter what we pass in here, - // since the tranposition logic will discard the argument anyway (we'll pass in - // the old arg, just to keep the number of arguments consistent) + // since the tranposition logic will discard the argument anyway (we'll pass + // in the old arg, just to keep the number of arguments consistent) // diffArgs.add(arg); } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4edd8eabe6..7507e2fac1 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -230,9 +230,6 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( } else if (auto specializedType = as(pairType)) { - // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's - // type, emit the specialization type. - // auto genericType = findInnerMostGenericReturnVal(as(specializedType->getBase())); if (auto genericBasePairStructType = as(genericType)) { @@ -263,14 +260,142 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( return nullptr; } -IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) +bool isExistentialOrRuntimeInst(IRInst* inst) { - return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + if (auto lookup = as(inst)) + { + return isExistentialOrRuntimeInst(lookup->getWitnessTable()); + } + + return as(inst) || as(inst) || + as(inst) || as(inst->getDataType()); } -IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) +bool isRuntimeType(IRType* type) { - return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + if (as(type)) + return true; + + if (auto lookup = as(type)) + { + return isExistentialOrRuntimeInst(lookup->getWitnessTable()); + } + + return false; +} + +IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type) +{ + if (auto lookupWitnessMethod = as(type)) + { + return lookupWitnessMethod->getWitnessTable(); + } + else if (auto extractExistentialType = as(type)) + { + return builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0)); + } + else + { + SLANG_UNEXPECTED("Unexpected existential type"); + } +} + +IRInst* getCacheKey(IRBuilder* builder, IRInst* primalType) +{ + if (auto lookupWitness = as(primalType)) + return lookupWitness->getRequirementKey(); + else if (auto extractExistentialType = as(primalType)) + { + auto interfaceType = extractExistentialType->getOperand(0)->getDataType(); + + // We will cache on the interface's this-type, since the interface type itself can be + // deallocated during the lowering process. + // + return builder->getThisType(interfaceType); + } + + return primalType; +} + +IRInst* DifferentialPairTypeBuilder::emitExistentialMakePair( + IRBuilder* builder, + IRInst* pairType, + IRInst* primalInst, + IRInst* diffInst) +{ + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)pairType); + + auto pairTypeKey = cast(pairType)->getRequirementKey(); + auto makePairKey = makePairKeyMap[pairTypeKey]; + + auto makePairMethod = builder->emitLookupInterfaceMethodInst( + makePairFuncTypeMap[makePairKey], + baseWitness, + makePairKey); + + List args; + args.add(primalInst); + args.add(diffInst); + + auto makePairVal = builder->emitCallInst((IRType*)pairType, makePairMethod, args); + + return makePairVal; +} + +IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess( + IRBuilder* builder, + IRType* loweredPairType, + IRInst* baseInst) +{ + if (isRuntimeType(loweredPairType)) + { + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType); + + auto pairTypeKey = cast(loweredPairType)->getRequirementKey(); + auto getPrimalKey = getPrimalKeyMap[pairTypeKey]; + + auto primalFieldMethod = builder->emitLookupInterfaceMethodInst( + getPrimalFuncTypeMap[getPrimalKey], + baseWitness, + getPrimalKey); + + auto primalFieldVal = + builder->emitCallInst(primalTypeMap[loweredPairType], primalFieldMethod, baseInst); + + return primalFieldVal; + } + else + { + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + } +} + +IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess( + IRBuilder* builder, + IRType* loweredPairType, + IRInst* baseInst) +{ + if (isRuntimeType(loweredPairType)) + { + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType); + + auto pairTypeKey = cast(loweredPairType)->getRequirementKey(); + auto getDiffKey = getDiffKeyMap[pairTypeKey]; + + auto diffFieldMethod = builder->emitLookupInterfaceMethodInst( + getDiffFuncTypeMap[getDiffKey], + baseWitness, + getDiffKey); + + auto diffFieldVal = + builder->emitCallInst(diffTypeMap[loweredPairType], diffFieldMethod, baseInst); + + return diffFieldVal; + } + else + { + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + } } IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey() @@ -307,6 +432,380 @@ IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey() return this->globalPrimalKey; } +IRInst* DifferentialPairTypeBuilder::getOrCreateCommonDiffPairInterface(IRBuilder* builder) +{ + if (!this->commonDiffPairInterface) + { + this->commonDiffPairInterface = builder->createInterfaceType(0, nullptr); + builder->addNameHintDecoration( + this->commonDiffPairInterface, + UnownedStringSlice("IDiffPair")); + } + + return this->commonDiffPairInterface; +} + +IRInst* DifferentialPairTypeBuilder::_createDiffPairInterfaceRequirement( + IRType* origBaseType, + IRType*) +{ + // We will create an interface requirement for the type's pair & then create implementations in + // all the implementing witness tables. + // + + IRBuilder builder(sharedContext->moduleInst); + + // Find the right interface to put the requirement in. + IRInterfaceType* interfaceType = nullptr; + + // Find the effective type to put in the requirement entry + // for the base type + // + IRType* requirementBaseType = nullptr; + + // Requirement key (only used for associated types) + // + IRInst* requirementKey = nullptr; + + // Add a name hint to the key. + StringBuilder nameBuilderReqKey; + nameBuilderReqKey << "DiffPair_Req_"; + + if (auto lookup = as(origBaseType)) + { + interfaceType = + cast(cast(lookup->getWitnessTable()->getDataType()) + ->getConformanceType()); + + requirementBaseType = + cast(findInterfaceRequirement(interfaceType, lookup->getRequirementKey())); + + requirementKey = lookup->getRequirementKey(); + + if (auto nameHint = lookup->getRequirementKey()->findDecoration()) + { + nameBuilderReqKey << nameHint->getName(); + } + else + { + nameBuilderReqKey << "unknown_assoc_type"; + } + } + else if (auto extractType = as(origBaseType)) + { + auto existentialType = extractType->getOperand(0); + interfaceType = cast(existentialType->getDataType()); + requirementBaseType = builder.getThisType(interfaceType); + + requirementKey = nullptr; + + if (auto nameHint = interfaceType->findDecoration()) + { + nameBuilderReqKey << nameHint->getName(); + } + else + { + nameBuilderReqKey << "unknown_interface_type"; + } + } + else + { + SLANG_UNEXPECTED("Unexpected type for differential pair interface requirement"); + } + + auto diffPairInterfaceType = + cast(getOrCreateCommonDiffPairInterface(&builder)); + + // Add 4 requirements to the interface: + // the associated pair type, getPrimal, getDiff & makePair + // + builder.setInsertInto(interfaceType); + IRStructKey* diffPairRequirementKey = builder.createStructKey(); + IRStructKey* getPrimalRequirementKey = builder.createStructKey(); + IRStructKey* getDiffRequirementKey = builder.createStructKey(); + IRStructKey* makePairRequirementKey = builder.createStructKey(); + + makePairKeyMap[diffPairRequirementKey] = makePairRequirementKey; + getPrimalKeyMap[diffPairRequirementKey] = getPrimalRequirementKey; + getDiffKeyMap[diffPairRequirementKey] = getDiffRequirementKey; + + List entries; + + // Add all the old requirements to the new interface. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + entries.add(interfaceType->getOperand(i)); + + // + // Create the new requirement entries. + // + + { + // Create & insert the requirement key. + List constraintTypes; + constraintTypes.add(diffPairInterfaceType); + auto entry = builder.createInterfaceRequirementEntry( + diffPairRequirementKey, + builder.getAssociatedType(constraintTypes.getArrayView())); + + builder.addNameHintDecoration(diffPairRequirementKey, nameBuilderReqKey.getUnownedSlice()); + entries.add(entry); + } + + { + // Create & insert the getPrimal requirement. + + List paramTypes; + List paramConstraintTypes; + paramConstraintTypes.add(diffPairInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + auto entryFuncType = builder.getFuncType(paramTypes, requirementBaseType); + auto entry = + builder.createInterfaceRequirementEntry(getPrimalRequirementKey, entryFuncType); + + getPrimalFuncTypeMap[getPrimalRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getPrimal"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create & insert the getDiff requirement. + + List paramTypes; + List paramConstraintTypes; + paramConstraintTypes.add(diffPairInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + List resultConstraintTypes; + resultConstraintTypes.add(sharedContext->differentiableInterfaceType); + auto resultType = builder.getAssociatedType(resultConstraintTypes.getArrayView()); + + auto entryFuncType = builder.getFuncType(paramTypes, resultType); + auto entry = builder.createInterfaceRequirementEntry(getDiffRequirementKey, entryFuncType); + + getDiffFuncTypeMap[getDiffRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getDiff"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create & insert the makePair requirement. + + List paramTypes; + paramTypes.add(requirementBaseType); + + List paramConstraintTypes; + paramConstraintTypes.add(sharedContext->differentiableInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + List resultConstraintTypes; + resultConstraintTypes.add(diffPairInterfaceType); + auto entryFuncType = builder.getFuncType( + paramTypes, + builder.getAssociatedType(resultConstraintTypes.getArrayView())); + auto entry = builder.createInterfaceRequirementEntry(makePairRequirementKey, entryFuncType); + + makePairFuncTypeMap[makePairRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_makePair"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create the new interface type. + + auto newInterfaceType = + builder.createInterfaceType(entries.getCount(), entries.getBuffer()); + + // Transfer decorations from the old interface to the new one. + interfaceType->transferDecorationsTo(newInterfaceType); + interfaceType->replaceUsesWith(newInterfaceType); + + // Replace the interface maps in the caches. + if (this->pairTypeCache.containsKey(interfaceType)) + this->pairTypeCache[newInterfaceType] = this->pairTypeCache[interfaceType]; + + if (this->existentialPairTypeCache.containsKey(interfaceType)) + this->existentialPairTypeCache[newInterfaceType] = + this->existentialPairTypeCache[interfaceType]; + + interfaceType->removeAndDeallocate(); + interfaceType = newInterfaceType; + } + + // + // Implement the requirements in all the witness tables. + // + + // Collect all witness tables of the given interfaceType. + List concreteWitnessTables; + auto witnessTableType = builder.getWitnessTableType(interfaceType); + for (auto use = witnessTableType->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as(use->getUser())) + { + if (use->getUser()->getFullType() == witnessTableType) + concreteWitnessTables.add(witnessTable); + } + } + + DifferentiableTypeConformanceContext ctx(sharedContext); + ctx.buildGlobalWitnessDictionary(); + + for (auto concreteWitnessTable : concreteWitnessTables) + { + IRType* concretePrimalType = nullptr; + + // What requirement are we trying to satisfy? + if (as(requirementBaseType)) + { + // For this types, we should lower the concrete type of the witness table itself. + concretePrimalType = concreteWitnessTable->getConcreteType(); + } + else if (as(requirementBaseType)) + { + // For associated types, look it up in the witness table. + concretePrimalType = + (IRType*)findWitnessTableEntry(concreteWitnessTable, requirementKey); + } + else + { + // We shouldn't see any other case here. + SLANG_UNEXPECTED("Unexpected requirement base type"); + } + + // Create the pair type. + auto witness = ctx.tryGetDifferentiableWitness( + &builder, + concretePrimalType, + DiffConformanceKind::Value); + + // Really should not see a case where the original interface is differentiable, but + // we can't find the witness table. + // + SLANG_ASSERT(witness); + + auto concretePairType = builder.getDifferentialPairType( + concretePrimalType, + witness); // TODO: Need to handle the other conformance kinds + auto concreteDiffType = + (IRType*)_getDiffTypeFromPairType(sharedContext, &builder, concretePairType); + + auto loweredStructType = (IRType*)lowerDiffPairType(&builder, concretePairType); + + // Create an (empty) witness table for loweredStuctType : IDiffPair_... + // This is just so that there is a bound on the any-value-size for each group of pair types. + // + auto witnessTable = builder.createWitnessTable(diffPairInterfaceType, loweredStructType); + builder.addKeepAliveDecoration(witnessTable); + + builder.setInsertInto(concreteWitnessTable); + + // Create the associated type entry. + { + builder.createWitnessTableEntry( + concreteWitnessTable, + diffPairRequirementKey, + loweredStructType); + } + + // Create the getPrimal method. + { + auto primalMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_getPrimal"; + builder.addNameHintDecoration(primalMethod, nameBuilder.getUnownedSlice()); + + primalMethod->setFullType(builder.getFuncType( + List({(IRType*)loweredStructType}), + concretePrimalType)); + + builder.setInsertInto(primalMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param = builder.emitParam((IRType*)loweredStructType); + builder.emitReturn( + builder.emitFieldExtract(concretePrimalType, param, _getOrCreatePrimalStructKey())); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + getPrimalRequirementKey, + primalMethod); + } + + // Create the getDiff method. + { + auto diffMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_getDiff"; + builder.addNameHintDecoration(diffMethod, nameBuilder.getUnownedSlice()); + + diffMethod->setFullType( + builder.getFuncType(List({(IRType*)loweredStructType}), concreteDiffType)); + + builder.setInsertInto(diffMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param = builder.emitParam((IRType*)loweredStructType); + builder.emitReturn( + builder.emitFieldExtract(concreteDiffType, param, _getOrCreateDiffStructKey())); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + getDiffRequirementKey, + diffMethod); + } + + // Create the makePair method. + { + auto makePairMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_makePair"; + builder.addNameHintDecoration(makePairMethod, nameBuilder.getUnownedSlice()); + + makePairMethod->setFullType(builder.getFuncType( + List({concretePrimalType, concreteDiffType}), + (IRType*)loweredStructType)); + + builder.setInsertInto(makePairMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param1 = builder.emitParam(concretePrimalType); + auto param2 = builder.emitParam(concreteDiffType); + List args = {param1, param2}; + auto pair = builder.emitMakeStruct((IRType*)loweredStructType, args); + builder.emitReturn(pair); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + makePairRequirementKey, + makePairMethod); + } + } + + return diffPairRequirementKey; +} + IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType) { switch (origBaseType->getOp()) @@ -333,6 +832,7 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I return pairStructType; } + IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; @@ -352,26 +852,119 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRTyp // purposes. auto primalType = pairType->getValueType(); - if (pairTypeCache.tryGetValue(primalType, result)) - return result; - if (!pairType) + + if (isRuntimeType(primalType)) { - result = originalPairType; + // Existential case. + auto cacheKey = getCacheKey(builder, primalType); + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); + + IRInst* pairReqKey = nullptr; + if (!existentialPairTypeCache.tryGetValue(cacheKey, pairReqKey)) + { + pairReqKey = _createDiffPairInterfaceRequirement(primalType, (IRType*)diffType); + existentialPairTypeCache.add(cacheKey, pairReqKey); + } + + auto baseWitnessTable = getExistentialBaseWitnessTable(builder, primalType); + result = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + baseWitnessTable, + pairReqKey); + + primalTypeMap[result] = primalType; + diffTypeMap[result] = (IRType*)diffType; + return result; } - if (as(primalType)) + else if (auto typePack = as(primalType)) { - result = nullptr; - return result; + // Lower DiffPair(TypePack(a_0, a_1, ...), MakeWitnessPack(w_0, w_1, ...)) as + // TypePack(DiffPair(a_0, w_0), DiffPair(a_1, w_1), ...) + // + auto cacheKey = primalType; + if (pairTypeCache.tryGetValue(cacheKey, result)) + return result; + + auto packWitness = pairType->getWitness(); + + // Right now we only support concrete witness tables for type packs. + auto concretePackWitness = as(packWitness); + SLANG_ASSERT(concretePackWitness); + + // Get diff type pack. + IRTypePack* diffTypePack = nullptr; + + if (concretePackWitness->getConformanceType() == + this->sharedContext->differentiableInterfaceType) + diffTypePack = as(findWitnessTableEntry( + concretePackWitness, + this->sharedContext->differentialAssocTypeStructKey)); + else if ( + concretePackWitness->getConformanceType() == + this->sharedContext->differentiablePtrInterfaceType) + diffTypePack = as(findWitnessTableEntry( + concretePackWitness, + this->sharedContext->differentialAssocRefTypeStructKey)); + else + SLANG_UNEXPECTED("Unexpected witness table"); + + SLANG_ASSERT(diffTypePack); + + List args; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto type = (IRType*)typePack->getOperand(i); + auto diffType = (IRType*)typePack->getOperand(i); + + if (pairTypeCache.tryGetValue(type, result)) + { + args.add((IRType*)result); + continue; + } + + // Lower the diff pair type. + auto loweredPairType = (IRType*)_createDiffPairType(type, diffType); + + pairTypeCache.add(type, loweredPairType); + args.add(loweredPairType); + } + + auto loweredTypePack = builder->getTypePack(args.getCount(), args.getBuffer()); + // TODO: Unify the cache between the three cases. + pairTypeCache.add(cacheKey, loweredTypePack); + + return loweredTypePack; } + else + { + auto cacheKey = primalType; + if (pairTypeCache.tryGetValue(primalType, result)) + return result; - auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); - if (!diffType) - return result; - result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - pairTypeCache.add(primalType, result); + if (as(primalType)) + { + result = nullptr; + return result; + } + + if (as(primalType) || as(primalType)) + { + List constraintTypes; + constraintTypes.add(this->commonDiffPairInterface); + return builder->getAssociatedType(constraintTypes.getArrayView()); + } + + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); + if (!diffType) + return result; + + // Concrete case. + result = _createDiffPairType(primalType, (IRType*)diffType); + pairTypeCache.add(cacheKey, result); - return result; + return result; + } } IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) @@ -550,6 +1143,13 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit auto innerWitnessTableType = cast(operand); return cast(innerWitnessTableType->getConformanceType()); } + else if (auto genericWitness = as(witness)) + { + // This is a generic witness table. + auto innerWitness = getGenericReturnVal(genericWitness); + SLANG_ASSERT(as(innerWitness->getDataType())); + return getConformanceTypeFromWitness(innerWitness); + } else { SLANG_UNEXPECTED("Unexpected witness type"); @@ -558,81 +1158,134 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit return diffInterfaceType; } +List DifferentiableTypeConformanceContext::getAnnotations( + IRGlobalValueWithCode* code) +{ + // Scan function for all IRDifferentiableTypeAnnotation insts. + List annotations; + for (auto block : code->getBlocks()) + { + for (auto child : block->getChildren()) + { + if (auto annotation = as(child)) + { + annotations.add(annotation); + } + } + } + + return annotations; +} + +List DifferentiableTypeConformanceContext::getAnnotations( + IRModuleInst* module) +{ + // Scan module for all IRDifferentiableTypeAnnotation insts. + List annotations; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto annotation = as(globalInst)) + { + annotations.add(annotation); + } + } + + return annotations; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; + List annotations = getAnnotations(func); - auto decor = func->findDecoration(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) + // Go up the parents of func & add the annotations of any IRGeneric or IRModule parent: + IRInst* parent = func; + while (parent) { - if (auto item = as(child)) + if (auto upperFunc = as(parent)) { - IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + // TODO: Cache this. + auto parentAnnotations = getAnnotations(upperFunc); + annotations.addRange(parentAnnotations); + } + else if (auto module = as(parent)) + { + // TODO: Cache this. + auto parentAnnotations = getAnnotations(module); + annotations.addRange(parentAnnotations); + } + parent = parent->getParent(); + } - SLANG_ASSERT( - diffInterfaceType == sharedContext->differentiableInterfaceType || - diffInterfaceType == sharedContext->differentiablePtrInterfaceType); + for (auto item : annotations) + { + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); - auto existingItem = - differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); - if (existingItem) - { - *existingItem = item->getWitness(); - } - else - { - auto witness = item->getWitness(); + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType || + diffInterfaceType == sharedContext->differentiablePtrInterfaceType); - // Also register the type's differential type with the same witness. - auto concreteType = item->getConcreteType(); - IRBuilder subBuilder(item->getConcreteType()); - if (as(concreteType) || as(concreteType)) + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getBaseType()); + if (existingItem) + { + *existingItem = item->getWitness(); + } + else + { + auto witness = item->getWitness(); + + // Also register the type's differential type with the same witness. + auto concreteType = item->getBaseType(); + IRBuilder subBuilder(item->getBaseType()); + if (as(concreteType) || as(concreteType)) + { + // For tuple types with concrete element types, + // register the differential type for each element, but don't register for the + // tuple/typepack itself. + if (auto witnessPack = as(witness)) { - // For tuple types with concrete element types, - // register the differential type for each element, but don't register for the - // tuple/typepack itself. - if (auto witnessPack = as(witness)) + + for (UInt i = 0; i < concreteType->getOperandCount(); i++) { + auto element = concreteType->getOperand(i); + auto elementWitness = witnessPack->getOperand(i); - for (UInt i = 0; i < concreteType->getOperandCount(); i++) - { - auto element = concreteType->getOperand(i); - auto elementWitness = witnessPack->getOperand(i); - - if (diffInterfaceType == sharedContext->differentiableInterfaceType) - addTypeToDictionary((IRType*)element, elementWitness); - else if ( - diffInterfaceType == sharedContext->differentiablePtrInterfaceType) - addTypeToDictionary((IRType*)element, elementWitness); - } - return; + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary((IRType*)element, elementWitness); + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) + addTypeToDictionary((IRType*)element, elementWitness); } + return; } + } - addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness()); - if (!as(item->getConcreteType())) - { - addTypeToDictionary( - (IRType*)_lookupWitness( - &subBuilder, - item->getWitness(), - sharedContext->differentialAssocTypeStructKey, - subBuilder.getTypeKind()), - item->getWitness()); - } + // TODO: Is this really needed? + if (!as(item->getBaseType()) && + !as(item->getBaseType())) + { + addTypeToDictionary( + (IRType*)_lookupWitness( + &subBuilder, + item->getWitness(), + sharedContext->differentialAssocTypeStructKey, + subBuilder.getTypeKind()), + item->getWitness()); + } - if (auto diffPairType = as(item->getConcreteType())) - { - // For differential pair types, register the differential type as well. - IRBuilder builder(diffPairType); - builder.setInsertAfter(diffPairType->getWitness()); + // TODO: Is this really needed? + if (auto diffPairType = as(item->getBaseType())) + { + // For differential pair types, register the differential type as well. + IRBuilder builder(diffPairType); + builder.setInsertAfter(diffPairType->getWitness()); - // TODO(sai): lot of this logic is duplicated. need to refactor. + // TODO(sai): lot of this logic is duplicated. need to refactor. + if (!as(diffPairType->getValueType()) && + !as(diffPairType->getValueType())) + { auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? _lookupWitness( @@ -665,12 +1318,28 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) } } +IRWitnessTable* findGlobalWitness(IRInterfaceType* interface, IRInst* type) +{ + for (auto use = type->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as(use->getUser())) + { + if (witnessTable->getConcreteType() == type && + witnessTable->getConformanceType() == interface) + return witnessTable; + } + } + + return nullptr; +} + IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType( IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) return nullptr; @@ -791,8 +1460,8 @@ IRInst* DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table -// of `supType`. +// Given an interface type, return the lookup path from a witness table of `type` to a witness +// table of `supType`. static bool _findInterfaceLookupPathImpl( HashSet& processedTypes, IRInterfaceType* supType, @@ -967,6 +1636,11 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); } + + if (auto annotation = as(globalInst)) + { + addTypeToDictionary((IRType*)annotation->getBaseType(), annotation->getWitness()); + } } } @@ -1071,6 +1745,20 @@ IRType* DifferentiableTypeConformanceContext::differentiateType( } } +IRType* getAssociatedTypeForKey(IRInst* key) +{ + for (auto use = key->firstUse; use; use = use->nextUse) + { + if (auto interfaceReq = as(key)) + { + if (auto assocType = as(interfaceReq->getRequirementVal())) + return assocType; + } + } + + return nullptr; +} + IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness( IRBuilder* builder, IRInst* primalType, @@ -1118,8 +1806,9 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness( } else if (auto lookup = as(primalType)) { - // For types that are lookups from a table, we can simply lookup the witness from the same - // table + // Trivial cases: For types that are lookups from a table, we can simply lookup the + // witness from the same table + // if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { witness = builder->emitLookupInterfaceMethodInst( @@ -1203,8 +1892,8 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( auto p0 = b.emitParam(diffDiffPairType); auto p1 = b.emitParam(diffDiffPairType); - // Since we are already dealing with a DiffPair.Differnetial type, we know that value - // type == diff type. + // Since we are already dealing with a DiffPair.Differnetial type, we know that + // value type == diff type. auto innerAdd = _lookupWitness( &b, innerWitness, @@ -1325,8 +2014,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( auto p0 = b.emitParam(diffArrayType); auto p1 = b.emitParam(diffArrayType); - // Since we are already dealing with a DiffPair.Differnetial type, we know that value - // type == diff type. + // Since we are already dealing with a DiffPair.Differnetial type, we know that + // value type == diff type. auto innerAdd = _lookupWitness( &b, innerWitness, @@ -1566,6 +2255,143 @@ IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness return nullptr; } +IRInst* DifferentiableTypeConformanceContext::emitDAddOfDiffInstType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2) +{ + if (auto arrayType = as(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffElementType = + (IRType*)this->getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto arraySize = arrayType->getElementCount(); + + if (auto constArraySize = as(arraySize)) + { + List args; + for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++) + { + auto index = builder->getIntValue(builder->getIntType(), i); + auto op1Val = builder->emitElementExtract(diffElementType, op1, index); + auto op2Val = builder->emitElementExtract(diffElementType, op2, index); + args.add( + emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val)); + } + auto diffArrayType = + builder->getArrayType(diffElementType, arrayType->getElementCount()); + return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer()); + } + else + { + // TODO: insert a runtime loop here. + SLANG_UNIMPLEMENTED_X("dadd of dynamic array."); + } + } + else if (auto diffPairUserType = as(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffType = (IRType*)this->getDiffTypeFromPairType(builder, diffPairUserType); + auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType); + + auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1); + auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2); + auto primal = + emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2); + + auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1); + auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2); + auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2); + + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff); + } + else if (as(primalType)) + { + // If our type is existential, we need to handle the case where + // one or both of our operands are null-type. + // + return emitDAddForExistentialType(builder, primalType, op1, op2); + } + else if (as(primalType)) + { + // Should not happen. associated type does not have any additional info, we can't + // lookup the necessary methods. + // + SLANG_UNEXPECTED("unexpected associated type during transposition"); + } + + auto addMethod = this->getAddMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(addMethod); + + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + addMethod, + List(op1, op2)); +} + +IRInst* DifferentiableTypeConformanceContext::emitDAddForExistentialType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2) +{ + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + this->getOrCreateExistentialDAddMethod(), + List({op1, op2})); +} + +IRInst* DifferentiableTypeConformanceContext::emitDZeroOfDiffInstType( + IRBuilder* builder, + IRType* primalType) +{ + if (auto arrayType = as(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffElementType = + (IRType*)this->getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount()); + auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType()); + return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); + } + else if (auto diffPairUserType = as(primalType)) + { + // TODO: This case should really not be necessary anymore. + auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType()); + auto diffZero = primalZero; + auto diffType = primalZero->getFullType(); + auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType); + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero); + } + else if (as(primalType) || as(primalType)) + { + // Pack a null value into an existential type. + auto existentialZero = builder->emitMakeExistential( + this->sharedContext->differentiableInterfaceType, + this->emitNullDifferential(builder), + this->sharedContext->nullDifferentialWitness); + + return existentialZero; + } + + auto zeroMethod = this->getZeroMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(zeroMethod); + + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + zeroMethod, + List()); +} + void copyCheckpointHints( IRBuilder* builder, IRGlobalValueWithCode* oldInst, @@ -1883,6 +2709,7 @@ struct AutoDiffPass : public InstPassBase { bool result = false; OrderedHashSet loweredIntermediateTypes; + Dictionary typeToBwdFuncMap; // Replace all `BackwardDiffIntermediateContextType` insts with the struct type // that we generated during backward diff pass. @@ -1906,6 +2733,38 @@ struct AutoDiffPass : public InstPassBase if (type) { loweredIntermediateTypes.add(type); + + auto func = differentiateInst->getFunc(); + + if (auto spec = as(func)) + func = spec->getBase(); + + if (auto generic = as(func)) + { + func = + cast(findGenericReturnVal(generic)); + + auto bwdFuncDecor = func->findDecoration< + IRBackwardDerivativePropagateDecoration>(); + + typeToBwdFuncMap.add( + type, + cast( + as( + bwdFuncDecor->getBackwardDerivativePropagateFunc()) + ->getBase())); + } + else + { + auto bwdFuncDecor = func->findDecoration< + IRBackwardDerivativePropagateDecoration>(); + + typeToBwdFuncMap.add( + type, + cast( + bwdFuncDecor->getBackwardDerivativePropagateFunc())); + } + inst->replaceUsesWith(type); inst->removeAndDeallocate(); changed = true; @@ -1922,7 +2781,9 @@ struct AutoDiffPass : public InstPassBase } // Now we generate the differential type for the intermediate context type // to allow higher order differentiation. - generateDifferentialImplementationForContextType(loweredIntermediateTypes); + generateDifferentialImplementationForContextType( + loweredIntermediateTypes, + typeToBwdFuncMap); return result; } @@ -1977,22 +2838,13 @@ struct AutoDiffPass : public InstPassBase IRInst* addMethod = nullptr; }; - // Register the differential type for an intermediate context type to the derivative functions - // that uses the type. + // Register the differential type for an intermediate context type to the derivative + // functions that uses the type. void registerDiffContextType( IRBuilder& builder, - IRDifferentiableTypeDictionaryDecoration* diffDecor, OrderedDictionary& diffTypes, IRInst* origType) { - HashSet registeredType; - for (auto entry : diffDecor->getChildren()) - { - if (auto e = as(entry)) - { - registeredType.add(e->getOperand(0)); - } - } // Use a work list to recursively walk through all sub fields of the struct type. List wlist; wlist.add(origType); @@ -2002,10 +2854,13 @@ struct AutoDiffPass : public InstPassBase IntermediateContextTypeDifferentialInfo diffInfo; if (!diffTypes.tryGetValue(t, diffInfo)) continue; - if (registeredType.add(t)) - builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness); - else - continue; + + IRInst* args[] = {t, diffInfo.diffWitness}; + builder.emitIntrinsicInst( + builder.getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); if (auto structType = as(getResolvedInstForDecorations(t))) { @@ -2017,7 +2872,9 @@ struct AutoDiffPass : public InstPassBase } } - void generateDifferentialImplementationForContextType(OrderedHashSet& contextTypes) + void generateDifferentialImplementationForContextType( + OrderedHashSet& contextTypes, + Dictionary typeToBwdFuncMap) { // First we are going to topology sort all intermediate context types. OrderedHashSet sortedContextTypes; @@ -2043,6 +2900,10 @@ struct AutoDiffPass : public InstPassBase IRBuilder builder(module); for (auto t : sortedContextTypes) { + auto func = typeToBwdFuncMap[t]; + DifferentiableTypeConformanceContext ctx(this->autodiffContext); + ctx.setFunc(func); + if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType) { // For generics/struct types, we will generate a new generic/struct type @@ -2050,7 +2911,7 @@ struct AutoDiffPass : public InstPassBase SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module); builder.setInsertBefore(t); - auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t); + auto diffInfo = fillDifferentialTypeImplementation(&ctx, diffTypes, t); diffTypes[t] = diffInfo; } else if (auto specialize = as(t)) @@ -2085,30 +2946,29 @@ struct AutoDiffPass : public InstPassBase // function without a intermediate-type via an interface. SLANG_RELEASE_ASSERT(diffTypes.containsKey(t)); } - } - // Register the differential types into the conformance dictionaries of the functions that - // uses them. - for (auto t : diffTypes) - { + if (!diffTypes.containsKey(t)) + continue; + + // If we created a new differential type, we need to place into the contexts of all + // functions that use it. + // HashSet registeredFuncs; - for (auto use = t.key->firstUse; use; use = use->nextUse) + for (auto use = t->firstUse; use; use = use->nextUse) { auto parentFunc = getParentFunc(use->getUser()); if (!parentFunc) continue; if (!registeredFuncs.add(parentFunc)) continue; - if (auto dictDecor = - parentFunc->findDecoration()) - { - registerDiffContextType(builder, dictDecor, diffTypes, t.key); - } + + registerDiffContextType(builder, diffTypes, t); } } } IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct( + DifferentiableTypeConformanceContext* ctx, OrderedDictionary& diffTypes, IRStructType* originalType, IRStructType* diffType) @@ -2122,6 +2982,7 @@ struct AutoDiffPass : public InstPassBase // Generate the fields for all differentiable members of the original struct type. struct FieldInfo { + IRType* primalType; IRStructField* field; IRInst* witness; }; @@ -2130,30 +2991,30 @@ struct AutoDiffPass : public InstPassBase for (auto field : originalType->getFields()) { IRInst* diffFieldWitness = nullptr; - if (auto diffDecor = - field->findDecoration()) - { - diffFieldWitness = diffDecor->getDifferentialWitness(); - } - else + + diffFieldWitness = ctx->tryGetDifferentiableWitness( + &builder, + field->getFieldType(), + DiffConformanceKind::Value); + + if (!diffFieldWitness) { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; diffTypes.tryGetValue(field->getFieldType(), diffFieldTypeInfo); diffFieldWitness = diffFieldTypeInfo.diffWitness; } + if (diffFieldWitness) { FieldInfo info; IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness( - &keyBuilder, - diffFieldWitness, - autodiffContext->differentialAssocTypeStructKey, - builder.getTypeKind()); + auto diffFieldType = ctx->getDifferentialForType(&builder, field->getFieldType()); + info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; + info.primalType = field->getFieldType(); builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey); diffFields.add(info); @@ -2172,16 +3033,10 @@ struct AutoDiffPass : public InstPassBase builder.setInsertInto(zeroMethod); builder.emitBlock(); List fieldVals; + for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness( - &builder, - info.witness, - autodiffContext->zeroMethodStructKey, - autodiffContext->zeroMethodType); - IRInst* val = - builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); - fieldVals.add(val); + fieldVals.add(ctx->emitDZeroOfDiffInstType(&builder, info.primalType)); } builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } @@ -2203,20 +3058,15 @@ struct AutoDiffPass : public InstPassBase List fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness( - &builder, - info.witness, - autodiffContext->addMethodStructKey, - autodiffContext->addMethodType); IRInst* args[2] = { builder .emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder .emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), }; - IRInst* val = - builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args); - fieldVals.add(val); + + fieldVals.add( + ctx->emitDAddOfDiffInstType(&builder, info.primalType, args[0], args[1])); } builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } @@ -2265,6 +3115,7 @@ struct AutoDiffPass : public InstPassBase } IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation( + DifferentiableTypeConformanceContext* ctx, OrderedDictionary& diffTypes, IRInst* originalType) { @@ -2274,6 +3125,7 @@ struct AutoDiffPass : public InstPassBase builder.setInsertBefore(originalType); auto diffType = builder.createStructType(); return fillDifferentialTypeImplementationForStruct( + ctx, diffTypes, as(originalType), as(diffType)); @@ -2286,7 +3138,7 @@ struct AutoDiffPass : public InstPassBase auto structType = as(findGenericReturnVal(genType)); SLANG_RELEASE_ASSERT(structType); - auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType); + auto innerResult = fillDifferentialTypeImplementation(ctx, diffTypes, structType); IRBuilder builder(originalType); builder.setInsertBefore(originalType); @@ -2421,7 +3273,8 @@ struct AutoDiffPass : public InstPassBase { bool changed = false; List autoDiffWorkList; - // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph. + // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call + // graph. processAllReachableInsts( [&](IRInst* inst) { @@ -2438,6 +3291,7 @@ struct AutoDiffPass : public InstPassBase case kIROp_Func: case kIROp_Specialize: case kIROp_LookupWitness: + case kIROp_Generic: if (auto innerFunc = as(getResolvedInstForDecorations(inst->getOperand(0)))) { @@ -2519,8 +3373,8 @@ struct AutoDiffPass : public InstPassBase } // Run transcription logic to generate the body of forward/backward derivatives - // functions. While doing so, we may discover new functions to differentiate, so we keep - // running until the worklist goes dry. + // functions. While doing so, we may discover new functions to differentiate, so we + // keep running until the worklist goes dry. List autodiffCleanupList; while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0) { @@ -2582,10 +3436,10 @@ struct AutoDiffPass : public InstPassBase hasChanges = true; // We have done transcribing the functions, now it is time to demote all - // DifferentialPair types and their operations down to DifferentialPairUserCodeType and - // *UserCode operations so they can be treated just like normal types with no special - // semantics in future processing, and won't be confused with the semantics of a - // DifferentialPair type during future autodiff code gen. + // DifferentialPair types and their operations down to DifferentialPairUserCodeType + // and *UserCode operations so they can be treated just like normal types with no + // special semantics in future processing, and won't be confused with the semantics + // of a DifferentialPair type during future autodiff code gen. rewriteDifferentialPairToUserCode(module); hasChanges |= changed; @@ -2693,8 +3547,8 @@ void checkAutodiffPatterns(TargetProgram* target, IRModule* module, DiagnosticSi if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions func->findDecoration()) { - // If we don't have any side-effect behavior, we should warn (note: read-none is a - // stronger guarantee than no-side-effect) + // If we don't have any side-effect behavior, we should warn (note: read-none is + // a stronger guarantee than no-side-effect) // if (func->findDecoration() || func->findDecoration()) @@ -2759,6 +3613,27 @@ void removeDetachInsts(IRModule* module) pass.processModule(); } + +struct RemoveTypeAnnotationInstsPass : InstPassBase +{ + RemoveTypeAnnotationInstsPass(IRModule* module) + : InstPassBase(module) + { + } + void processModule() + { + processInstsOfType( + kIROp_DifferentiableTypeAnnotation, + [&](IRDifferentiableTypeAnnotation* annotation) { annotation->removeAndDeallocate(); }); + } +}; + +void removeTypeAnnotations(IRModule* module) +{ + RemoveTypeAnnotationInstsPass pass(module); + pass.processModule(); +} + struct LowerNullCheckPass : InstPassBase { LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context) @@ -2841,6 +3716,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module) removeDetachInsts(module); + removeTypeAnnotations(module); + lowerNullCheckInsts(module, &autodiffContext); stripNoDiffTypeAttribute(module); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 2b03f39238..433b6093fd 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -221,6 +221,8 @@ struct DifferentiableTypeConformanceContext IRGlobalValueWithCode* parentFunc = nullptr; OrderedDictionary differentiableTypeWitnessDictionary; + Dictionary> annotationCache; + IRFunc* existentialDAddFunc = nullptr; DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) @@ -235,6 +237,10 @@ struct DifferentiableTypeConformanceContext void setFunc(IRGlobalValueWithCode* func); + List getAnnotations(IRGlobalValueWithCode* inst); + + List getAnnotations(IRModuleInst* inst); + void buildGlobalWitnessDictionary(); // Lookup a witness table for the concreteType. One should exist if concreteType @@ -445,6 +451,20 @@ struct DifferentiableTypeConformanceContext IRBuilder* builder, IRExtractExistentialType* extractExistentialType, DiffConformanceKind target); + + IRInst* emitDAddOfDiffInstType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2); + + IRInst* emitDAddForExistentialType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2); + + IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType); }; @@ -461,9 +481,15 @@ struct DifferentialPairTypeBuilder IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key); - IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst); + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRType* loweredPairType, IRInst* baseInst); + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRType* loweredPairType, IRInst* baseInst); - IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst); + IRInst* emitExistentialMakePair( + IRBuilder* builder, + IRInst* type, + IRInst* primalInst, + IRInst* diffInst); IRStructKey* _getOrCreateDiffStructKey(); @@ -471,17 +497,52 @@ struct DifferentialPairTypeBuilder IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType); + IRInst* _createDiffPairInterfaceRequirement(IRType* origBaseType, IRType* diffType); + IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); + IRInst* getOrCreateCommonDiffPairInterface(IRBuilder* builder); + struct PairStructKey { IRInst* originalType; IRInst* diffType; }; - // Cache from `IRDifferentialPairType` to materialized struct type. + // Cache from pair types to lowered type. Dictionary pairTypeCache; + // Cache from existential pair types to their lowered interface keys. + // We use a different cache because an interface type can have + // a regular pair for the pair of interface types, as well as an + // interface key for the associated pair types used for its implementations + // + Dictionary existentialPairTypeCache; + + // Cache for any interface requirement keys (generated for existential + // pair types) + // + Dictionary assocPairTypeKeyMap; + Dictionary makePairKeyMap; + Dictionary getPrimalKeyMap; + Dictionary getDiffKeyMap; + + // More caches for easier lookups of the types associated with the + // keys. (avoid having to keep recomputing or performing complicated + // lookups) + // + Dictionary makePairFuncTypeMap; + Dictionary getPrimalFuncTypeMap; + Dictionary getDiffFuncTypeMap; + + // Even more caches for easier access to original primal/diff types + // (Only used for existential pair types). For regular pair types, + // these are easy to find right on the type itself. + // + Dictionary primalTypeMap; + Dictionary diffTypeMap; + + IRStructKey* globalPrimalKey = nullptr; IRStructKey* globalDiffKey = nullptr; @@ -491,6 +552,8 @@ struct DifferentialPairTypeBuilder List generatedTypeList; AutoDiffSharedContext* sharedContext = nullptr; + + IRInterfaceType* commonDiffPairInterface = nullptr; }; void stripAutoDiffDecorations(IRModule* module); @@ -551,6 +614,10 @@ inline bool isRelevantDifferentialPair(IRType* type) return false; } +bool isRuntimeType(IRType* type); + +IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type); + UIndex addPhiOutputArg( IRBuilder* builder, IRBlock* block, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 88a9ac5e33..38e5f8869f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1282,6 +1282,9 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) +/* Differentiable Type Annotation (for run-time types)*/ +INST(DifferentiableTypeAnnotation, DifferentiableTypeAnnotation, 2, HOISTABLE) + INST(BeginFragmentShaderInterlock, BeginFragmentShaderInterlock, 0, 0) INST(EndFragmentShaderInterlock, BeginFragmentShaderInterlock, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 53adce87a8..a288bca97a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1334,6 +1334,18 @@ struct IRPrimalSubstitute : IRInst IR_LEAF_ISA(PrimalSubstitute) }; +struct IRDifferentiableTypeAnnotation : IRInst +{ + enum + { + kOp = kIROp_DifferentiableTypeAnnotation + }; + IRInst* getBaseType() { return getOperand(0); } + IRInst* getWitness() { return getOperand(1); } + + IR_LEAF_ISA(DifferentiableTypeAnnotation) +}; + struct IRDispatchKernel : IRInst { enum diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 28bb63a87b..d60903cfc2 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1886,7 +1886,8 @@ LinkedIR linkIR(CodeGenContext* codeGenContext) { // We need to copy over exported symbols, // and any global parameters if preserve-params option is set. - if (_isHLSLExported(inst) || shouldCopyGlobalParams && as(inst)) + if (_isHLSLExported(inst) || shouldCopyGlobalParams && as(inst) || + as(inst)) { auto cloned = cloneValue(context, inst); if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration)) diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 4fd162e53b..b01abcbc55 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -228,8 +228,6 @@ void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSin checkTypeConformanceExists(&sharedContext); - inferAnyValueSizeWhereNecessary(targetProgram, module); - // Replace all `makeExistential` insts with `makeExistentialWithRTTI` // before making any other changes. This is necessary because a parameter of // generic type will be lowered into `AnyValueType`, and after that we can no longer diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 50dfa2c6a3..40cd40758a 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -51,15 +51,17 @@ struct SpecializationContext IRModule* module; DiagnosticSink* sink; TargetProgram* targetProgram; + SpecializationOptions options; bool changed = false; - SpecializationContext(IRModule* inModule, TargetProgram* target) + SpecializationContext(IRModule* inModule, TargetProgram* target, SpecializationOptions options) : workList(*inModule->getContainerPool().getList()) , workListSet(*inModule->getContainerPool().getHashSet()) , cleanInsts(*inModule->getContainerPool().getHashSet()) , module(inModule) , targetProgram(target) + , options(options) { } ~SpecializationContext() @@ -1102,7 +1104,11 @@ struct SpecializationContext // Now we consider lower lookupWitnessMethod insts into dynamic dispatch calls, // which may open up more specialization opportunities. // - iterChanged = lowerWitnessLookup(module, sink); + if (options.lowerWitnessLookups) + { + iterChanged = lowerWitnessLookup(module, sink); + } + if (!iterChanged || sink->getErrorCount()) break; } @@ -2882,10 +2888,14 @@ struct SpecializationContext } }; -bool specializeModule(TargetProgram* target, IRModule* module, DiagnosticSink* sink) +bool specializeModule( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink, + SpecializationOptions options) { SLANG_PROFILE; - SpecializationContext context(module, target); + SpecializationContext context(module, target, options); context.sink = sink; context.processModule(); return context.changed; diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 734c76427c..72f2c6130f 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -7,8 +7,20 @@ struct IRModule; class DiagnosticSink; class TargetProgram; +struct SpecializationOptions +{ + // Option that allows specializeModule to generate dynamic-dispatch code + // wherever possible to open up more specialization opportunities. + // + bool lowerWitnessLookups = false; +}; + /// Specialize generic and interface-based code to use concrete types. -bool specializeModule(TargetProgram* target, IRModule* module, DiagnosticSink* sink); +bool specializeModule( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink, + SpecializationOptions options); void finalizeSpecialization(IRModule* module); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 011ea6bc76..e82fc03fde 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -592,11 +592,21 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // Callback function to call when after lowering a type. + std::function lowerTypeCallback = + nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr) { } + void registerTypeCallback( + std::function callback) + { + lowerTypeCallback = callback; + } + void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); } void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; } @@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type) { ValLoweringVisitor visitor; visitor.context = context; - return (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + + if (context->lowerTypeCallback && loweredType) + context->lowerTypeCallback(context, type, loweredType); + + return loweredType; } void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl) @@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; subContextStorage.returnDestination = LoweredValInfo(); + subContextStorage.lowerTypeCallback = nullptr; } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - context->setValue(decl, LoweredValInfo::simple(finalVal)); + context->setGlobalValue(decl, LoweredValInfo::simple(finalVal)); subBuilder->setInsertBefore(irInterface); @@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor } } - addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); if (auto anyValueSizeAttr = decl->findModifier()) @@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor else outerGeneric = emitOuterGenerics(subContext, decl, decl); + // If our function is differentiable, register a callback so the derivative + // annotations for types can be lowered. + // + if (auto diffAttr = decl->findModifier()) + { + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); + OrderedDictionary resolveddiffTypeWitnessMap; + + // Go through each entry in the map and resolve the key. + for (auto& entry : diffTypeWitnessMap) + { + auto resolvedKey = as(entry.key->resolve()); + resolveddiffTypeWitnessMap[resolvedKey] = + as(as(entry.value)->resolve()); + } + + subContext->registerTypeCallback( + [=](IRGenContext* context, Type* type, IRType* irType) + { + if (!as(type)) + return irType; + + DeclRefBase* declRefBase = as(type)->getDeclRefBase(); + if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + { + auto irWitness = + lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + if (irWitness) + { + IRInst* args[] = {irType, irWitness}; + context->irBuilder->emitIntrinsicInst( + context->irBuilder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + } + } + + return irType; + }); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor } } + subContext->registerTypeCallback(nullptr); + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); addSpecializedForTargetDecorations(irFunc, decl); @@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor } } - if (auto diffAttr = decl->findModifier()) - { - if (decl->body) - { - subContext->irBuilder->setInsertInto(irFunc->getParent()); - lowerDifferentiableAttribute(subContext, irFunc, diffAttr); - subContext->irBuilder->setInsertInto(irFunc); - } - } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list diff --git a/tests/autodiff/dynamic-dispatch-material.slang b/tests/autodiff/dynamic-dispatch-material.slang new file mode 100644 index 0000000000..1185a92e7f --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-material.slang @@ -0,0 +1,142 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +//TEST_INPUT: set g_materials = new StructuredBuffer[new MaterialDataBlob{new MaterialHeader{[0, 0, 0, 0]}, new MaterialPayload{[1.0, 1.2, 0.3, 0.5]}}]; +RWStructuredBuffer g_materials; + +public struct ShadingInput +{ + public float scale; +} + +struct MaterialHeader +{ + uint4 header; +}; +struct MaterialPayload +{ + float4 data; +}; +struct MaterialDataBlob +{ + MaterialHeader header; // 16B + MaterialPayload payload; // 16B +}; + +interface IMaterial : IDifferentiable +{ + associatedtype MaterialInstance : IMaterialInstance; + + [Differentiable] + MaterialInstance setupMaterialInstance( ShadingInput input ); +} + +interface IMaterialInstance : IDifferentiable +{ + [Differentiable] + float eval( float x ); +} + + +[BackwardDerivative(getMaterial_bwd)] +IMaterial getMaterial(int id) +{ + return createDynamicObject(id, g_materials[id]); +} + +void getMaterial_bwd(int id, IDifferentiable d) +{ + // Something random + outputBuffer[id] = 2.f; +} + +struct Material1: IMaterial +{ + typedef MaterialInstance1 MaterialInstance; + + MaterialHeader header; + float a; + float b; + float c; + + [Differentiable] + MaterialInstance1 setupMaterialInstance( ShadingInput input ) + { + MaterialInstance1 instance; + instance.a = a * input.scale; + instance.b = b * input.scale; + instance.c = c * input.scale; + return instance; + } + +} +struct MaterialInstance1: IMaterialInstance +{ + float a; + float b; + float c; + + [Differentiable] + float eval( float x ) + { + return a * x * x + b * x + c; + } +} + +struct Material2: IMaterial +{ + typedef MaterialInstance2 MaterialInstance; + + MaterialHeader header; + float a; + float b; + + [Differentiable] + MaterialInstance2 setupMaterialInstance( ShadingInput input ) + { + MaterialInstance2 instance; + instance.a = a * input.scale * input.scale; + instance.b = b * input.scale * input.scale; + return instance; + } + +} +public struct MaterialInstance2: IMaterialInstance +{ + float a; + float b; + + [Differentiable] + public float eval( float x ) + { + return a * x + b; + } +} + +[Differentiable] +public float shade(int material, ShadingInput input, float x) +{ + IMaterial m = getMaterial(material); + IMaterialInstance mi = m.setupMaterialInstance(input); + return mi.eval(x); +} + +//TEST_INPUT: type_conformance Material1:IMaterial = 0 +//TEST_INPUT: type_conformance Material2:IMaterial = 1 + +[shader("compute")] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = shade(0, {0.5}, 0.6); + + // TODO: VERIFY + DifferentialPair dpx = diffPair(3.0); + bwd_diff(shade)(0, {0.5}, dpx, 1.0); + + outputBuffer[3] = dpx.d; +} \ No newline at end of file diff --git a/tests/autodiff/dynamic-dispatch-material.slang.expected.txt b/tests/autodiff/dynamic-dispatch-material.slang.expected.txt new file mode 100644 index 0000000000..49e4a0e6b7 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-material.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +0.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/material2/DiffuseMaterial.slang b/tests/autodiff/material2/DiffuseMaterial.slang index d4ae579ffc..1422ee30cf 100644 --- a/tests/autodiff/material2/DiffuseMaterial.slang +++ b/tests/autodiff/material2/DiffuseMaterial.slang @@ -21,6 +21,11 @@ public struct DiffuseMaterial : IMaterial float3 albedo = getAlbedo(baseColor); DiffuseMaterialInstance mi = { baseColor }; for (uint i = 0; i < 3; i++) miData.data[i] = albedo[i]; + + // unused + miData.data[3] = 0.f; + miData.data[4] = 0.f; + return mi; } }; diff --git a/tests/autodiff/material2/GlossyMaterial.slang b/tests/autodiff/material2/GlossyMaterial.slang index e937443d43..1070c6e63d 100644 --- a/tests/autodiff/material2/GlossyMaterial.slang +++ b/tests/autodiff/material2/GlossyMaterial.slang @@ -35,6 +35,11 @@ public struct GlossyMaterial : IMaterial float3 albedo = getAlbedo(baseColor); GlossyMaterialInstance mi = { baseColor }; for (uint i = 0; i < 3; i++) miData.data[i] = albedo[i]; + + // unused + miData.data[3] = 0.f; + miData.data[4] = 0.f; + return mi; } }; diff --git a/tests/autodiff/material2/MxLayeredMaterial.slang b/tests/autodiff/material2/MxLayeredMaterial.slang index 8c5cd6d3c7..fabfde80cd 100644 --- a/tests/autodiff/material2/MxLayeredMaterial.slang +++ b/tests/autodiff/material2/MxLayeredMaterial.slang @@ -57,6 +57,11 @@ public struct MxLayeredMaterial : IMaterial float3 albedo = getAlbedo(baseColor); UsedMaterialInstance mi; for (uint i = 0; i < 3; i++) miData.data[i] = albedo[i]; + + // unused + miData.data[3] = 0.f; + miData.data[4] = 0.f; + return mi; } }