Skip to content

Commit

Permalink
Add method to check if object is generically writeable in stream
Browse files Browse the repository at this point in the history
When calling scripts in metric aggregation, the returned metric state is
passed along to the coordinating node to do the final reduce. However,
it is possible the object could contain nested state which is unknown to
StreamOutput/StreamInput. This would then result in the node crashing as
exceptions are not expected in the middle of serialization.

This commit adds a method to StreamOutput that can determine if an
object is writeable by the stream. It uses the same logic
writeGenericValue, special casing each of the supported collection types
to recursively determine if each contained value is itself writeable.

relates elastic#54708
  • Loading branch information
rjernst committed Apr 8, 2020
1 parent 57bd6d2 commit c0921af
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.IntFunction;

Expand Down Expand Up @@ -792,6 +793,21 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep
o.writeLong(zonedDateTime.toInstant().toEpochMilli());
}));

private static Class<?> getGenericType(Object value) {
if (value instanceof List) {
return List.class;
} else if (value instanceof Object[]) {
return Object[].class;
} else if (value instanceof Map) {
return Map.class;
} else if (value instanceof ReadableInstant) {
return ReadableInstant.class;
} else if (value instanceof BytesReference) {
return BytesReference.class;
} else {
return value.getClass();
}
}
/**
* Notice: when serialization a map, the stream out map with the stream in map maybe have the
* different key-value orders, they will maybe have different stream order.
Expand All @@ -803,20 +819,7 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
writeByte((byte) -1);
return;
}
final Class type;
if (value instanceof List) {
type = List.class;
} else if (value instanceof Object[]) {
type = Object[].class;
} else if (value instanceof Map) {
type = Map.class;
} else if (value instanceof ReadableInstant) {
type = ReadableInstant.class;
} else if (value instanceof BytesReference) {
type = BytesReference.class;
} else {
type = value.getClass();
}
final Class<?> type = getGenericType(value);
final Writer writer = WRITERS.get(type);
if (writer != null) {
writer.write(this, value);
Expand All @@ -825,6 +828,38 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
}
}

public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException {
if (value == null) {
return;
}
final Class<?> type = getGenericType(value);

if (type == List.class) {
@SuppressWarnings("unchecked") List<Object> list = (List<Object>) value;
for (Object v : list) {
checkWriteable(v);
}
} else if (value instanceof Object[]) {
Object[] array = (Object[]) value;
for (Object v : array) {
checkWriteable(v);
}
} else if (value instanceof Map) {
@SuppressWarnings("unchecked") Map<String, Object> map = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : map.entrySet()) {
checkWriteable(entry.getKey());
checkWriteable(entry.getValue());
}
} else if (value instanceof Set) {
@SuppressWarnings("unchecked") Set<Object> set = (Set<Object>) value;
for (Object v : set) {
checkWriteable(v);
}
} else if (WRITERS.containsKey(type) == false) {
throw new IllegalArgumentException("Cannot write type [" + type + "] to stream");
}
}

public void writeIntArray(int[] values) throws IOException {
writeVInt(values.length);
for (int value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptedMetricAggContexts;
Expand Down Expand Up @@ -90,6 +91,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
} else {
aggregation = aggState;
}
StreamOutput.checkWriteable(aggregation);
return new InternalScriptedMetric(name, aggregation, reduceScript, metadata());
}

Expand Down

0 comments on commit c0921af

Please sign in to comment.