From 1863fe1deecc3f5b15b45020105f6cadcc2f9999 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 2 May 2024 16:48:27 -0700 Subject: [PATCH] Support generic constraints that are dependent on another generic param. (#4091) --- source/slang/slang-ir-link.cpp | 29 ++++++++- source/slang/slang-lower-to-ir.cpp | 11 ++-- .../generics/generic-witness-derived.slang | 64 +++++++++++++++++++ 3 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 tests/language-feature/generics/generic-witness-derived.slang diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 4871062d4d..e652745e78 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -746,6 +746,12 @@ void cloneGlobalValueWithCodeCommon( { IRBlock* ob = originalValue->getFirstBlock(); IRBlock* cb = clonedValue->getFirstBlock(); + struct ParamCloneInfo + { + IRParam* originalParam; + IRParam* clonedParam; + }; + ShortList paramCloneInfos; while (ob) { SLANG_ASSERT(cb); @@ -753,9 +759,28 @@ void cloneGlobalValueWithCodeCommon( builder->setInsertInto(cb); for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst()) { - cloneInst(context, builder, oi); + if (oi->getOp() == kIROp_Param) + { + // Params may have forward references in its type and + // decorations, so we just create a placeholder for it + // in this first pass. + IRParam* clonedParam = builder->emitParam(nullptr); + registerClonedValue(context, clonedParam, oi); + paramCloneInfos.add({ (IRParam*)oi, clonedParam }); + } + else + { + cloneInst(context, builder, oi); + } + } + // Clone the type and decorations of parameters after all instructs in the block + // have been cloned. + for (auto param : paramCloneInfos) + { + builder->setInsertInto(param.clonedParam); + param.clonedParam->setFullType((IRType*)cloneValue(context, param.originalParam->getFullType())); + cloneDecorations(context, param.clonedParam, param.originalParam); } - ob = ob->getNextBlock(); cb = cb->getNextBlock(); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a78110a846..2bf5f1e964 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8823,7 +8823,7 @@ struct DeclLoweringVisitor : DeclVisitor IRGenContext* subContext, GenericTypeConstraintDecl* constraintDecl) { - auto supType = lowerType(context, constraintDecl->sup.type); + auto supType = lowerType(subContext, constraintDecl->sup.type); auto value = emitGenericConstraintValue(subContext, constraintDecl, supType); subContext->setValue(constraintDecl, LoweredValInfo::simple(value)); } @@ -8972,9 +8972,11 @@ struct DeclLoweringVisitor : DeclVisitor auto operand = value->getOperand(i); markInstsToClone(valuesToClone, parentBlock, operand); } + if (value->getFullType()) + markInstsToClone(valuesToClone, parentBlock, value->getFullType()); + for (auto child : value->getDecorationsAndChildren()) + markInstsToClone(valuesToClone, parentBlock, child); } - for (auto child : value->getChildren()) - markInstsToClone(valuesToClone, parentBlock, child); auto parent = parentBlock->getParent(); while (parent && parent != parentBlock) { @@ -9025,7 +9027,8 @@ struct DeclLoweringVisitor : DeclVisitor markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), returnType); // For Function Types, we always clone all generic parameters regardless of whether // the generic parameter appears in the function signature or not. - if (returnType->getOp() == kIROp_FuncType) + if (returnType->getOp() == kIROp_FuncType || + returnType->getOp() == kIROp_Generic) { for (auto genericParam : parentGeneric->getParams()) { diff --git a/tests/language-feature/generics/generic-witness-derived.slang b/tests/language-feature/generics/generic-witness-derived.slang new file mode 100644 index 0000000000..e9659102c3 --- /dev/null +++ b/tests/language-feature/generics/generic-witness-derived.slang @@ -0,0 +1,64 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that we can compile a generic function with a generic type constraint that is dependent on an +// outer generic type parameter. + +namespace ns{ + + public interface IBinaryElementWiseFunction + { + public static T call(const in T lhs, const in T rhs); + } + public struct AddOp : IBinaryElementWiseFunction + { + public static T call(const in T lhs, const in T rhs) + { + return lhs + rhs; + } + } + public struct BinaryElementWiseInputData + { + T lhs; + T rhs; + + // Note: `U` is constrainted by `IBinaryElementWiseFunction`, which is dependent on `T`, + // that is another generic type parameter defined on the outer type. + // This eventually leads to a IRGeneric where one param has a type that is dependent on + // another param. + // In this case, the IR for `test` after generic flattening will be: + // ``` + // %g_test = IRGeneric + // { + // IRBlock + // { + // %T = IRParam : Type; + // %T_w = IRParam : IRWitnessTableType; + // %U = IRParam : Type; + // %U_w = IRRaram : IRWitnessTableType<%s>; // note that the type here is a forward reference to %s + // %s = specialize(%IBinaryElementWiseFunction, %T) // %s is dependent on %T. + // ... + // } + // } + // + public T test>(U x) + { + return x.call(lhs ,rhs); + } + } +} + + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint3 threadId: SV_DispatchThreadID) +{ + ns::BinaryElementWiseInputData cb; + cb.lhs = threadId.x + 1; + cb.rhs = 2; + // CHECK: 3 + outputBuffer[0] = cb.test(ns::AddOp()); +} \ No newline at end of file