Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lowering logic of nested generics. #4091

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions source/slang/slang-ir-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,41 @@ void cloneGlobalValueWithCodeCommon(
{
IRBlock* ob = originalValue->getFirstBlock();
IRBlock* cb = clonedValue->getFirstBlock();
struct ParamCloneInfo
{
IRParam* originalParam;
IRParam* clonedParam;
};
ShortList<ParamCloneInfo> paramCloneInfos;
while (ob)
{
SLANG_ASSERT(cb);

builder->setInsertInto(cb);
for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst())
{
cloneInst(context, builder, oi);
if (oi->getOp() == kIROp_Param)
{
// Params may have forward references in its type and
// decorations, so we just create a placeholder for it
// in this first pass.
IRParam* clonedParam = builder->emitParam(nullptr);
registerClonedValue(context, clonedParam, oi);
paramCloneInfos.add({ (IRParam*)oi, clonedParam });
}
else
{
cloneInst(context, builder, oi);
}
}
// Clone the type and decorations of parameters after all instructs in the block
// have been cloned.
for (auto param : paramCloneInfos)
{
builder->setInsertInto(param.clonedParam);
param.clonedParam->setFullType((IRType*)cloneValue(context, param.originalParam->getFullType()));
cloneDecorations(context, param.clonedParam, param.originalParam);
}

ob = ob->getNextBlock();
cb = cb->getNextBlock();
}
Expand Down
11 changes: 7 additions & 4 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8823,7 +8823,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IRGenContext* subContext,
GenericTypeConstraintDecl* constraintDecl)
{
auto supType = lowerType(context, constraintDecl->sup.type);
auto supType = lowerType(subContext, constraintDecl->sup.type);
auto value = emitGenericConstraintValue(subContext, constraintDecl, supType);
subContext->setValue(constraintDecl, LoweredValInfo::simple(value));
}
Expand Down Expand Up @@ -8972,9 +8972,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto operand = value->getOperand(i);
markInstsToClone(valuesToClone, parentBlock, operand);
}
if (value->getFullType())
markInstsToClone(valuesToClone, parentBlock, value->getFullType());
for (auto child : value->getDecorationsAndChildren())
markInstsToClone(valuesToClone, parentBlock, child);
}
for (auto child : value->getChildren())
markInstsToClone(valuesToClone, parentBlock, child);
auto parent = parentBlock->getParent();
while (parent && parent != parentBlock)
{
Expand Down Expand Up @@ -9025,7 +9027,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), returnType);
// For Function Types, we always clone all generic parameters regardless of whether
// the generic parameter appears in the function signature or not.
if (returnType->getOp() == kIROp_FuncType)
if (returnType->getOp() == kIROp_FuncType ||
returnType->getOp() == kIROp_Generic)
{
for (auto genericParam : parentGeneric->getParams())
{
Expand Down
64 changes: 64 additions & 0 deletions tests/language-feature/generics/generic-witness-derived.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type

// Test that we can compile a generic function with a generic type constraint that is dependent on an
// outer generic type parameter.

namespace ns{

public interface IBinaryElementWiseFunction<T>
{
public static T call(const in T lhs, const in T rhs);
}
public struct AddOp<T : IArithmetic> : IBinaryElementWiseFunction<T>
{
public static T call(const in T lhs, const in T rhs)
{
return lhs + rhs;
}
}
public struct BinaryElementWiseInputData<T : IArithmetic>
{
T lhs;
T rhs;

// Note: `U` is constrainted by `IBinaryElementWiseFunction<T>`, which is dependent on `T`,
// that is another generic type parameter defined on the outer type.
// This eventually leads to a IRGeneric where one param has a type that is dependent on
// another param.
// In this case, the IR for `test` after generic flattening will be:
// ```
// %g_test = IRGeneric
// {
// IRBlock
// {
// %T = IRParam : Type;
// %T_w = IRParam : IRWitnessTableType<IArithmetic>;
// %U = IRParam : Type;
// %U_w = IRRaram : IRWitnessTableType<%s>; // note that the type here is a forward reference to %s
// %s = specialize(%IBinaryElementWiseFunction, %T) // %s is dependent on %T.
// ...
// }
// }
//
public T test<U : IBinaryElementWiseFunction<T>>(U x)
{
return x.call(lhs ,rhs);
}
}
}


//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;

[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId: SV_DispatchThreadID)
{
ns::BinaryElementWiseInputData<int> cb;
cb.lhs = threadId.x + 1;
cb.rhs = 2;
// CHECK: 3
outputBuffer[0] = cb.test(ns::AddOp<int>());
}
Loading