Skip to content

Commit

Permalink
Various WGSL fixes. (#5490)
Browse files Browse the repository at this point in the history
* [WGSL] make sure switch has a default label.

* Various WGSL fixes.

* Update rhi submodule commit

* format code

* Remove unnecessary DISABLE_TEST directive on not applicable test.

* Matrix comp mul + `select`.

* Legalize binary ops for wgsl.

---------

Co-authored-by: slangbot <[email protected]>
  • Loading branch information
csyonghe and slangbot authored Nov 5, 2024
1 parent 2c8dacf commit 7c2ff54
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 13 deletions.
2 changes: 1 addition & 1 deletion external/slang-rhi
126 changes: 121 additions & 5 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
}
}

static bool isStaticConst(IRInst* inst)
{
if (inst->getParent()->getOp() == kIROp_Module)
{
return true;
}
switch (inst->getOp())
{
case kIROp_MakeVector:
case kIROp_swizzle:
case kIROp_swizzleSet:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_BitCast:
{
for (UInt i = 0; i < inst->getOperandCount(); i++)
{
if (!isStaticConst(inst->getOperand(i)))
return false;
}
return true;
}
}
return false;
}

void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
{
switch (varDecl->getOp())
Expand All @@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
case kIROp_GlobalVar:
case kIROp_Var: m_writer->emit("var"); break;
default:
if (as<IRModuleInst>(varDecl->getParent()))
{
if (isStaticConst(varDecl))
m_writer->emit("const");
}
else
{
m_writer->emit("var");
}
break;
}

Expand Down Expand Up @@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst)
}
}

bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
{
bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst);
if (result)
{
// If inst is a matrix, and is used in a component-wise multiply,
// we need to not fold it.
if (as<IRMatrixType>(inst->getDataType()))
{
for (auto use = inst->firstUse; use; use = use->nextUse)
{
auto user = use->getUser();
if (user->getOp() == kIROp_Mul)
{
if (as<IRMatrixType>(user->getOperand(0)->getDataType()) &&
as<IRMatrixType>(user->getOperand(1)->getDataType()))
{
return false;
}
}
}
}
}
return result;
}


bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
Expand Down Expand Up @@ -1126,6 +1177,71 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
break;

case kIROp_GetStringHash:
{
auto getStringHashInst = as<IRGetStringHash>(inst);
auto stringLit = getStringHashInst->getStringLit();

if (stringLit)
{
auto slice = stringLit->getStringSlice();
emitType(inst->getDataType());
m_writer->emit("(");
m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash);
m_writer->emit(")");
}
else
{
// Couldn't handle
diagnoseUnhandledInst(inst);
}
return true;
}

case kIROp_Mul:
{
if (!as<IRMatrixType>(inst->getOperand(0)->getDataType()) ||
!as<IRMatrixType>(inst->getOperand(1)->getDataType()))
{
return false;
}
// Mul(m1, m2) should be translated to component-wise multiplication in WGSL.
auto matrixType = as<IRMatrixType>(inst->getDataType());
auto rowCount = getIntVal(matrixType->getRowCount());
emitType(inst->getDataType());
m_writer->emit("(");
for (IRIntegerValue i = 0; i < rowCount; i++)
{
if (i != 0)
{
m_writer->emit(", ");
}
emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix));
m_writer->emit("[");
m_writer->emit(i);
m_writer->emit("] * ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix));
m_writer->emit("[");
m_writer->emit(i);
m_writer->emit("]");
}
m_writer->emit(")");

return true;
}

case kIROp_Select:
{
m_writer->emit("select(");
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(")");
return true;
}
}

return false;
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter

void emit(const AddressSpace addressSpace);

virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;

private:
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
void emitMatrixType(
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -4021,6 +4021,7 @@ struct IRBuilder
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(IRType* type, UInt argCount, IRInst* const* args);
IRInst* emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue);
IRInst* emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue);

IRInst* emitMakeVector(IRType* type, List<IRInst*> const& args)
{
Expand Down
90 changes: 89 additions & 1 deletion source/slang/slang-ir-wgsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct LegalizeWGSLEntryPointContext
String* optionalSemanticIndex,
IRInst* parentVar);
void legalizeCall(IRCall* call);
void legalizeSwitch(IRSwitch* switchInst);
void legalizeBinaryOp(IRInst* inst);
void processInst(IRInst* inst);
};

Expand Down Expand Up @@ -349,11 +351,97 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
}
}

void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
{
// WGSL Requires all switch statements to contain a default case.
// If the switch statement does not contain a default case, we will add one.
if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
return;
IRBuilder builder(switchInst);
auto defaultBlock = builder.createBlock();
builder.setInsertInto(defaultBlock);
builder.emitBranch(switchInst->getBreakLabel());
defaultBlock->insertBefore(switchInst->getBreakLabel());
List<IRInst*> cases;
for (UInt i = 0; i < switchInst->getCaseCount(); i++)
{
cases.add(switchInst->getCaseValue(i));
cases.add(switchInst->getCaseLabel(i));
}
builder.setInsertBefore(switchInst);
auto newSwitch = builder.emitSwitch(
switchInst->getCondition(),
switchInst->getBreakLabel(),
defaultBlock,
(UInt)cases.getCount(),
cases.getBuffer());
switchInst->transferDecorationsTo(newSwitch);
switchInst->removeAndDeallocate();
}

void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst)
{
auto isVectorOrMatrix = [](IRType* type)
{
switch (type->getOp())
{
case kIROp_VectorType:
case kIROp_MatrixType: return true;
default: return false;
}
};
if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
as<IRBasicType>(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newRhs = builder.emitMakeCompositeFromScalar(
inst->getOperand(0)->getDataType(),
inst->getOperand(1));
builder.replaceOperand(inst->getOperands() + 1, newRhs);
}
else if (
as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
isVectorOrMatrix(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newLhs = builder.emitMakeCompositeFromScalar(
inst->getOperand(1)->getDataType(),
inst->getOperand(0));
builder.replaceOperand(inst->getOperands(), newLhs);
}
}

void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
case kIROp_Switch: legalizeSwitch(as<IRSwitch>(inst)); break;

// For all binary operators, make sure both side of the operator have the same type
// (vector-ness and matrix-ness).
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
case kIROp_Div:
case kIROp_FRem:
case kIROp_IRem:
case kIROp_And:
case kIROp_Or:
case kIROp_BitAnd:
case kIROp_BitOr:
case kIROp_BitXor:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_Eql:
case kIROp_Neq:
case kIROp_Greater:
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq: legalizeBinaryOp(inst); break;

default:
for (auto child : inst->getModifiableChildren())
processInst(child);
Expand Down
11 changes: 11 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4162,6 +4162,17 @@ IRInst* IRBuilder::emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue)
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
}

IRInst* IRBuilder::emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue)
{
switch (type->getOp())
{
case kIROp_VectorType: return emitMakeVectorFromScalar(type, scalarValue);
case kIROp_MatrixType: return emitMakeMatrixFromScalar(type, scalarValue);
case kIROp_ArrayType: return emitMakeArrayFromElement(type, scalarValue);
default: SLANG_UNEXPECTED("unhandled composite type"); UNREACHABLE_RETURN(nullptr);
}
}

IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
{
return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);
Expand Down
2 changes: 1 addition & 1 deletion tests/autodiff-dstdlib/dstdlib-abs.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
Expand Down
2 changes: 1 addition & 1 deletion tests/autodiff/matrix-arithmetic-fwd.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
Expand Down
1 change: 1 addition & 0 deletions tests/autodiff/reverse-loop-checkpoint-test.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates

Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/nested-switch.slang
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE:-wgpu

int test(int t, int r)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//TEST:COMPILE: -entry computeMain -stage compute -target callable tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu

// Not available on non PS shader
// dx.op.writeSamplerFeedback WriteSamplerFeedback
Expand Down
2 changes: 1 addition & 1 deletion tests/ir/string-literal-hash.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE:-wgpu

// Note: disabled on CPU target until we can fill
// in a more correct/complete `String` and `getStringHash`
Expand Down
2 changes: 1 addition & 1 deletion tests/language-feature/constants/constexpr-loop.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST(compute):COMPARE_COMPUTE_EX: -wgpu -compute -output-using-type

//TEST_INPUT: set g_texture = Texture2D(size=8, content = one)
//TEST_INPUT: set g_sampler = Sampler
Expand Down
Binary file removed tests/library/linked.spirv
Binary file not shown.

0 comments on commit 7c2ff54

Please sign in to comment.