Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify scratch local calculation #6583

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/wasm-stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,14 @@ class BinaryInstWriter : public OverriddenVisitor<BinaryInstWriter> {
// type => number of locals of that type in the compact form
std::unordered_map<Type, size_t> numLocalsByType;

void noteLocalType(Type type);
void noteLocalType(Type type, Index count = 1);

// Keeps track of the binary index of the scratch locals used to lower
// tuple.extract.
// tuple.extract. If there are multiple scratch locals of the same type, they
// are contiguous and this map holds the index of the first.
InsertOrderedMap<Type, Index> scratchLocals;
void countScratchLocals();
void setScratchLocals();
// Return the type and number of required scratch locals.
InsertOrderedMap<Type, Index> countScratchLocals();

// local.get, local.tee, and glboal.get expressions that will be followed by
// tuple.extracts. We can optimize these by getting only the local for the
Expand Down
122 changes: 70 additions & 52 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,8 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
mappedLocals[std::make_pair(i, 0)] = i;
}

auto scratches = countScratchLocals();

// Normally we map all locals of the same type into a range of adjacent
// addresses, which is more compact. However, if we need to keep DWARF valid,
// do not do any reordering at all - instead, do a trivial mapping that
Expand All @@ -2584,25 +2586,26 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) {
size_t size = func->getLocalType(i).size();
for (Index j = 0; j < size; j++) {
mappedLocals[std::make_pair(i, j)] = mappedIndex + j;
mappedLocals[std::make_pair(i, j)] = mappedIndex++;
}
mappedIndex += size;
}
countScratchLocals();

size_t numBinaryLocals =
mappedIndex - func->getVarIndexBase() + scratchLocals.size();
mappedIndex - func->getVarIndexBase() + scratches.size();

o << U32LEB(numBinaryLocals);

for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) {
for (const auto& type : func->getLocalType(i)) {
o << U32LEB(1);
parent.writeType(type);
}
}
for (auto& [type, _] : scratchLocals) {
o << U32LEB(1);
for (auto& [type, count] : scratches) {
o << U32LEB(count);
parent.writeType(type);
scratchLocals[type] = mappedIndex++;
scratchLocals[type] = mappedIndex;
mappedIndex += count;
}
return;
}
Expand All @@ -2612,7 +2615,10 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
noteLocalType(t);
}
}
countScratchLocals();

for (auto& [type, count] : scratches) {
noteLocalType(type, count);
}

if (parent.getModule()->features.hasReferenceTypes()) {
// Sort local types in a way that keeps all MVP types together and all
Expand All @@ -2636,23 +2642,28 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
});
}

std::unordered_map<Type, size_t> currLocalsByType;
// Map IR (local index, tuple index) pairs to binary local indices. Since
// locals are grouped by type, start by calculating the base indices for each
// type.
std::unordered_map<Type, Index> nextFreeIndex;
Index baseIndex = func->getVarIndexBase();
for (auto& type : localTypes) {
nextFreeIndex[type] = baseIndex;
baseIndex += numLocalsByType[type];
}

// Map the IR index pairs to indices.
for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) {
Index j = 0;
for (const auto& type : func->getLocalType(i)) {
auto fullIndex = std::make_pair(i, j++);
Index index = func->getVarIndexBase();
for (auto& localType : localTypes) {
if (type == localType) {
mappedLocals[fullIndex] = index + currLocalsByType[localType];
currLocalsByType[type]++;
break;
}
index += numLocalsByType.at(localType);
}
mappedLocals[{i, j++}] = nextFreeIndex[type]++;
}
}
setScratchLocals();

// Map scratch locals to the remaining indices.
for (auto& [type, _] : scratches) {
scratchLocals[type] = nextFreeIndex[type];
}

o << U32LEB(numLocalsByType.size());
for (auto& localType : localTypes) {
Expand All @@ -2661,44 +2672,51 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
}
}

void BinaryInstWriter::noteLocalType(Type type) {
if (!numLocalsByType.count(type)) {
void BinaryInstWriter::noteLocalType(Type type, Index count) {
auto& num = numLocalsByType[type];
if (num == 0) {
localTypes.push_back(type);
}
numLocalsByType[type]++;
num += count;
}

void BinaryInstWriter::countScratchLocals() {
// Add a scratch register in `numLocalsByType` for each type of
// tuple.extract with nonzero index present.
FindAll<TupleExtract> extracts(func->body);
for (auto* extract : extracts.list) {
if (extract->type != Type::unreachable && extract->index != 0) {
scratchLocals[extract->type] = 0;
}
}
for (auto& [type, _] : scratchLocals) {
noteLocalType(type);
}
// While we have all the tuple.extracts, also find extracts of local.gets,
// local.tees, and global.gets that we can optimize.
for (auto* extract : extracts.list) {
auto* tuple = extract->tuple;
if (tuple->is<LocalGet>() || tuple->is<LocalSet>() ||
tuple->is<GlobalGet>()) {
extractedGets.insert({tuple, extract->index});
}
}
}
InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
struct ScratchLocalFinder : PostWalker<ScratchLocalFinder> {
BinaryInstWriter& parent;
InsertOrderedMap<Type, Index> scratches;

void BinaryInstWriter::setScratchLocals() {
Index index = func->getVarIndexBase();
for (auto& localType : localTypes) {
index += numLocalsByType[localType];
if (scratchLocals.find(localType) != scratchLocals.end()) {
scratchLocals[localType] = index - 1;
ScratchLocalFinder(BinaryInstWriter& parent) : parent(parent) {}

// We need two i32 scratch locals for reach string slice, but they can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// We need two i32 scratch locals for reach string slice, but they can be
// We need two i32 scratch locals for each string slice, but they can be

// reused.
bool hasStringSlice = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, this is left over development cruft, will remove.


void visitTupleExtract(TupleExtract* curr) {
if (curr->type == Type::unreachable) {
// We will not emit this instruction anyway.
return;
}
// Extracts from locals or globals are optimizable and do not require
// scratch locals. Record them.
auto* tuple = curr->tuple;
if (tuple->is<LocalGet>() || tuple->is<LocalSet>() ||
tuple->is<GlobalGet>()) {
parent.extractedGets.insert({tuple, curr->index});
return;
}
// Include a scratch register for each type of tuple.extract with nonzero
// index present.
if (curr->index != 0) {
auto& count = scratches[curr->type];
count = std::max(count, 1u);
}
}
}
};

ScratchLocalFinder finder(*this);
finder.walk(func->body);

return std::move(finder.scratches);
}

void BinaryInstWriter::emitMemoryAccess(size_t alignment,
Expand Down
5 changes: 2 additions & 3 deletions test/lit/binary/dwarf-multivalue.test
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@
;; ROUNDTRIP-NEXT: (local $11 i32)
;; ROUNDTRIP-NEXT: (local $12 f32)
;; ROUNDTRIP-NEXT: (local $13 i32)
;; ROUNDTRIP-NEXT: (local $14 f32)
;; ROUNDTRIP-NEXT: (local $15 (tuple i32 f32))
;; ROUNDTRIP-NEXT: (local $16 i32)
;; ROUNDTRIP-NEXT: (local $14 (tuple i32 f32))
;; ROUNDTRIP-NEXT: (local $15 i32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 80 looks like it needs to be updated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the scratch local created when first reading the binary, (local $14 f32) now unnecessary? Is the reason because this was optimized by extractedGets? (It looks before we were counting extractedGets as scratch locals anyway)

In any case the comment needs to be updated from line 78.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point, I think we no longer emit (local $14). I updated the comments.


;; We can see that we don't reorder the locals during the process and the
;; original list of locals, local $0~$10, is untouched, to NOT invalidate DWARF
Expand Down
37 changes: 17 additions & 20 deletions test/lit/multivalue.wast
Original file line number Diff line number Diff line change
Expand Up @@ -149,42 +149,40 @@
;; CHECK: (func $reverse (type $4) (result f32 i64 i32)
;; CHECK-NEXT: (local $x i32)
;; CHECK-NEXT: (local $1 i64)
;; CHECK-NEXT: (local $2 i64)
;; CHECK-NEXT: (local $3 f32)
;; CHECK-NEXT: (local $4 f32)
;; CHECK-NEXT: (local $5 (tuple i32 i64 f32))
;; CHECK-NEXT: (local $6 i64)
;; CHECK-NEXT: (local $7 i32)
;; CHECK-NEXT: (local.set $5
;; CHECK-NEXT: (local $2 f32)
;; CHECK-NEXT: (local $3 (tuple i32 i64 f32))
;; CHECK-NEXT: (local $4 i64)
;; CHECK-NEXT: (local $5 i32)
;; CHECK-NEXT: (local.set $3
;; CHECK-NEXT: (call $triple)
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.set $x
;; CHECK-NEXT: (block (result i32)
;; CHECK-NEXT: (local.set $7
;; CHECK-NEXT: (local.set $5
;; CHECK-NEXT: (tuple.extract 3 0
;; CHECK-NEXT: (local.get $5)
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.set $1
;; CHECK-NEXT: (block (result i64)
;; CHECK-NEXT: (local.set $6
;; CHECK-NEXT: (local.set $4
;; CHECK-NEXT: (tuple.extract 3 1
;; CHECK-NEXT: (local.get $5)
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.set $3
;; CHECK-NEXT: (local.set $2
;; CHECK-NEXT: (tuple.extract 3 2
;; CHECK-NEXT: (local.get $5)
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.get $6)
;; CHECK-NEXT: (local.get $4)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.get $7)
;; CHECK-NEXT: (local.get $5)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (tuple.make 3
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: (local.get $2)
;; CHECK-NEXT: (local.get $1)
;; CHECK-NEXT: (local.get $x)
;; CHECK-NEXT: )
Expand Down Expand Up @@ -228,17 +226,16 @@

;; Test multivalue globals
;; CHECK: (func $global (type $0) (result i32 i64)
;; CHECK-NEXT: (local $0 i64)
;; CHECK-NEXT: (local $1 i32)
;; CHECK-NEXT: (local $0 i32)
;; CHECK-NEXT: (global.set $g1
;; CHECK-NEXT: (block (result i32)
;; CHECK-NEXT: (local.set $1
;; CHECK-NEXT: (local.set $0
;; CHECK-NEXT: (i32.const 42)
;; CHECK-NEXT: )
;; CHECK-NEXT: (global.set $g2
;; CHECK-NEXT: (i64.const 7)
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.get $1)
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
Expand Down
Loading