Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various WGSL fixes. #5490

Merged
merged 9 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/slang-rhi
127 changes: 122 additions & 5 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
}
}

static bool isStaticConst(IRInst* inst)
{
if (inst->getParent()->getOp() == kIROp_Module)
{
return true;
}
switch (inst->getOp())
{
case kIROp_MakeVector:
case kIROp_swizzle:
case kIROp_swizzleSet:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_BitCast:
{
for (UInt i = 0; i < inst->getOperandCount(); i++)
{
if (!isStaticConst(inst->getOperand(i)))
return false;
}
return true;
}
}
return false;
}

void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
{
switch (varDecl->getOp())
Expand All @@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
case kIROp_GlobalVar:
case kIROp_Var: m_writer->emit("var"); break;
default:
if (as<IRModuleInst>(varDecl->getParent()))
{
if (isStaticConst(varDecl))
m_writer->emit("const");
}
else
{
m_writer->emit("var");
}
break;
}

Expand Down Expand Up @@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst)
}
}

bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
{
bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst);
if (result)
{
// If inst is a matrix, and is used in a component-wise multiply,
// we need to not fold it.
if (as<IRMatrixType>(inst->getDataType()))
{
for (auto use = inst->firstUse; use; use = use->nextUse)
{
auto user = use->getUser();
if (user->getOp() == kIROp_Mul)
{
if (as<IRMatrixType>(user->getOperand(0)->getDataType()) &&
as<IRMatrixType>(user->getOperand(1)->getDataType()))
{
return false;
}
}
}
}
}
return result;
}


bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
Expand Down Expand Up @@ -1126,6 +1177,72 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
break;

case kIROp_GetStringHash:
{
auto getStringHashInst = as<IRGetStringHash>(inst);
auto stringLit = getStringHashInst->getStringLit();

if (stringLit)
{
auto slice = stringLit->getStringSlice();
emitType(inst->getDataType());
m_writer->emit("(");
m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash);
m_writer->emit(")");
}
else
{
// Couldn't handle
diagnoseUnhandledInst(inst);
}
return true;
}

case kIROp_Mul:
{
if (!as<IRMatrixType>(inst->getOperand(0)->getDataType()) ||
!as<IRMatrixType>(inst->getOperand(1)->getDataType()))
{
return false;
}
// Mul(m1, m2) should be translated to component-wise multiplication in WGSL.
auto matrixType = as<IRMatrixType>(inst->getDataType());
auto rowCount = getIntVal(matrixType->getRowCount());
emitType(inst->getDataType());
m_writer->emit("(");
for (IRIntegerValue i = 0; i < rowCount; i++)
{
if (i != 0)
{
m_writer->emit(", ");
}
emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix));
m_writer->emit("[");
m_writer->emit(i);
m_writer->emit("] * ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix));
m_writer->emit("[");
m_writer->emit(i);
m_writer->emit("]");
}
m_writer->emit(")");

return true;
}

case kIROp_Select:
{
auto prec = getInfo(EmitOp::Conditional);
m_writer->emit("select(");
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(")");
return true;
}
}

return false;
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter

void emit(const AddressSpace addressSpace);

virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;

private:
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
void emitMatrixType(
Expand Down
32 changes: 31 additions & 1 deletion source/slang/slang-ir-wgsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct LegalizeWGSLEntryPointContext
String* optionalSemanticIndex,
IRInst* parentVar);
void legalizeCall(IRCall* call);
void legalizeSwitch(IRSwitch* switchInst);
void processInst(IRInst* inst);
};

Expand Down Expand Up @@ -349,11 +350,40 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
}
}

void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
{
// WGSL Requires all switch statements to contain a default case.
// If the switch statement does not contain a default case, we will add one.
if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
return;
IRBuilder builder(switchInst);
auto defaultBlock = builder.createBlock();
builder.setInsertInto(defaultBlock);
builder.emitBranch(switchInst->getBreakLabel());
defaultBlock->insertBefore(switchInst->getBreakLabel());
List<IRInst*> cases;
for (UInt i = 0; i < switchInst->getCaseCount(); i++)
{
cases.add(switchInst->getCaseValue(i));
cases.add(switchInst->getCaseLabel(i));
}
builder.setInsertBefore(switchInst);
auto newSwitch = builder.emitSwitch(
switchInst->getCondition(),
switchInst->getBreakLabel(),
defaultBlock,
(UInt)cases.getCount(),
cases.getBuffer());
switchInst->transferDecorationsTo(newSwitch);
switchInst->removeAndDeallocate();
}

void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
case kIROp_Switch: legalizeSwitch(as<IRSwitch>(inst)); break;
default:
for (auto child : inst->getModifiableChildren())
processInst(child);
Expand Down
2 changes: 1 addition & 1 deletion tests/autodiff/matrix-arithmetic-fwd.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
Expand Down
1 change: 1 addition & 0 deletions tests/autodiff/reverse-loop-checkpoint-test.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates

Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/nested-switch.slang
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE:-wgpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it does, but shouldn't this have -shaderobj like other lines?

//TEST(compute):COMPARE_COMPUTE:-wgpu -shaderobj

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That flag is no longer meaningful. Every test is -shaderobj now, and this flag should be removed from all tests.


int test(int t, int r)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//TEST:COMPILE: -entry computeMain -stage compute -target callable tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu

// Not available on non PS shader
// dx.op.writeSamplerFeedback WriteSamplerFeedback
Expand Down
2 changes: 1 addition & 1 deletion tests/ir/string-literal-hash.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE:-wgpu

// Note: disabled on CPU target until we can fill
// in a more correct/complete `String` and `getStringHash`
Expand Down
2 changes: 1 addition & 1 deletion tests/language-feature/constants/constexpr-loop.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE_EX: -wgpu -compute -output-using-type

//TEST_INPUT: set g_texture = Texture2D(size=8, content = one)
//TEST_INPUT: set g_sampler = Sampler
Expand Down
Binary file removed tests/library/linked.spirv
Binary file not shown.
Loading