diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java index 1546886b4107c..e88188fe8c2a1 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java @@ -70,6 +70,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -735,6 +736,10 @@ public Object readGenericValue() throws IOException { return readGeoPoint(); case 23: return readZonedDateTime(); + case 24: + return readCollection(StreamInput::readGenericValue, LinkedHashSet::new, Collections.emptySet()); + case 25: + return readCollection(StreamInput::readGenericValue, HashSet::new, Collections.emptySet()); default: throw new IOException("Can't read unknown type [" + type + "]"); } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 603dc15f7cea8..8a93fb19eeaa0 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -70,6 +70,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -776,6 +777,14 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep o.writeString(zoneId.equals("Z") ? DateTimeZone.UTC.getID() : zoneId); o.writeLong(zonedDateTime.toInstant().toEpochMilli()); }); + writers.put(Set.class, (o, v) -> { + if (v instanceof LinkedHashSet) { + o.writeByte((byte) 24); + } else { + o.writeByte((byte) 25); + } + o.writeCollection((Set) v, StreamOutput::writeGenericValue); + }); WRITERS = Collections.unmodifiableMap(writers); } @@ -797,6 +806,8 @@ public void writeGenericValue(@Nullable Object value) throws IOException { type = Object[].class; } else if (value instanceof Map) { type = Map.class; + } else if (value instanceof Set) { + type = Set.class; } else if (value instanceof ReadableInstant) { type = ReadableInstant.class; } else if (value instanceof BytesReference) { diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java index ce0bce03b0335..9b5013543c555 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java @@ -21,6 +21,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -38,6 +39,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -434,4 +436,32 @@ public void testSecureStringSerialization() throws IOException { } } + public void testGenericSet() throws IOException { + Set set = Set.of("a", "b", "c", "d", "e"); + assertGenericRoundtrip(set); + // reverse order in normal set so linked hashset does not match the order + var list = new ArrayList<>(set); + Collections.reverse(list); + assertGenericRoundtrip(new LinkedHashSet<>(list)); + } + + private void assertSerialization(CheckedConsumer outputAssertions, + CheckedConsumer inputAssertions) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + outputAssertions.accept(output); + final BytesReference bytesReference = output.bytes(); + final StreamInput input = bytesReference.streamInput(); + inputAssertions.accept(input); + } + } + + private void assertGenericRoundtrip(Object original) throws IOException { + assertSerialization(output -> { + output.writeGenericValue(original); + }, input -> { + Object read = input.readGenericValue(); + assertThat(read, equalTo(original)); + }); + } + }