Skip to content

Commit

Permalink
[Java] fix nested collection num elements (#1306)
Browse files Browse the repository at this point in the history
Closes #1305 

Nested collection outer deserialization may read `num elements` for
internal collection, which will make deserialization fail.
Currently only `StringKeyMapSerializer` has this bug, but future
customized may have this issue. This PR make `num elements` defensive
from such cases by using a read and reset pattern.
  • Loading branch information
chaokunyang authored Jan 5, 2024
1 parent d2bf62e commit a9057eb
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ protected Expression deserializeForCollection(
}
Invoke supportHook = inlineInvoke(serializer, "supportCodegenHook", PRIMITIVE_BOOLEAN_TYPE);
Expression collection = new Invoke(serializer, "newCollection", COLLECTION_TYPE, buffer);
Expression size = new Invoke(serializer, "getNumElements", "size", PRIMITIVE_INT_TYPE);
Expression size = new Invoke(serializer, "getAndClearNumElements", "size", PRIMITIVE_INT_TYPE);
// if add branch by `ArrayList`, generated code will be > 325 bytes.
// and List#add is more likely be inlined if there is only one subclass.
Expression hookRead = readCollectionCodegen(buffer, collection, size, elementType);
Expand Down Expand Up @@ -1442,7 +1442,7 @@ protected Expression deserializeForMap(
}
Invoke supportHook = inlineInvoke(serializer, "supportCodegenHook", PRIMITIVE_BOOLEAN_TYPE);
Expression newMap = new Invoke(serializer, "newMap", MAP_TYPE, buffer);
Expression size = new Invoke(serializer, "getNumElements", "size", PRIMITIVE_INT_TYPE);
Expression size = new Invoke(serializer, "getAndClearNumElements", "size", PRIMITIVE_INT_TYPE);
Expression start = new Literal(0, PRIMITIVE_INT_TYPE);
Expression step = new Literal(1, PRIMITIVE_INT_TYPE);
ExprHolder exprHolder = ExprHolder.of("map", newMap, "buffer", buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@SuppressWarnings({"unchecked", "rawtypes"})
public abstract class AbstractCollectionSerializer<T> extends Serializer<T> {
private MethodHandle constructor;
protected int numElements;
private int numElements;
private final boolean supportCodegenHook;
// TODO remove elemSerializer, support generics in CompatibleSerializer.
private Serializer<?> elemSerializer;
Expand Down Expand Up @@ -507,9 +507,19 @@ public Collection newCollection(MemoryBuffer buffer) {
}
}

/** Get numElements of deserializing collection. Should be called after {@link #newCollection}. */
public int getNumElements() {
return numElements;
/**
* Get and reset numElements of deserializing collection. Should be called after {@link
* #newCollection}. Nested read may overwrite this element, reset is necessary to avoid use wrong
* value by mistake.
*/
public int getAndClearNumElements() {
int size = numElements;
numElements = -1; // nested read may overwrite this element.
return size;
}

protected void setNumElements(int numElements) {
this.numElements = numElements;
}

public abstract T onCollectionRead(Collection collection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public abstract class AbstractMapSerializer<T> extends Serializer<T> {
// field. So we will write those extra kv classes to keep protocol consistency between
// interpreter and jit mode although it seems unnecessary.
// With kv header in future, we can write this kv classes only once, the cost won't be too much.
protected int numElements;
private int numElements;

public AbstractMapSerializer(Fury fury, Class<T> cls) {
this(fury, cls, !ReflectionUtils.isDynamicGeneratedCLass(cls));
Expand Down Expand Up @@ -718,9 +718,18 @@ public Map newMap(MemoryBuffer buffer) {
}
}

/** Get numElements of deserializing collection. Should be called after {@link #newMap}. */
public int getNumElements() {
return numElements;
/**
* Get and reset numElements of deserializing collection. Should be called after {@link #newMap}.
* Nested read may overwrite this element, reset is necessary to avoid use wrong value by mistake.
*/
public int getAndClearNumElements() {
int size = numElements;
numElements = -1; // nested read may overwrite this element.
return size;
}

public void setNumElements(int numElements) {
this.numElements = numElements;
}

public abstract T onMapRead(Map map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ public ChildArrayListSerializer(Fury fury, Class<T> cls) {
@Override
public T newCollection(MemoryBuffer buffer) {
T collection = (T) super.newCollection(buffer);
int numElements = getAndClearNumElements();
setNumElements(numElements);
collection.ensureCapacity(numElements);
return collection;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public T onCollectionRead(Collection collection) {
@Override
public T read(MemoryBuffer buffer) {
Collection collection = newCollection(buffer);
int numElements = getAndClearNumElements();
if (numElements != 0) {
readElements(fury, buffer, collection, numElements);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public short getXtypeId() {

@Override
public ArrayList newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
ArrayList arrayList = new ArrayList(numElements);
fury.getRefResolver().reference(arrayList);
return arrayList;
Expand Down Expand Up @@ -147,7 +148,8 @@ public short getXtypeId() {

@Override
public HashSet newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
HashSet hashSet = new HashSet(numElements);
fury.getRefResolver().reference(hashSet);
return hashSet;
Expand All @@ -166,7 +168,8 @@ public short getXtypeId() {

@Override
public LinkedHashSet newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
LinkedHashSet hashSet = new LinkedHashSet(numElements);
fury.getRefResolver().reference(hashSet);
return hashSet;
Expand Down Expand Up @@ -197,7 +200,8 @@ public Collection onCollectionWrite(MemoryBuffer buffer, T value) {
@SuppressWarnings("unchecked")
@Override
public T newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
T collection;
Comparator comparator = (Comparator) fury.readRef(buffer);
if (type == TreeSet.class) {
Expand Down Expand Up @@ -376,7 +380,8 @@ public ConcurrentSkipListSetSerializer(Fury fury, Class<ConcurrentSkipListSet> c

@Override
public ConcurrentSkipListSet newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Comparator comparator = (Comparator) fury.readRef(buffer);
ConcurrentSkipListSet skipListSet = new ConcurrentSkipListSet(comparator);
fury.getRefResolver().reference(skipListSet);
Expand All @@ -392,7 +397,8 @@ public VectorSerializer(Fury fury, Class<Vector> cls) {

@Override
public Vector newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Vector<Object> vector = new Vector<>(numElements);
fury.getRefResolver().reference(vector);
return vector;
Expand All @@ -407,7 +413,8 @@ public ArrayDequeSerializer(Fury fury, Class<ArrayDeque> cls) {

@Override
public ArrayDeque newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
ArrayDeque deque = new ArrayDeque(numElements);
fury.getRefResolver().reference(deque);
return deque;
Expand Down Expand Up @@ -486,7 +493,8 @@ public Collection onCollectionWrite(MemoryBuffer buffer, PriorityQueue value) {

@Override
public PriorityQueue newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Comparator comparator = (Comparator) fury.readRef(buffer);
PriorityQueue queue = new PriorityQueue(comparator);
fury.getRefResolver().reference(queue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public short getXtypeId() {
}

public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
return new ArrayAsList(numElements);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ public ImmutableListSerializer(Fury fury, Class<T> cls) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
return new CollectionContainer<>(numElements);
}

Expand Down Expand Up @@ -124,7 +125,8 @@ public RegularImmutableListSerializer(Fury fury, Class<T> cls) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
return new CollectionContainer(numElements);
}

Expand Down Expand Up @@ -156,7 +158,8 @@ public ImmutableSetSerializer(Fury fury, Class<T> cls) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
return new CollectionContainer<>(numElements);
}

Expand Down Expand Up @@ -195,7 +198,8 @@ public Collection onCollectionWrite(MemoryBuffer buffer, T value) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Comparator comparator = (Comparator) fury.readRef(buffer);
return new SortedCollectionContainer(comparator, numElements);
}
Expand All @@ -221,7 +225,9 @@ public GuavaMapSerializer(Fury fury, Class<T> cls) {

@Override
public Map newMap(MemoryBuffer buffer) {
return new MapContainer(numElements = buffer.readPositiveVarInt());
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
return new MapContainer(numElements);
}

@Override
Expand Down Expand Up @@ -333,7 +339,8 @@ public Map onMapWrite(MemoryBuffer buffer, T value) {

@Override
public Map newMap(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Comparator comparator = (Comparator) fury.readRef(buffer);
return new SortedMapContainer<>(comparator, numElements);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ public ImmutableListSerializer(Fury fury, Class cls) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
if (Platform.JAVA_VERSION > 8) {
return new CollectionContainer<>(numElements);
} else {
Expand Down Expand Up @@ -141,7 +142,8 @@ public ImmutableSetSerializer(Fury fury, Class cls) {

@Override
public Collection newCollection(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
if (Platform.JAVA_VERSION > 8) {
return new CollectionContainer<>(numElements);
} else {
Expand Down Expand Up @@ -173,7 +175,8 @@ public ImmutableMapSerializer(Fury fury, Class cls) {

@Override
public Map newMap(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
if (Platform.JAVA_VERSION > 8) {
return new JDKImmutableMapContainer(numElements);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public Map onMapWrite(MemoryBuffer buffer, T value) {
@Override
public T read(MemoryBuffer buffer) {
Map map = newMap(buffer);
readElements(buffer, numElements, map);
readElements(buffer, getAndClearNumElements(), map);
return onMapRead(map);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ public short getXtypeId() {

@Override
public HashMap newMap(MemoryBuffer buffer) {
HashMap hashMap = new HashMap(numElements = buffer.readPositiveVarInt());
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
HashMap hashMap = new HashMap(numElements);
fury.getRefResolver().reference(hashMap);
return hashMap;
}
Expand All @@ -81,7 +83,9 @@ public short getXtypeId() {

@Override
public LinkedHashMap newMap(MemoryBuffer buffer) {
LinkedHashMap hashMap = new LinkedHashMap(numElements = buffer.readPositiveVarInt());
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
LinkedHashMap hashMap = new LinkedHashMap(numElements);
fury.getRefResolver().reference(hashMap);
return hashMap;
}
Expand All @@ -99,7 +103,9 @@ public short getXtypeId() {

@Override
public LazyMap newMap(MemoryBuffer buffer) {
LazyMap map = new LazyMap(numElements = buffer.readPositiveVarInt());
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
LazyMap map = new LazyMap(numElements);
fury.getRefResolver().reference(map);
return map;
}
Expand All @@ -124,7 +130,7 @@ public Map onMapWrite(MemoryBuffer buffer, T value) {
@SuppressWarnings("unchecked")
@Override
public Map newMap(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
setNumElements(buffer.readPositiveVarInt());
T map;
Comparator comparator = (Comparator) fury.readRef(buffer);
if (type == TreeMap.class) {
Expand Down Expand Up @@ -237,7 +243,9 @@ public ConcurrentHashMapSerializer(Fury fury, Class<ConcurrentHashMap> type) {

@Override
public ConcurrentHashMap newMap(MemoryBuffer buffer) {
ConcurrentHashMap map = new ConcurrentHashMap(numElements = buffer.readPositiveVarInt());
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
ConcurrentHashMap map = new ConcurrentHashMap(numElements);
fury.getRefResolver().reference(map);
return map;
}
Expand All @@ -257,7 +265,8 @@ public ConcurrentSkipListMapSerializer(Fury fury, Class<ConcurrentSkipListMap> c

@Override
public ConcurrentSkipListMap newMap(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
int numElements = buffer.readPositiveVarInt();
setNumElements(numElements);
Comparator comparator = (Comparator) fury.readRef(buffer);
ConcurrentSkipListMap map = new ConcurrentSkipListMap(comparator);
fury.getRefResolver().reference(map);
Expand Down Expand Up @@ -298,7 +307,7 @@ public Map onMapWrite(MemoryBuffer buffer, EnumMap value) {

@Override
public EnumMap newMap(MemoryBuffer buffer) {
numElements = buffer.readPositiveVarInt();
setNumElements(buffer.readPositiveVarInt());
Class<?> keyType = fury.getClassResolver().readClassInfo(buffer).getCls();
return new EnumMap(keyType);
}
Expand All @@ -325,6 +334,7 @@ public void write(MemoryBuffer buffer, Map<String, T> value) {
@Override
public Map<String, T> read(MemoryBuffer buffer) {
Map map = newMap(buffer);
int numElements = getAndClearNumElements();
for (int i = 0; i < numElements; i++) {
map.put(fury.readJavaStringRef(buffer), fury.readRef(buffer));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void testChildCollection(Fury fury) {
ChildArrayList<Integer> list = new ChildArrayList<>();
list.addAll(data);
list.state = 3;
ChildArrayList<Integer> newList = (ChildArrayList) serDe(fury, list);
ChildArrayList<Integer> newList = serDe(fury, list);
Assert.assertEquals(newList, list);
Assert.assertEquals(newList.state, 3);
Assert.assertEquals(
Expand Down
Loading

0 comments on commit a9057eb

Please sign in to comment.