Skip to content

Commit

Permalink
Allow creating DynamicCodec instances with custom FieldHandlers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606671932
Change-Id: I6d4ff01c2e7f514b5800bef560bf61f2729e2be5
  • Loading branch information
aoeui authored and copybara-github committed Feb 13, 2024
1 parent 723d772 commit 483ed29
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;

/** A codec that serializes arbitrary types. */
public final class DynamicCodec extends AsyncObjectCodec<Object> {
Expand All @@ -39,8 +41,13 @@ public final class DynamicCodec extends AsyncObjectCodec<Object> {
private final FieldHandler[] handlers;

public DynamicCodec(Class<?> type) {
this(type, getFieldHandlers(type));
}

@SuppressWarnings("AvoidObjectArrays") // less overhead
public DynamicCodec(Class<?> type, FieldHandler[] handlers) {
this.type = type;
this.handlers = getFieldHandlers(type);
this.handlers = handlers;
}

@Override
Expand Down Expand Up @@ -91,14 +98,29 @@ public Object deserializeAsync(AsyncDeserializationContext context, CodedInputSt
}

/** Handles serialization of a field. */
private interface FieldHandler {
public interface FieldHandler {
void serialize(SerializationContext context, CodedOutputStream codedOut, Object obj)
throws SerializationException, IOException;

void deserialize(AsyncDeserializationContext context, CodedInputStream codedIn, Object obj)
throws SerializationException, IOException;
}

/**
* Computes the default {@link FieldHandler}s that would be used for the given type.
*
* <p>The entries are ordered by {@link FieldComparator} for determinism. The returned value is a
* fresh copy that the caller may freely modify.
*/
@SuppressWarnings("NonApiType") // type communicates fixed ordering
public static <T> LinkedHashMap<Field, FieldHandler> getFieldHandlerMap(Class<T> type) {
LinkedHashMap<Field, FieldHandler> handlers = new LinkedHashMap<>();
for (Field field : getSerializableFields(type)) {
handlers.put(field, getHandlerForField(field));
}
return handlers;
}

private static final class BooleanHandler implements FieldHandler {
private final long offset;

Expand Down Expand Up @@ -331,8 +353,17 @@ public void deserialize(
}

private static <T> FieldHandler[] getFieldHandlers(Class<T> type) {
// NB: it's tempting to try to simplify this by ordering by offset, but it looks like offsets
// are not guaranteed to be stable, which is needed for deterministic serialization.
List<Field> fields = getSerializableFields(type);

FieldHandler[] handlers = new FieldHandler[fields.size()];
int i = 0;
for (Field field : fields) {
handlers[i++] = getHandlerForField(field);
}
return handlers;
}

private static <T> List<Field> getSerializableFields(Class<T> type) {
ArrayList<Field> fields = new ArrayList<>();
for (Class<? super T> next = type; next != null; next = next.getSuperclass()) {
for (Field field : next.getDeclaredFields()) {
Expand All @@ -342,42 +373,40 @@ private static <T> FieldHandler[] getFieldHandlers(Class<T> type) {
fields.add(field);
}
}
// NB: it's tempting to try to simplify this by ordering by offset, but it looks like offsets
// are not guaranteed to be stable, which is needed for deterministic serialization.
Collections.sort(fields, new FieldComparator());
FieldHandler[] handlers = new FieldHandler[fields.size()];
int i = 0;
for (Field field : fields) {
long offset = unsafe().objectFieldOffset(field);
Class<?> fieldType = field.getType();
FieldHandler handler;
if (fieldType.isPrimitive()) {
if (fieldType.equals(boolean.class)) {
handler = new BooleanHandler(offset);
} else if (fieldType.equals(byte.class)) {
handler = new ByteHandler(offset);
} else if (fieldType.equals(short.class)) {
handler = new ShortHandler(offset);
} else if (fieldType.equals(char.class)) {
handler = new CharHandler(offset);
} else if (fieldType.equals(int.class)) {
handler = new IntHandler(offset);
} else if (fieldType.equals(long.class)) {
handler = new LongHandler(offset);
} else if (fieldType.equals(float.class)) {
handler = new FloatHandler(offset);
} else if (fieldType.equals(double.class)) {
handler = new DoubleHandler(offset);
} else {
throw new UnsupportedOperationException(
"Unexpected primitive field type " + fieldType + " for " + type);
}
} else if (fieldType.isArray()) {
handler = new ArrayHandler(fieldType, offset);
return fields;
}

private static FieldHandler getHandlerForField(Field field) {
long offset = unsafe().objectFieldOffset(field);
Class<?> fieldType = field.getType();
if (fieldType.isPrimitive()) {
if (fieldType.equals(boolean.class)) {
return new BooleanHandler(offset);
} else if (fieldType.equals(byte.class)) {
return new ByteHandler(offset);
} else if (fieldType.equals(short.class)) {
return new ShortHandler(offset);
} else if (fieldType.equals(char.class)) {
return new CharHandler(offset);
} else if (fieldType.equals(int.class)) {
return new IntHandler(offset);
} else if (fieldType.equals(long.class)) {
return new LongHandler(offset);
} else if (fieldType.equals(float.class)) {
return new FloatHandler(offset);
} else if (fieldType.equals(double.class)) {
return new DoubleHandler(offset);
} else {
handler = new ObjectHandler(fieldType, offset);
throw new UnsupportedOperationException(
"Unexpected primitive field type " + fieldType + " for " + field.getDeclaringClass());
}
handlers[i++] = handler;
} else if (fieldType.isArray()) {
return new ArrayHandler(fieldType, offset);
}
return handlers;
return new ObjectHandler(fieldType, offset);
}

private static final class FieldComparator implements Comparator<Field> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableClassToInstanceMap;
import com.google.devtools.build.lib.skyframe.serialization.DynamicCodec.FieldHandler;
import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Objects;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -353,7 +357,7 @@ static class PrimitiveExample {
@SuppressWarnings("EqualsHashCode") // Testing
@Override
public boolean equals(Object object) {
if (object == null) {
if (!(object instanceof PrimitiveExample)) {
return false;
}
PrimitiveExample that = (PrimitiveExample) object;
Expand Down Expand Up @@ -479,4 +483,88 @@ public Object deserialize(DeserializationContext context, CodedInputStream coded
assertThat(deserialized).isInstanceOf(SpecificObjectWrapper.class);
assertThat(((SpecificObjectWrapper) deserialized).field).isNull();
}

private static class CustomHandlerExample {
private final String text;
private Object tricky;

private CustomHandlerExample(String text, Object tricky) {
this.text = text;
this.tricky = tricky;
}

@SuppressWarnings("EqualsHashCode") // Testing
@Override
public boolean equals(Object other) {
if (!(other instanceof CustomHandlerExample)) {
return false;
}
CustomHandlerExample that = (CustomHandlerExample) other;
return Objects.equals(text, that.text) && Objects.equals(tricky, that.tricky);
}
}

/** An object for testing that can't be serialized. */
private static final Object NOT_SERIALIZABLE =
new Object() {
@Override
public String toString() {
return "not serializable";
}
};

@Test
public void customFieldHandler_counterfactual() throws Exception {
// Verifies that a naive DynamicCodec instance cannot serialize the `NOT_SERIALIZABLE` object.
ObjectCodecs codecs =
new ObjectCodecs(
ObjectCodecRegistry.newBuilder()
.add(new DynamicCodec(CustomHandlerExample.class))
.build());
SerializationException expected =
assertThrows(
SerializationException.class,
() -> codecs.serialize(new CustomHandlerExample("hello", NOT_SERIALIZABLE)));
assertThat(expected).hasMessageThat().contains("No default codec available");
}

@Test
public void customFieldHandler() throws Exception {
LinkedHashMap<Field, FieldHandler> fieldHandlers =
DynamicCodec.getFieldHandlerMap(CustomHandlerExample.class);
Field textField = CustomHandlerExample.class.getDeclaredField("text");
Field trickyField = CustomHandlerExample.class.getDeclaredField("tricky");
assertThat(fieldHandlers.keySet()).containsExactly(textField, trickyField);

// Creates and inserts a custom handler for the field "tricky".
fieldHandlers.put(
trickyField,
new FieldHandler() {
@Override
public void serialize(
SerializationContext context, CodedOutputStream codedOut, Object obj)
throws IOException {
CustomHandlerExample subject = (CustomHandlerExample) obj;
codedOut.writeBoolNoTag(subject.tricky != null);
}

@Override
public void deserialize(
AsyncDeserializationContext context, CodedInputStream codedIn, Object obj)
throws IOException {
if (codedIn.readBool()) {
((CustomHandlerExample) obj).tricky = NOT_SERIALIZABLE;
}
}
});
DynamicCodec customizedCodec =
new DynamicCodec(
CustomHandlerExample.class, fieldHandlers.values().toArray(new FieldHandler[2]));

// The NOT_SERIALIZABLE object round-trips successfully with the custom handler.
new SerializationTester(
new CustomHandlerExample("a", null), new CustomHandlerExample("b ", NOT_SERIALIZABLE))
.addCodec(customizedCodec)
.runTests();
}
}

0 comments on commit 483ed29

Please sign in to comment.