diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java index 6927fa71c084d9..d5f8691c8ec4af 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java @@ -29,6 +29,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; /** A codec that serializes arbitrary types. */ public final class DynamicCodec extends AsyncObjectCodec { @@ -39,8 +41,13 @@ public final class DynamicCodec extends AsyncObjectCodec { private final FieldHandler[] handlers; public DynamicCodec(Class type) { + this(type, getFieldHandlers(type)); + } + + @SuppressWarnings("AvoidObjectArrays") // less overhead + public DynamicCodec(Class type, FieldHandler[] handlers) { this.type = type; - this.handlers = getFieldHandlers(type); + this.handlers = handlers; } @Override @@ -91,7 +98,7 @@ public Object deserializeAsync(AsyncDeserializationContext context, CodedInputSt } /** Handles serialization of a field. */ - private interface FieldHandler { + public interface FieldHandler { void serialize(SerializationContext context, CodedOutputStream codedOut, Object obj) throws SerializationException, IOException; @@ -99,6 +106,21 @@ void deserialize(AsyncDeserializationContext context, CodedInputStream codedIn, throws SerializationException, IOException; } + /** + * Computes the default {@link FieldHandler}s that would be used for the given type. + * + *

The entries are ordered by {@link FieldComparator} for determinism. The returned value is a + * fresh copy that the caller may freely modify. + */ + @SuppressWarnings("NonApiType") // type communicates fixed ordering + public static LinkedHashMap getFieldHandlerMap(Class type) { + LinkedHashMap handlers = new LinkedHashMap<>(); + for (Field field : getSerializableFields(type)) { + handlers.put(field, getHandlerForField(field)); + } + return handlers; + } + private static final class BooleanHandler implements FieldHandler { private final long offset; @@ -331,8 +353,17 @@ public void deserialize( } private static FieldHandler[] getFieldHandlers(Class type) { - // NB: it's tempting to try to simplify this by ordering by offset, but it looks like offsets - // are not guaranteed to be stable, which is needed for deterministic serialization. + List fields = getSerializableFields(type); + + FieldHandler[] handlers = new FieldHandler[fields.size()]; + int i = 0; + for (Field field : fields) { + handlers[i++] = getHandlerForField(field); + } + return handlers; + } + + private static List getSerializableFields(Class type) { ArrayList fields = new ArrayList<>(); for (Class next = type; next != null; next = next.getSuperclass()) { for (Field field : next.getDeclaredFields()) { @@ -342,42 +373,40 @@ private static FieldHandler[] getFieldHandlers(Class type) { fields.add(field); } } + // NB: it's tempting to try to simplify this by ordering by offset, but it looks like offsets + // are not guaranteed to be stable, which is needed for deterministic serialization. Collections.sort(fields, new FieldComparator()); - FieldHandler[] handlers = new FieldHandler[fields.size()]; - int i = 0; - for (Field field : fields) { - long offset = unsafe().objectFieldOffset(field); - Class fieldType = field.getType(); - FieldHandler handler; - if (fieldType.isPrimitive()) { - if (fieldType.equals(boolean.class)) { - handler = new BooleanHandler(offset); - } else if (fieldType.equals(byte.class)) { - handler = new ByteHandler(offset); - } else if (fieldType.equals(short.class)) { - handler = new ShortHandler(offset); - } else if (fieldType.equals(char.class)) { - handler = new CharHandler(offset); - } else if (fieldType.equals(int.class)) { - handler = new IntHandler(offset); - } else if (fieldType.equals(long.class)) { - handler = new LongHandler(offset); - } else if (fieldType.equals(float.class)) { - handler = new FloatHandler(offset); - } else if (fieldType.equals(double.class)) { - handler = new DoubleHandler(offset); - } else { - throw new UnsupportedOperationException( - "Unexpected primitive field type " + fieldType + " for " + type); - } - } else if (fieldType.isArray()) { - handler = new ArrayHandler(fieldType, offset); + return fields; + } + + private static FieldHandler getHandlerForField(Field field) { + long offset = unsafe().objectFieldOffset(field); + Class fieldType = field.getType(); + if (fieldType.isPrimitive()) { + if (fieldType.equals(boolean.class)) { + return new BooleanHandler(offset); + } else if (fieldType.equals(byte.class)) { + return new ByteHandler(offset); + } else if (fieldType.equals(short.class)) { + return new ShortHandler(offset); + } else if (fieldType.equals(char.class)) { + return new CharHandler(offset); + } else if (fieldType.equals(int.class)) { + return new IntHandler(offset); + } else if (fieldType.equals(long.class)) { + return new LongHandler(offset); + } else if (fieldType.equals(float.class)) { + return new FloatHandler(offset); + } else if (fieldType.equals(double.class)) { + return new DoubleHandler(offset); } else { - handler = new ObjectHandler(fieldType, offset); + throw new UnsupportedOperationException( + "Unexpected primitive field type " + fieldType + " for " + field.getDeclaringClass()); } - handlers[i++] = handler; + } else if (fieldType.isArray()) { + return new ArrayHandler(fieldType, offset); } - return handlers; + return new ObjectHandler(fieldType, offset); } private static final class FieldComparator implements Comparator { diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java index 8a7c8897aaf296..a0a3eecc866645 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java @@ -18,12 +18,16 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableClassToInstanceMap; +import com.google.devtools.build.lib.skyframe.serialization.DynamicCodec.FieldHandler; import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester; import com.google.protobuf.ByteString; import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import java.io.BufferedInputStream; +import java.io.IOException; +import java.lang.reflect.Field; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.Objects; import org.junit.Test; import org.junit.runner.RunWith; @@ -353,7 +357,7 @@ static class PrimitiveExample { @SuppressWarnings("EqualsHashCode") // Testing @Override public boolean equals(Object object) { - if (object == null) { + if (!(object instanceof PrimitiveExample)) { return false; } PrimitiveExample that = (PrimitiveExample) object; @@ -479,4 +483,88 @@ public Object deserialize(DeserializationContext context, CodedInputStream coded assertThat(deserialized).isInstanceOf(SpecificObjectWrapper.class); assertThat(((SpecificObjectWrapper) deserialized).field).isNull(); } + + private static class CustomHandlerExample { + private final String text; + private Object tricky; + + private CustomHandlerExample(String text, Object tricky) { + this.text = text; + this.tricky = tricky; + } + + @SuppressWarnings("EqualsHashCode") // Testing + @Override + public boolean equals(Object other) { + if (!(other instanceof CustomHandlerExample)) { + return false; + } + CustomHandlerExample that = (CustomHandlerExample) other; + return Objects.equals(text, that.text) && Objects.equals(tricky, that.tricky); + } + } + + /** An object for testing that can't be serialized. */ + private static final Object NOT_SERIALIZABLE = + new Object() { + @Override + public String toString() { + return "not serializable"; + } + }; + + @Test + public void customFieldHandler_counterfactual() throws Exception { + // Verifies that a naive DynamicCodec instance cannot serialize the `NOT_SERIALIZABLE` object. + ObjectCodecs codecs = + new ObjectCodecs( + ObjectCodecRegistry.newBuilder() + .add(new DynamicCodec(CustomHandlerExample.class)) + .build()); + SerializationException expected = + assertThrows( + SerializationException.class, + () -> codecs.serialize(new CustomHandlerExample("hello", NOT_SERIALIZABLE))); + assertThat(expected).hasMessageThat().contains("No default codec available"); + } + + @Test + public void customFieldHandler() throws Exception { + LinkedHashMap fieldHandlers = + DynamicCodec.getFieldHandlerMap(CustomHandlerExample.class); + Field textField = CustomHandlerExample.class.getDeclaredField("text"); + Field trickyField = CustomHandlerExample.class.getDeclaredField("tricky"); + assertThat(fieldHandlers.keySet()).containsExactly(textField, trickyField); + + // Creates and inserts a custom handler for the field "tricky". + fieldHandlers.put( + trickyField, + new FieldHandler() { + @Override + public void serialize( + SerializationContext context, CodedOutputStream codedOut, Object obj) + throws IOException { + CustomHandlerExample subject = (CustomHandlerExample) obj; + codedOut.writeBoolNoTag(subject.tricky != null); + } + + @Override + public void deserialize( + AsyncDeserializationContext context, CodedInputStream codedIn, Object obj) + throws IOException { + if (codedIn.readBool()) { + ((CustomHandlerExample) obj).tricky = NOT_SERIALIZABLE; + } + } + }); + DynamicCodec customizedCodec = + new DynamicCodec( + CustomHandlerExample.class, fieldHandlers.values().toArray(new FieldHandler[2])); + + // The NOT_SERIALIZABLE object round-trips successfully with the custom handler. + new SerializationTester( + new CustomHandlerExample("a", null), new CustomHandlerExample("b ", NOT_SERIALIZABLE)) + .addCodec(customizedCodec) + .runTests(); + } }