Skip to content

Commit

Permalink
Decorate
Browse files Browse the repository at this point in the history
  • Loading branch information
amdrexu committed Dec 11, 2023
1 parent daf27f7 commit 7b77de2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 128 deletions.
47 changes: 1 addition & 46 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVDecorate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void SPIRVGroupDecorate::decorateTargets() {
auto Target = getOrCreate(I);
for (auto &Dec : DecorationGroup->getDecorations()) {
assert(Dec->isDecorate());
Target->addDecorate(static_cast<const SPIRVDecorate *const>(Dec));
Target->addDecorate(static_cast<SPIRVDecorate *const>(Dec));
}
}
}
Expand Down Expand Up @@ -164,49 +164,4 @@ void SPIRVDecorateId::decode(std::istream &I) {
getOrCreateTarget()->addDecorate(this);
}

bool SPIRVDecorateGeneric::Comparator::operator()(const SPIRVDecorateGeneric *A, const SPIRVDecorateGeneric *B) const {
auto Action = [=]() {
if (A->getOpCode() < B->getOpCode())
return true;
if (A->getOpCode() > B->getOpCode())
return false;
if (A->getDecorateKind() < B->getDecorateKind())
return true;
if (A->getDecorateKind() > B->getDecorateKind())
return false;
if (A->getLiteralCount() < B->getLiteralCount())
return true;
if (A->getLiteralCount() > B->getLiteralCount())
return false;
for (size_t I = 0, E = A->getLiteralCount(); I != E; ++I) {
auto EA = A->getLiteral(I);
auto EB = B->getLiteral(I);
if (EA < EB)
return true;
if (EA > EB)
return false;
}
return false;
};
auto Res = Action();
return Res;
}

bool operator==(const SPIRVDecorateGeneric &A, const SPIRVDecorateGeneric &B) {
if (A.getTargetId() != B.getTargetId())
return false;
if (A.getOpCode() != B.getOpCode())
return false;
if (A.getDecorateKind() != B.getDecorateKind())
return false;
if (A.getLiteralCount() != B.getLiteralCount())
return false;
for (size_t I = 0, E = A.getLiteralCount(); I != E; ++I) {
auto EA = A.getLiteral(I);
auto EB = B.getLiteral(I);
if (EA != EB)
return false;
}
return true;
}
} // namespace SPIRV
37 changes: 7 additions & 30 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVDecorate.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ class SPIRVDecorateGeneric : public SPIRVAnnotationGeneric {
}
Decoration getDecorateKind() const;
size_t getLiteralCount() const;
/// Compare for kind and literal only.
struct Comparator {
bool operator()(const SPIRVDecorateGeneric *A, const SPIRVDecorateGeneric *B) const;
};
/// Compare kind, literals and target.
friend bool operator==(const SPIRVDecorateGeneric &A, const SPIRVDecorateGeneric &B);

SPIRVDecorationGroup *getOwner() const { return Owner; }

Expand All @@ -96,18 +90,7 @@ class SPIRVDecorateGeneric : public SPIRVAnnotationGeneric {
SPIRVDecorationGroup *Owner; // Owning decorate group
};

class SPIRVDecorateSet : public std::multiset<const SPIRVDecorateGeneric *, SPIRVDecorateGeneric::Comparator> {
public:
typedef std::multiset<const SPIRVDecorateGeneric *, SPIRVDecorateGeneric::Comparator> BaseType;
iterator insert(const value_type &Dec) {
auto ER = BaseType::equal_range(Dec);
for (auto I = ER.first, E = ER.second; I != E; ++I) {
if (**I == *Dec)
return I;
}
return BaseType::insert(Dec);
}
};
typedef std::vector<SPIRVDecorateGeneric *> SPIRVDecorateVec;

class SPIRVDecorate : public SPIRVDecorateGeneric {
public:
Expand Down Expand Up @@ -205,23 +188,17 @@ class SPIRVDecorationGroup : public SPIRVEntry {
SPIRVDecorationGroup() : SPIRVEntry(OC) {}
_SPIRV_DCL_DECODE
// Move the given decorates to the decoration group
void takeDecorates(SPIRVDecorateSet &Decs) {
for (auto &I : Decs) {
// Insert decorates whose target ID is this decoration group
if (I->getTargetId() == Id) {
const_cast<SPIRVDecorateGeneric *>(I)->setOwner(this);
Decorations.insert(I);
}
}
// Remove those inserted decorates from original set
void takeDecorates(SPIRVDecorateVec &Decs) {
Decorations = std::move(Decs);
for (auto &I : Decorations)
Decs.erase(I);
const_cast<SPIRVDecorateGeneric *>(I)->setOwner(this);
Decs.clear();
}

SPIRVDecorateSet &getDecorations() { return Decorations; }
SPIRVDecorateVec &getDecorations() { return Decorations; }

protected:
SPIRVDecorateSet Decorations;
SPIRVDecorateVec Decorations;
void validate() const override {
assert(OpCode == OC);
assert(WordCount == WC);
Expand Down
4 changes: 2 additions & 2 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ void SPIRVEntry::validateBuiltin(SPIRVWord TheSet, SPIRVWord Index) const {
assert(TheSet != SPIRVWORD_MAX && Index != SPIRVWORD_MAX && "Invalid builtin");
}

void SPIRVEntry::addDecorate(const SPIRVDecorate *Dec) {
void SPIRVEntry::addDecorate(SPIRVDecorate *Dec) {
auto Kind = Dec->getDecorateKind();
Decorates.insert(std::make_pair(Dec->getDecorateKind(), Dec));
Module->addDecorate(Dec);
Expand Down Expand Up @@ -232,7 +232,7 @@ void SPIRVEntry::setLine(const SPIRVLine *L) {
Line = L;
}

void SPIRVEntry::addMemberDecorate(const SPIRVMemberDecorate *Dec) {
void SPIRVEntry::addMemberDecorate(SPIRVMemberDecorate *Dec) {
assert(Dec);
assert(canHaveMemberDecorates());
MemberDecorates[Dec->getPair()] = Dec;
Expand Down
4 changes: 2 additions & 2 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,11 @@ class SPIRVEntry {
return false;
}

void addDecorate(const SPIRVDecorate *);
void addDecorate(SPIRVDecorate *);
void addDecorate(Decoration Kind);
void addDecorate(Decoration Kind, SPIRVWord Literal);
void eraseDecorate(Decoration);
void addMemberDecorate(const SPIRVMemberDecorate *);
void addMemberDecorate(SPIRVMemberDecorate *);
void addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind);
void addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind, SPIRVWord Literal);
void eraseMemberDecorate(SPIRVWord MemberNumber, Decoration Kind);
Expand Down
53 changes: 7 additions & 46 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class SPIRVModuleImpl : public SPIRVModule {
// Module changing functions
bool importBuiltinSet(const std::string &, SPIRVId *) override;
bool importBuiltinSetWithId(const std::string &, SPIRVId) override;
void optimizeDecorates() override;
void setAddressingModel(SPIRVAddressingModelKind AM) override { AddrModel = AM; }
void postProcessExecutionModeId();
void setMemoryModel(SPIRVMemoryModelKind MM) override { MemoryModel = MM; }
Expand All @@ -173,7 +172,7 @@ class SPIRVModuleImpl : public SPIRVModule {
void setCurrentLine(const SPIRVLine *Line) override;
void addCapability(SPIRVCapabilityKind) override;
void addCapabilityInternal(SPIRVCapabilityKind) override;
const SPIRVDecorateGeneric *addDecorate(const SPIRVDecorateGeneric *) override;
const SPIRVDecorateGeneric *addDecorate(SPIRVDecorateGeneric *) override;
SPIRVDecorationGroup *addDecorationGroup() override;
SPIRVDecorationGroup *addDecorationGroup(SPIRVDecorationGroup *Group) override;
SPIRVGroupDecorate *addGroupDecorate(SPIRVDecorationGroup *Group, const std::vector<SPIRVEntry *> &Targets) override;
Expand Down Expand Up @@ -323,7 +322,7 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVStringVec StringVec;
SPIRVMemberNameVec MemberNameVec;
const SPIRVLine *CurrentLine;
SPIRVDecorateSet DecorateSet;
SPIRVDecorateVec DecorateVec;
SPIRVDecGroupVec DecGroupVec;
SPIRVGroupDecVec GroupDecVec;
SPIRVEnetryPointVec EntryPointVec;
Expand Down Expand Up @@ -357,44 +356,6 @@ void SPIRVModuleImpl::setCurrentLine(const SPIRVLine *Line) {
CurrentLine = Line;
}

// Creates decoration group and group decorates from decorates shared by
// multiple targets.
void SPIRVModuleImpl::optimizeDecorates() {
for (auto I = DecorateSet.begin(), E = DecorateSet.end(); I != E;) {
auto D = *I;
if (D->getOpCode() == OpMemberDecorate) {
++I;
continue;
}
auto ER = DecorateSet.equal_range(D);
if (std::distance(ER.first, ER.second) < 2) {
I = ER.second;
continue;
}
auto G = add(new SPIRVDecorationGroup(this, getId()));
std::vector<SPIRVId> Targets;
Targets.push_back(D->getTargetId());
const_cast<SPIRVDecorateGeneric *>(D)->setTargetId(G->getId());
G->getDecorations().insert(D);
for (I = ER.first; I != ER.second; ++I) {
auto E = *I;
if (*E == *D)
continue;
Targets.push_back(E->getTargetId());
}

// WordCount is only 16 bits. We can only have 65535 - FixedWC targets per
// group.
// For now, just skip using a group if the number of targets to too big
if (Targets.size() < 65530) {
DecorateSet.erase(ER.first, ER.second);
auto GD = add(new SPIRVGroupDecorate(G, Targets));
DecGroupVec.push_back(G);
GroupDecVec.push_back(GD);
}
}
}

void SPIRVModuleImpl::postProcessExecutionModeId() {
for (auto ExecModeId : ExecModeIdVec) {
SPIRVExecutionModeId *E = static_cast<SPIRVExecutionModeId *>(ExecModeId);
Expand Down Expand Up @@ -741,12 +702,12 @@ SPIRVBasicBlock *SPIRVModuleImpl::addBasicBlock(SPIRVFunction *Func, SPIRVId Id)
return Func->addBasicBlock(new SPIRVBasicBlock(getId(Id), Func));
}

const SPIRVDecorateGeneric *SPIRVModuleImpl::addDecorate(const SPIRVDecorateGeneric *Dec) {
const SPIRVDecorateGeneric *SPIRVModuleImpl::addDecorate(SPIRVDecorateGeneric *Dec) {
SPIRVEntry *Target = nullptr;
assert(exist(Dec->getTargetId(), &Target) && "Decorate target does not exist");
(void)Target;
if (!Dec->getOwner())
DecorateSet.insert(Dec);
DecorateVec.push_back(Dec);
addCapabilities(Dec->getRequiredCapability());
return Dec;
}
Expand Down Expand Up @@ -1038,7 +999,8 @@ SPIRVInstruction *SPIRVModuleImpl::addVariable(SPIRVType *Type, bool IsConstant,
return Variable;
}

template <class T, class B> spv_ostream &operator<<(spv_ostream &O, const std::multiset<T *, B> &V) {
template <class T, class B = std::less<T>>
spv_ostream &operator<<(spv_ostream &O, const std::unordered_set<T *, B> &V) {
for (auto &I : V)
O << *I;
return O;
Expand All @@ -1057,7 +1019,7 @@ SPIRVDecorationGroup *SPIRVModuleImpl::addDecorationGroup() {

SPIRVDecorationGroup *SPIRVModuleImpl::addDecorationGroup(SPIRVDecorationGroup *Group) {
add(Group);
Group->takeDecorates(DecorateSet);
Group->takeDecorates(DecorateVec);
DecGroupVec.push_back(Group);
return Group;
}
Expand Down Expand Up @@ -1127,7 +1089,6 @@ std::istream &operator>>(std::istream &I, SPIRVModule &M) {
Decoder.getEntry();

MI.postProcessExecutionModeId();
MI.optimizeDecorates();
MI.resolveUnknownStructFields();
MI.createForwardPointers();
return I;
Expand Down
3 changes: 1 addition & 2 deletions llpc/translator/lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class SPIRVModule {
virtual void setSourceLanguage(SourceLanguage, SPIRVWord) = 0;
virtual void setSourceFile(SPIRVId) = 0;
virtual SPIRVString *getSourceFile(uint32_t FileId) const = 0;
virtual void optimizeDecorates() = 0;
virtual void setAutoAddCapability(bool E) { AutoAddCapability = E; }
virtual void setValidateCapability(bool E) { ValidateCapability = E; }
virtual void setGeneratorId(unsigned short) = 0;
Expand All @@ -168,7 +167,7 @@ class SPIRVModule {
virtual void addUnknownStructField(SPIRVTypeStruct *, unsigned Idx, SPIRVId Id) = 0;
virtual const SPIRVLine *getCurrentLine() const = 0;
virtual void setCurrentLine(const SPIRVLine *) = 0;
virtual const SPIRVDecorateGeneric *addDecorate(const SPIRVDecorateGeneric *) = 0;
virtual const SPIRVDecorateGeneric *addDecorate(SPIRVDecorateGeneric *) = 0;
virtual SPIRVDecorationGroup *addDecorationGroup() = 0;
virtual SPIRVDecorationGroup *addDecorationGroup(SPIRVDecorationGroup *Group) = 0;
virtual SPIRVGroupDecorate *addGroupDecorate(SPIRVDecorationGroup *Group,
Expand Down

0 comments on commit 7b77de2

Please sign in to comment.