From fed296ebb755d37e09e6672a828347da8511d799 Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Tue, 28 Apr 2020 13:08:41 -0700 Subject: [PATCH] Add method to check if object is generically writeable in stream (#54936) (#55561) When calling scripts in metric aggregation, the returned metric state is passed along to the coordinating node to do the final reduce. However, it is possible the object could contain nested state which is unknown to StreamOutput/StreamInput. This would then result in the node crashing as exceptions are not expected in the middle of serialization. This commit adds a method to StreamOutput that can determine if an object is writeable by the stream. It uses the same logic writeGenericValue, special casing each of the supported collection types to recursively determine if each contained value is itself writeable. relates #54708 --- .../common/io/stream/StreamOutput.java | 66 ++++++++++++++----- .../metrics/ScriptedMetricAggregator.java | 2 + .../common/io/stream/StreamTests.java | 35 ++++++++++ 3 files changed, 87 insertions(+), 16 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 8a93fb19eeaa0..b08aea9c963ec 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -788,6 +788,23 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep WRITERS = Collections.unmodifiableMap(writers); } + private static Class getGenericType(Object value) { + if (value instanceof List) { + return List.class; + } else if (value instanceof Object[]) { + return Object[].class; + } else if (value instanceof Map) { + return Map.class; + } else if (value instanceof Set) { + return Set.class; + } else if (value instanceof ReadableInstant) { + return ReadableInstant.class; + } else if (value instanceof BytesReference) { + return BytesReference.class; + } else { + return value.getClass(); + } + } /** * Notice: when serialization a map, the stream out map with the stream in map maybe have the * different key-value orders, they will maybe have different stream order. @@ -799,22 +816,7 @@ public void writeGenericValue(@Nullable Object value) throws IOException { writeByte((byte) -1); return; } - final Class type; - if (value instanceof List) { - type = List.class; - } else if (value instanceof Object[]) { - type = Object[].class; - } else if (value instanceof Map) { - type = Map.class; - } else if (value instanceof Set) { - type = Set.class; - } else if (value instanceof ReadableInstant) { - type = ReadableInstant.class; - } else if (value instanceof BytesReference) { - type = BytesReference.class; - } else { - type = value.getClass(); - } + final Class type = getGenericType(value); final Writer writer = WRITERS.get(type); if (writer != null) { writer.write(this, value); @@ -823,6 +825,38 @@ public void writeGenericValue(@Nullable Object value) throws IOException { } } + public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException { + if (value == null) { + return; + } + final Class type = getGenericType(value); + + if (type == List.class) { + @SuppressWarnings("unchecked") List list = (List) value; + for (Object v : list) { + checkWriteable(v); + } + } else if (type == Object[].class) { + Object[] array = (Object[]) value; + for (Object v : array) { + checkWriteable(v); + } + } else if (type == Map.class) { + @SuppressWarnings("unchecked") Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + checkWriteable(entry.getKey()); + checkWriteable(entry.getValue()); + } + } else if (type == Set.class) { + @SuppressWarnings("unchecked") Set set = (Set) value; + for (Object v : set) { + checkWriteable(v); + } + } else if (WRITERS.containsKey(type) == false) { + throw new IllegalArgumentException("Cannot write type [" + type.getCanonicalName() + "] to stream"); + } + } + public void writeIntArray(int[] values) throws IOException { writeVInt(values.length); for (int value : values) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java index 91a73de9aa07c..d88a26a75671c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptedMetricAggContexts; @@ -90,6 +91,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { } else { aggregation = aggState; } + StreamOutput.checkWriteable(aggregation); return new InternalScriptedMetric(name, aggregation, reduceScript, metadata()); } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java index 56a1ce73289f7..62dab6025c8f2 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java @@ -37,6 +37,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -445,6 +446,40 @@ public void testGenericSet() throws IOException { assertGenericRoundtrip(new LinkedHashSet<>(list)); } + private static class Unwriteable {} + + private void assertNotWriteable(Object o, Class type) { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> StreamOutput.checkWriteable(o)); + assertThat(e.getMessage(), equalTo("Cannot write type [" + type.getCanonicalName() + "] to stream")); + } + + public void testIsWriteable() throws IOException { + assertNotWriteable(new Unwriteable(), Unwriteable.class); + } + + public void testSetIsWriteable() throws IOException { + StreamOutput.checkWriteable(new HashSet<>(Arrays.asList("a", "b"))); + assertNotWriteable(Collections.singleton(new Unwriteable()), Unwriteable.class); + } + + public void testListIsWriteable() throws IOException { + StreamOutput.checkWriteable(Arrays.asList("a", "b")); + assertNotWriteable(Collections.singletonList(new Unwriteable()), Unwriteable.class); + } + + public void testMapIsWriteable() throws IOException { + Map goodMap = new HashMap<>(); + goodMap.put("a", "b"); + goodMap.put("c", "d"); + StreamOutput.checkWriteable(goodMap); + assertNotWriteable(Collections.singletonMap("a", new Unwriteable()), Unwriteable.class); + } + + public void testObjectArrayIsWriteable() throws IOException { + StreamOutput.checkWriteable(new Object[] {"a", "b"}); + assertNotWriteable(new Object[] {new Unwriteable()}, Unwriteable.class); + } + private void assertSerialization(CheckedConsumer outputAssertions, CheckedConsumer inputAssertions) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) {