Skip to content

Commit

Permalink
Add method to check if object is generically writeable in stream (#54936
Browse files Browse the repository at this point in the history
) (#55561)

When calling scripts in metric aggregation, the returned metric state is
passed along to the coordinating node to do the final reduce. However,
it is possible the object could contain nested state which is unknown to
StreamOutput/StreamInput. This would then result in the node crashing as
exceptions are not expected in the middle of serialization.

This commit adds a method to StreamOutput that can determine if an
object is writeable by the stream. It uses the same logic
writeGenericValue, special casing each of the supported collection types
to recursively determine if each contained value is itself writeable.

relates #54708
  • Loading branch information
rjernst authored Apr 28, 2020
1 parent 9e37658 commit fed296e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,23 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep
WRITERS = Collections.unmodifiableMap(writers);
}

private static Class<?> getGenericType(Object value) {
if (value instanceof List) {
return List.class;
} else if (value instanceof Object[]) {
return Object[].class;
} else if (value instanceof Map) {
return Map.class;
} else if (value instanceof Set) {
return Set.class;
} else if (value instanceof ReadableInstant) {
return ReadableInstant.class;
} else if (value instanceof BytesReference) {
return BytesReference.class;
} else {
return value.getClass();
}
}
/**
* Notice: when serialization a map, the stream out map with the stream in map maybe have the
* different key-value orders, they will maybe have different stream order.
Expand All @@ -799,22 +816,7 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
writeByte((byte) -1);
return;
}
final Class type;
if (value instanceof List) {
type = List.class;
} else if (value instanceof Object[]) {
type = Object[].class;
} else if (value instanceof Map) {
type = Map.class;
} else if (value instanceof Set) {
type = Set.class;
} else if (value instanceof ReadableInstant) {
type = ReadableInstant.class;
} else if (value instanceof BytesReference) {
type = BytesReference.class;
} else {
type = value.getClass();
}
final Class<?> type = getGenericType(value);
final Writer writer = WRITERS.get(type);
if (writer != null) {
writer.write(this, value);
Expand All @@ -823,6 +825,38 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
}
}

public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException {
if (value == null) {
return;
}
final Class<?> type = getGenericType(value);

if (type == List.class) {
@SuppressWarnings("unchecked") List<Object> list = (List<Object>) value;
for (Object v : list) {
checkWriteable(v);
}
} else if (type == Object[].class) {
Object[] array = (Object[]) value;
for (Object v : array) {
checkWriteable(v);
}
} else if (type == Map.class) {
@SuppressWarnings("unchecked") Map<String, Object> map = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : map.entrySet()) {
checkWriteable(entry.getKey());
checkWriteable(entry.getValue());
}
} else if (type == Set.class) {
@SuppressWarnings("unchecked") Set<Object> set = (Set<Object>) value;
for (Object v : set) {
checkWriteable(v);
}
} else if (WRITERS.containsKey(type) == false) {
throw new IllegalArgumentException("Cannot write type [" + type.getCanonicalName() + "] to stream");
}
}

public void writeIntArray(int[] values) throws IOException {
writeVInt(values.length);
for (int value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptedMetricAggContexts;
Expand Down Expand Up @@ -90,6 +91,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
} else {
aggregation = aggState;
}
StreamOutput.checkWriteable(aggregation);
return new InternalScriptedMetric(name, aggregation, reduceScript, metadata());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -445,6 +446,40 @@ public void testGenericSet() throws IOException {
assertGenericRoundtrip(new LinkedHashSet<>(list));
}

private static class Unwriteable {}

private void assertNotWriteable(Object o, Class<?> type) {
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> StreamOutput.checkWriteable(o));
assertThat(e.getMessage(), equalTo("Cannot write type [" + type.getCanonicalName() + "] to stream"));
}

public void testIsWriteable() throws IOException {
assertNotWriteable(new Unwriteable(), Unwriteable.class);
}

public void testSetIsWriteable() throws IOException {
StreamOutput.checkWriteable(new HashSet<>(Arrays.asList("a", "b")));
assertNotWriteable(Collections.singleton(new Unwriteable()), Unwriteable.class);
}

public void testListIsWriteable() throws IOException {
StreamOutput.checkWriteable(Arrays.asList("a", "b"));
assertNotWriteable(Collections.singletonList(new Unwriteable()), Unwriteable.class);
}

public void testMapIsWriteable() throws IOException {
Map<String, Object> goodMap = new HashMap<>();
goodMap.put("a", "b");
goodMap.put("c", "d");
StreamOutput.checkWriteable(goodMap);
assertNotWriteable(Collections.singletonMap("a", new Unwriteable()), Unwriteable.class);
}

public void testObjectArrayIsWriteable() throws IOException {
StreamOutput.checkWriteable(new Object[] {"a", "b"});
assertNotWriteable(new Object[] {new Unwriteable()}, Unwriteable.class);
}

private void assertSerialization(CheckedConsumer<StreamOutput, IOException> outputAssertions,
CheckedConsumer<StreamInput, IOException> inputAssertions) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
Expand Down

0 comments on commit fed296e

Please sign in to comment.