From c0921afa81308f0191f0315a6024601b5fdeb26a Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Tue, 7 Apr 2020 23:44:49 -0700 Subject: [PATCH 1/2] Add method to check if object is generically writeable in stream When calling scripts in metric aggregation, the returned metric state is passed along to the coordinating node to do the final reduce. However, it is possible the object could contain nested state which is unknown to StreamOutput/StreamInput. This would then result in the node crashing as exceptions are not expected in the middle of serialization. This commit adds a method to StreamOutput that can determine if an object is writeable by the stream. It uses the same logic writeGenericValue, special casing each of the supported collection types to recursively determine if each contained value is itself writeable. relates #54708 --- .../common/io/stream/StreamOutput.java | 63 ++++++++++++++----- .../metrics/ScriptedMetricAggregator.java | 2 + 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 7fa616e1c6340..ee46fd5baa58f 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -68,6 +68,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.IntFunction; @@ -792,6 +793,21 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep o.writeLong(zonedDateTime.toInstant().toEpochMilli()); })); + private static Class getGenericType(Object value) { + if (value instanceof List) { + return List.class; + } else if (value instanceof Object[]) { + return Object[].class; + } else if (value instanceof Map) { + return Map.class; + } else if (value instanceof ReadableInstant) { + return ReadableInstant.class; + } else if (value instanceof BytesReference) { + return BytesReference.class; + } else { + return value.getClass(); + } + } /** * Notice: when serialization a map, the stream out map with the stream in map maybe have the * different key-value orders, they will maybe have different stream order. @@ -803,20 +819,7 @@ public void writeGenericValue(@Nullable Object value) throws IOException { writeByte((byte) -1); return; } - final Class type; - if (value instanceof List) { - type = List.class; - } else if (value instanceof Object[]) { - type = Object[].class; - } else if (value instanceof Map) { - type = Map.class; - } else if (value instanceof ReadableInstant) { - type = ReadableInstant.class; - } else if (value instanceof BytesReference) { - type = BytesReference.class; - } else { - type = value.getClass(); - } + final Class type = getGenericType(value); final Writer writer = WRITERS.get(type); if (writer != null) { writer.write(this, value); @@ -825,6 +828,38 @@ public void writeGenericValue(@Nullable Object value) throws IOException { } } + public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException { + if (value == null) { + return; + } + final Class type = getGenericType(value); + + if (type == List.class) { + @SuppressWarnings("unchecked") List list = (List) value; + for (Object v : list) { + checkWriteable(v); + } + } else if (value instanceof Object[]) { + Object[] array = (Object[]) value; + for (Object v : array) { + checkWriteable(v); + } + } else if (value instanceof Map) { + @SuppressWarnings("unchecked") Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + checkWriteable(entry.getKey()); + checkWriteable(entry.getValue()); + } + } else if (value instanceof Set) { + @SuppressWarnings("unchecked") Set set = (Set) value; + for (Object v : set) { + checkWriteable(v); + } + } else if (WRITERS.containsKey(type) == false) { + throw new IllegalArgumentException("Cannot write type [" + type + "] to stream"); + } + } + public void writeIntArray(int[] values) throws IOException { writeVInt(values.length); for (int value : values) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java index 91a73de9aa07c..d88a26a75671c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptedMetricAggContexts; @@ -90,6 +91,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { } else { aggregation = aggState; } + StreamOutput.checkWriteable(aggregation); return new InternalScriptedMetric(name, aggregation, reduceScript, metadata()); } From de5fd366753d20304aca42df09443106f347b23c Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Mon, 20 Apr 2020 11:37:39 -0700 Subject: [PATCH 2/2] add tests --- .../common/io/stream/StreamTests.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java index 9b5013543c555..906f8a36bf52f 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java @@ -445,6 +445,37 @@ public void testGenericSet() throws IOException { assertGenericRoundtrip(new LinkedHashSet<>(list)); } + private static class Unwriteable {} + + private void assertNotWriteable(Object o, Class type) { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> StreamOutput.checkWriteable(o)); + assertThat(e.getMessage(), equalTo("Cannot write type [" + type.getCanonicalName() + "] to stream")); + } + + public void testIsWriteable() throws IOException { + assertNotWriteable(new Unwriteable(), Unwriteable.class); + } + + public void testSetIsWriteable() throws IOException { + StreamOutput.checkWriteable(Set.of("a", "b")); + assertNotWriteable(Set.of(new Unwriteable()), Unwriteable.class); + } + + public void testListIsWriteable() throws IOException { + StreamOutput.checkWriteable(List.of("a", "b")); + assertNotWriteable(List.of(new Unwriteable()), Unwriteable.class); + } + + public void testMapIsWriteable() throws IOException { + StreamOutput.checkWriteable(Map.of("a", "b", "c", "d")); + assertNotWriteable(Map.of("a", new Unwriteable()), Unwriteable.class); + } + + public void testObjectArrayIsWriteable() throws IOException { + StreamOutput.checkWriteable(new Object[] {"a", "b"}); + assertNotWriteable(new Object[] {new Unwriteable()}, Unwriteable.class); + } + private void assertSerialization(CheckedConsumer outputAssertions, CheckedConsumer inputAssertions) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) {