Skip to content

Commit

Permalink
Fix handling of pointer logic in wgsl backend. (#5129)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Sep 21, 2024
1 parent c42b5e2 commit 53684ed
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 189 deletions.
20 changes: 10 additions & 10 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,16 +1646,11 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
return true;
}

bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */)
{
return doesTargetSupportPtrTypes();
}

void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec)
{
EmitOpInfo newOuterPrec = outerPrec;

if (isPointerSyntaxRequiredImpl(inst))
if (doesTargetSupportPtrTypes())
{
switch (inst->getOp())
{
Expand Down Expand Up @@ -1754,7 +1749,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const&

void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec)
{
if (isPointerSyntaxRequiredImpl(inst))
if (doesTargetSupportPtrTypes())
{
auto prec = getInfo(EmitOp::Prefix);
auto newOuterPrec = outerPrec;
Expand Down Expand Up @@ -2003,6 +1998,11 @@ void CLikeSourceEmitter::emitIntrinsicCallExprImpl(
}
}

void CLikeSourceEmitter::emitCallArg(IRInst* inst)
{
emitOperand(inst, getInfo(EmitOp::General));
}

void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex)
{
bool isFirstArg = true;
Expand All @@ -2023,7 +2023,7 @@ void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex
m_writer->emit(", ");
else
isFirstArg = false;
emitOperand(inst->getOperand(aa), getInfo(EmitOp::General));
emitCallArg(inst->getOperand(aa));
}
m_writer->emit(")");
}
Expand Down Expand Up @@ -2296,7 +2296,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO

IRFieldAddress* ii = (IRFieldAddress*) inst;

if (isPointerSyntaxRequiredImpl(inst))
if (doesTargetSupportPtrTypes())
{
auto prec = getInfo(EmitOp::Prefix);
needClose = maybeEmitParens(outerPrec, prec);
Expand Down Expand Up @@ -4206,7 +4206,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)

emitRateQualifiersAndAddressSpace(varDecl);
emitVarKeyword(varType, varDecl);
emitGlobalParamType(varType, getName(varDecl));
emitType(varType, getName(varDecl));

emitSemantics(varDecl);

Expand Down
3 changes: 1 addition & 2 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ class CLikeSourceEmitter: public SourceEmitterBase
void emitType(IRType* type);
void emitType(IRType* type, Name* name, SourceLoc const& nameLoc);
void emitType(IRType* type, NameLoc const& nameAndLoc);
virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);}
bool hasExplicitConstantBufferOffset(IRInst* cbufferType);
bool isSingleElementConstantBuffer(IRInst* cbufferType);
bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType);
Expand Down Expand Up @@ -430,7 +429,6 @@ class CLikeSourceEmitter: public SourceEmitterBase

void emitGlobalInst(IRInst* inst);
virtual void emitGlobalInstImpl(IRInst* inst);
virtual bool isPointerSyntaxRequiredImpl(IRInst* inst);

void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition);

Expand Down Expand Up @@ -567,6 +565,7 @@ class CLikeSourceEmitter: public SourceEmitterBase

// Emit the argument list (including paranthesis) in a `CallInst`
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
virtual void emitCallArg(IRInst* arg);

String _generateUniqueName(const UnownedStringSlice& slice);

Expand Down
179 changes: 52 additions & 127 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
// 'transpose' calls, or else perform more complicated transformations that
// end up duplicating expressions many times.

namespace Slang {
namespace Slang
{

void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
IRBasicType *const switchConditionType,
const SwitchRegion::Case *const currentCase, const bool isDefault
)
const SwitchRegion::Case *const currentCase,
const bool isDefault)
{
// WGSL has special syntax for blocks sharing case labels:
// "case 2, 3, 4: ...;" instead of the C-like syntax
Expand Down Expand Up @@ -80,8 +81,8 @@ void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
}

void WGSLSourceEmitter::emitParameterGroupImpl(
IRGlobalParam* varDecl, IRUniformParameterGroupType* type
)
IRGlobalParam* varDecl,
IRUniformParameterGroupType* type)
{
auto varLayout = getVarLayout(varDecl);
SLANG_RELEASE_ASSERT(varLayout);
Expand Down Expand Up @@ -140,8 +141,8 @@ void WGSLSourceEmitter::emitParameterGroupImpl(
}

void WGSLSourceEmitter::emitEntryPointAttributesImpl(
IRFunc* irFunc, IREntryPointDecoration* entryPointDecor
)
IRFunc* irFunc,
IREntryPointDecoration* entryPointDecor)
{
auto stage = entryPointDecor->getProfile().getStage();

Expand Down Expand Up @@ -238,9 +239,7 @@ static bool isPowerOf2(const uint32_t n)
return (n != 0U) && ((n - 1U) & n) == 0U;
}

void WGSLSourceEmitter::emitStructFieldAttributes(
IRStructType * structType, IRStructField * field
)
void WGSLSourceEmitter::emitStructFieldAttributes(IRStructType * structType, IRStructField * field)
{
// Tint emits errors unless we explicitly spell out the layout in some cases, so emit
// offset and align attribtues for all fields.
Expand Down Expand Up @@ -273,26 +272,6 @@ void WGSLSourceEmitter::emitStructFieldAttributes(
m_writer->emit(")");
}

bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst)
{
if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr)
return false;

// Don't emit "->" to access fields in resource structs
if (inst->getOp() == kIROp_FieldAddress)
return false;

// Don't emit "*" to access fields in resource structs
if (inst->getOp() == kIROp_GlobalParam)
return false;

// Emit 'globalVar' instead of "*&globalVar"
if (inst->getOp() == kIROp_GlobalVar)
return false;

return true;
}

void WGSLSourceEmitter::emit(const AddressSpace addressSpace)
{
switch (addressSpace)
Expand Down Expand Up @@ -325,32 +304,14 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
{

case kIROp_HLSLRWStructuredBufferType:
{
auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
m_writer->emit("ptr<");
emit(AddressSpace::StorageBuffer);
m_writer->emit(", ");
m_writer->emit("array");
m_writer->emit("<");
emitType(structuredBufferType->getElementType());
m_writer->emit(">");
m_writer->emit(", read_write");
m_writer->emit(">");
}
break;

case kIROp_HLSLStructuredBufferType:
case kIROp_HLSLRasterizerOrderedStructuredBufferType:
{
auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
m_writer->emit("ptr<");
emit(AddressSpace::StorageBuffer);
m_writer->emit(", ");
m_writer->emit("array");
m_writer->emit("<");
emitType(structuredBufferType->getElementType());
m_writer->emit(">");
m_writer->emit(", read");
m_writer->emit(">");
}
break;

Expand Down Expand Up @@ -582,7 +543,8 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl)
{
m_writer->emit("<workgroup>");
}
else if (type->getOp() == kIROp_HLSLRWStructuredBufferType)
else if (type->getOp() == kIROp_HLSLRWStructuredBufferType ||
type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType)
{
m_writer->emit("<");
m_writer->emit("storage, read_write");
Expand Down Expand Up @@ -692,9 +654,26 @@ void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator)
}
}

void WGSLSourceEmitter::emitOperandImpl(IRInst* operand, EmitOpInfo const& outerPrec)
{
if (operand->getOp() == kIROp_Param && as<IRPtrTypeBase>(operand->getDataType()))
{
// If we are emitting a reference to a pointer typed operand, then
// we should dereference it now since we want to treat all the remaining
// part of wgsl as pointer-free target.
m_writer->emit("(*");
m_writer->emit(getName(operand));
m_writer->emit(")");
}
else
{
CLikeSourceEmitter::emitOperandImpl(operand, outerPrec);
}
}

void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(
IRType* type, DeclaratorInfo* declarator
)
IRType* type,
DeclaratorInfo* declarator)
{
if (declarator)
{
Expand Down Expand Up @@ -999,13 +978,29 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
}
}

void WGSLSourceEmitter::emitCallArg(IRInst* inst)
{
if (as<IRPtrTypeBase>(inst->getDataType()))
{
// If we are calling a function with a pointer-typed argument, we need to
// explicitly prefix the argument with `&` to pass a pointer.
//
m_writer->emit("&(");
emitOperand(inst, getInfo(EmitOp::General));
m_writer->emit(")");
}
else
{
emitOperand(inst, getInfo(EmitOp::General));
}
}

bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;

switch (inst->getOp())
{

case kIROp_MakeVectorFromScalar:
{
// In WGSL this is done by calling the vec* overloads listed in [1]
Expand Down Expand Up @@ -1079,25 +1074,13 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;

case kIROp_RWStructuredBufferGetElementPtr:
{
m_writer->emit("(*");
emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix)));
m_writer->emit(")[");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit("]");
return true;
}
break;

case kIROp_StructuredBufferLoad:
case kIROp_RWStructuredBufferLoad:
case kIROp_RWStructuredBufferGetElementPtr:
{
// Structured buffers are just arrays in WGSL
auto base = inst->getOperand(0);
emitOperand(base, outerPrec);
emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix)));
m_writer->emit("[");
emitOperand(inst->getOperand(1), EmitOpInfo());
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit("]");
return true;
}
Expand Down Expand Up @@ -1134,15 +1117,12 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
}
break;

}

return false;
}

void WGSLSourceEmitter::emitVectorTypeNameImpl(
IRType* elementType, IRIntegerValue elementCount
)
void WGSLSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
{

if (elementCount > 1)
Expand All @@ -1159,61 +1139,6 @@ void WGSLSourceEmitter::emitVectorTypeNameImpl(
}
}

void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec)
{
// In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
// everywhere, except for the global parameter declaration.
// Thus, when these globals are used in expressions, we need an ampersand.

if (inst->getOp() == kIROp_GlobalParam)
{
switch (inst->getDataType()->getOp())
{
case kIROp_HLSLStructuredBufferType:
case kIROp_HLSLRWStructuredBufferType:

m_writer->emit("(&");
CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
m_writer->emit(")");
return;
}
}

CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
}

void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name)
{
// In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
// everywhere, except for the global parameter declaration.

switch (type->getOp())
{

case kIROp_HLSLStructuredBufferType:
case kIROp_HLSLRWStructuredBufferType:
{
StringSliceLoc nameAndLoc(name.getUnownedSlice());
NameDeclaratorInfo nameDeclarator(&nameAndLoc);
emitDeclarator(&nameDeclarator);
m_writer->emit(" : ");
auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
m_writer->emit("array");
m_writer->emit("<");
emitType(structuredBufferType->getElementType());
m_writer->emit(">");
}
break;

default:

emitType(type, name);
break;

}

}

void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */)
{
if (m_f16ExtensionEnabled)
Expand Down
Loading

0 comments on commit 53684ed

Please sign in to comment.