Skip to content

Commit

Permalink
Make UnsafeStringCodec memoizing.
Browse files Browse the repository at this point in the history
* Makes the Memoizer memoize Strings using values-equality instead of reference
  equality. This improves determinism.

PiperOrigin-RevId: 599232400
Change-Id: I24c4112e154b9b68352494427c39d74527cd2d28
  • Loading branch information
aoeui authored and copybara-github committed Jan 17, 2024
1 parent 916c3f5 commit c9e4446
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 220 deletions.
173 changes: 21 additions & 152 deletions src/main/java/com/google/devtools/build/lib/packages/CallStack.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.devtools.build.lib.packages;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.collect.Interner;
import com.google.devtools.build.lib.concurrent.BlazeInterners;
import com.google.devtools.build.lib.skyframe.serialization.VisibleForSerialization;
import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
import com.google.devtools.build.lib.util.HashCodes;
import com.google.devtools.build.lib.util.StringCanonicalizer;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import net.starlark.java.eval.StarlarkThread;
Expand All @@ -48,7 +39,16 @@ final class CallStack {

private CallStack() {}

/**
* Returns the <em>full</em> call stack of the given rule, including both {@link Rule#getLocation}
* and {@link Rule#getInteriorCallStack}.
*/
static Node getFullCallStack(Rule rule) {
return new Node(StarlarkThread.TOP_LEVEL, rule.getLocation(), rule.getInteriorCallStack());
}

/** Compact representation of a call stack entry. */
@AutoCodec
static final class Node {
/** Function name. */
private final String name;
Expand All @@ -59,6 +59,16 @@ static final class Node {
private final int col;
@Nullable private final Node next;

@AutoCodec.Instantiator
@VisibleForSerialization
static Node createForDeserialization(
String name, String file, int line, int col, @Nullable Node next) {
// Use common canonicalizer based on assertion that most strings (function names, locations)
// were already shared across packages to some degree.
return new Node(
StringCanonicalizer.intern(name), StringCanonicalizer.intern(file), line, col, next);
}

private Node(String name, Location location, @Nullable Node next) {
this(name, location.file(), location.line(), location.column(), next);
}
Expand Down Expand Up @@ -132,145 +142,4 @@ static Node compactInterior(List<StarlarkThread.CallStackEntry> stack) {
}
return node;
}

/**
* Efficient serializer for {@link Node}s. Before callstacks are serialized in a package {@link
* #prepareCallStack} method must be called on them (to prepare a table of strings).
*/
static final class Serializer {
private static final int NULL_NODE_ID = 0;

private final IdentityHashMap<Node, Integer> nodeTable = new IdentityHashMap<>();
private final List<String> stringTable = new ArrayList<>();
private boolean stringTableSerialized = false;

private final Map<String, Integer> stringTableIndex = new HashMap<>();

Serializer() {
nodeTable.put(null, NULL_NODE_ID);
indexString(StarlarkThread.TOP_LEVEL);
}

private void indexString(String s) {
checkState(!stringTableSerialized);
int i = stringTableIndex.size();
if (stringTableIndex.putIfAbsent(s, i) == null) {
stringTable.add(s);
}
}

private int indexOf(String s) {
return checkNotNull(stringTableIndex.get(s), s);
}

/**
* Serializes the <em>full</em> call stack of the given rule, including both {@link
* Rule#getLocation} and {@link Rule#getInteriorCallStack}.
*/
void serializeCallStack(Rule rule, CodedOutputStream codedOut) throws IOException {
if (!stringTableSerialized) {
codedOut.writeInt32NoTag(stringTable.size());
for (String string : stringTable) {
codedOut.writeStringNoTag(string);
}
stringTableSerialized = true;
}

emitNode(
new Node(StarlarkThread.TOP_LEVEL, rule.getLocation(), rule.getInteriorCallStack()),
codedOut);
}

private void emitNode(Node node, CodedOutputStream codedOut) throws IOException {
Integer index = nodeTable.get(node);
if (index != null) {
codedOut.writeInt32NoTag(index);
return;
}

if (node == null) {
return;
}

int newIndex = nodeTable.size();
codedOut.writeInt32NoTag(newIndex);
nodeTable.put(node, newIndex);
codedOut.writeInt32NoTag(indexOf(node.name));
codedOut.writeInt32NoTag(indexOf(node.file));
codedOut.writeInt32NoTag(node.line);
codedOut.writeInt32NoTag(node.col);
emitNode(node.next, codedOut);
}

void prepareCallStack(Rule rule) {
indexString(rule.getLocation().file());
for (Node n = rule.getInteriorCallStack(); n != null; n = n.next) {
indexString(n.name);
indexString(n.file);
}
}
}

/** Deserializes call stacks as serialized by a {@link Serializer}. */
static final class Deserializer {
private static final Node DUMMY_NODE = new Node("", "", -1, -1, null);

private final List<Node> nodeTable = new ArrayList<>();
@Nullable private List<String> stringTable;

Deserializer() {
// By convention index 0 = null.
nodeTable.add(null);
}

/**
* Deserializes a <em>full</em> call stack.
*
* <p>The returned {@link Node} represents {@link Rule#getLocation}. Calling {@link Node#next()}
* on the returned {@link Node} yields {@link Rule#getInteriorCallStack}.
*/
Node deserializeFullCallStack(CodedInputStream codedIn) throws IOException {
if (stringTable == null) {
int length = codedIn.readInt32();
stringTable = new ArrayList<>(length);
for (int i = 0; i < length; i++) {
// Avoid having a new set of strings per deserialized string table. Use common
// canonicalizer based on assertion that most strings (function names, locations) were
// already some degree of shared across packages.
stringTable.add(StringCanonicalizer.intern(codedIn.readString()));
}
}

return readNode(codedIn);
}

@Nullable
@SuppressWarnings("ReferenceEquality")
private Node readNode(CodedInputStream codedIn) throws IOException {
int index = codedIn.readInt32();
if (index < nodeTable.size()) {
Node result = nodeTable.get(index);
checkState(result != DUMMY_NODE, "Loop detected at index %s", index);
return result;
}

checkState(
index == nodeTable.size(),
"Unexpected next value index - read %s, expected %s",
index,
nodeTable.size());

// Add dummy node to grow the table and save our spot in the table until we're done.
nodeTable.add(DUMMY_NODE);
int name = codedIn.readInt32();
int file = codedIn.readInt32();
int line = codedIn.readInt32();
int col = codedIn.readInt32();
Node child = readNode(codedIn);

Node result = new Node(stringTable.get(name), stringTable.get(file), line, col, child);
nodeTable.set(index, result);
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,6 @@ public void deserializeFully(CodedInputStream codedIn, Object obj, long offset,
deserialize(codedIn, obj, offset, done);
}

@Nullable
@SuppressWarnings({"TypeParameterUnusedInFormals"})
public <T> T deserializeWithAdHocMemoizationStrategy(
CodedInputStream codedIn, MemoizationStrategy memoizationStrategy)
throws IOException, SerializationException {
return deserializeInternal(codedIn, memoizationStrategy);
}

@Nullable
@SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"})
private <T> T deserializeInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
import java.io.IOException;
import java.util.ArrayDeque;
Expand Down Expand Up @@ -149,12 +150,8 @@ <T> void serialize(
CodedOutputStream codedOut,
MemoizationStrategy strategy)
throws SerializationException, IOException {
if (strategy == MemoizationStrategy.DO_NOT_MEMOIZE) {
codec.serialize(context, obj, codedOut);
} else {
// The caller already checked the table, so this is definitely a new value.
serializeMemoContent(context, obj, codec, codedOut, strategy);
}
// The caller already checked the table, so this is definitely a new value.
serializeMemoContent(context, obj, codec, codedOut, strategy);
}

int getMemoizedIndex(Object obj) {
Expand Down Expand Up @@ -188,16 +185,18 @@ private <T> void serializeMemoContent(
codedOut.writeInt32NoTag(id);
break;
}
default:
throw new AssertionError("Unreachable (strategy=" + strategy + ")");
}
}

private static class SerializingMemoTable {
private final Reference2IntOpenHashMap<Object> table = new Reference2IntOpenHashMap<>();

/** Table for types memoized using values equality, currently only {@link String}. */
private final Object2IntOpenHashMap<Object> valuesTable = new Object2IntOpenHashMap<>();

SerializingMemoTable() {
table.defaultReturnValue(-1);
valuesTable.defaultReturnValue(-1);
}

/**
Expand All @@ -206,17 +205,24 @@ private static class SerializingMemoTable {
*/
private int memoize(Object value) {
Preconditions.checkArgument(
!table.containsKey(value), "Tried to memoize object '%s' multiple times", value);
lookup(value) == -1, "Tried to memoize object '%s' multiple times", value);
// Ids count sequentially from 0.
int newId = table.size();
table.put(value, newId);
int newId = table.size() + valuesTable.size();
if (value instanceof String) {
valuesTable.put(value, newId);
} else {
table.put(value, newId);
}
return newId;
}

/**
* If the value is already memoized, return its on-the-wire id; otherwise returns {@code -1}.
*/
private int lookup(Object value) {
if (value instanceof String) {
return valuesTable.getInt(value);
}
return table.getInt(value);
}
}
Expand Down Expand Up @@ -247,18 +253,13 @@ <T> T deserialize(
"non-null memoized-before tag %s (%s)",
tagForMemoizedBefore,
codec);
if (strategy == MemoizationStrategy.DO_NOT_MEMOIZE) {
return codec.deserialize(context, codedIn);
} else {
switch (strategy) {
case MEMOIZE_BEFORE:
return deserializeMemoBeforeContent(context, codec, codedIn);
case MEMOIZE_AFTER:
return deserializeMemoAfterContent(context, codec, codedIn);
default:
throw new AssertionError("Unreachable (strategy=" + strategy + ")");
}
}
switch (strategy) {
case MEMOIZE_BEFORE:
return deserializeMemoBeforeContent(context, codec, codedIn);
case MEMOIZE_AFTER:
return deserializeMemoAfterContent(context, codec, codedIn);
}
throw new AssertionError("Unreachable (strategy=" + strategy + ")");
}

Object getMemoized(int memoIndex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,6 @@ default MemoizationStrategy getStrategy() {

/** Indicates how an {@link ObjectCodec} is memoized. */
enum MemoizationStrategy {
/**
* Indicates that memoization is not directly used by this codec.
*
* <p>Codecs with this strategy will always serialize payloads, never backreferences, even if
* the same value has been serialized before. This does not apply to other codecs that are
* delegated to within this codec. Deserialization behaves analogously.
*
* <p>This strategy is useful for codecs that write very little data themselves, but that still
* delegate to other codecs.
*/
DO_NOT_MEMOIZE,

/**
* Indicates that the value is memoized before recursing to its children, so that it is
* available to form cyclic references from its children. If this strategy is used, {@link
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ public void serialize(Object object, CodedOutputStream codedOut)
serializeInternal(object, /*customMemoizationStrategy=*/ null, codedOut);
}

void serializeWithAdHocMemoizationStrategy(
Object object, MemoizationStrategy memoizationStrategy, CodedOutputStream codedOut)
throws IOException, SerializationException {
serializeInternal(object, memoizationStrategy, codedOut);
}

private void serializeInternal(
Object object,
@Nullable MemoizationStrategy customMemoizationStrategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ public Class<String> getEncodedClass() {
return String.class;
}

@Override
public MemoizationStrategy getStrategy() {
// Don't memoize strings inside memoizing serialization, to preserve current behavior.
// TODO(janakr,brandjon,michajlo): Is it actually a problem to memoize strings? Doubt there
// would be much performance impact from increasing the size of the identity map, and we
// could potentially drop our string tables in the future.
return MemoizationStrategy.DO_NOT_MEMOIZE;
}

@Override
public void serialize(
SerializationDependencyProvider dependencies, String obj, CodedOutputStream codedOut)
Expand Down
Loading

0 comments on commit c9e4446

Please sign in to comment.