From 6b6c3670c2da3d8e8f886ae9c5ee0836ef16eb68 Mon Sep 17 00:00:00 2001 From: Sylvain Wallez Date: Mon, 21 Oct 2024 19:48:34 +0200 Subject: [PATCH] ESQL: Add support for multivalue fields in Arrow output (#114774) --- docs/changelog/114774.yaml | 5 + x-pack/plugin/esql/arrow/build.gradle | 1 + .../xpack/esql/arrow/ArrowResponse.java | 74 +++-- .../xpack/esql/arrow/BlockConverter.java | 214 ++++++++++----- .../xpack/esql/arrow/ArrowResponseTests.java | 252 +++++++++++++++--- 5 files changed, 431 insertions(+), 115 deletions(-) create mode 100644 docs/changelog/114774.yaml diff --git a/docs/changelog/114774.yaml b/docs/changelog/114774.yaml new file mode 100644 index 0000000000000..1becfe427fda0 --- /dev/null +++ b/docs/changelog/114774.yaml @@ -0,0 +1,5 @@ +pr: 114774 +summary: "ESQL: Add support for multivalue fields in Arrow output" +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/arrow/build.gradle b/x-pack/plugin/esql/arrow/build.gradle index 20c877a12bf0d..fac0bd0a77452 100644 --- a/x-pack/plugin/esql/arrow/build.gradle +++ b/x-pack/plugin/esql/arrow/build.gradle @@ -26,6 +26,7 @@ dependencies { testImplementation project(':test:framework') testImplementation('org.apache.arrow:arrow-memory-unsafe:16.1.0') + testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:${versions.jackson}") } tasks.named("dependencyLicenses").configure { diff --git a/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/ArrowResponse.java b/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/ArrowResponse.java index 7a8328060a390..208d3308d508b 100644 --- a/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/ArrowResponse.java +++ b/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/ArrowResponse.java @@ -17,6 +17,7 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; @@ -44,6 +45,7 @@ public class ArrowResponse implements ChunkedRestResponseBodyPart, Releasable { public static class Column { private final BlockConverter converter; private final String name; + private boolean multivalued; public Column(String esqlType, String name) { this.converter = ESQL_CONVERTERS.get(esqlType); @@ -61,20 +63,24 @@ public Column(String esqlType, String name) { public ArrowResponse(List columns, List pages) { this.columns = columns; + // Find multivalued columns + int colSize = columns.size(); + for (int col = 0; col < colSize; col++) { + for (Page page : pages) { + if (page.getBlock(col).mayHaveMultivaluedFields()) { + columns.get(col).multivalued = true; + break; + } + } + } + currentSegment = new SchemaResponse(this); List rest = new ArrayList<>(pages.size()); - for (int p = 0; p < pages.size(); p++) { - var page = pages.get(p); + + for (Page page : pages) { rest.add(new PageResponse(this, page)); - // Multivalued fields are not supported yet. - for (int b = 0; b < page.getBlockCount(); b++) { - if (page.getBlock(b).mayHaveMultivaluedFields()) { - throw new IllegalArgumentException( - "ES|QL response field [" + columns.get(b).name + "] is multi-valued. This isn't supported yet by the Arrow format" - ); - } - } } + rest.add(new EndResponse(this)); segments = rest.iterator(); } @@ -185,6 +191,9 @@ public void close() {} * @see IPC Streaming Format */ private static class SchemaResponse extends ResponseSegment { + + private static final FieldType LIST_FIELD_TYPE = FieldType.nullable(MinorType.LIST.getType()); + private boolean done = false; SchemaResponse(ArrowResponse response) { @@ -204,7 +213,20 @@ protected void encodeChunk(int sizeHint, RecyclerBytesStreamOutput out) throws I } private Schema arrowSchema() { - return new Schema(response.columns.stream().map(c -> new Field(c.name, c.converter.arrowFieldType(), List.of())).toList()); + return new Schema(response.columns.stream().map(c -> { + var fieldType = c.converter.arrowFieldType(); + if (c.multivalued) { + // A variable-sized list is a vector of offsets and a child vector of values + // See https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout + var listType = new FieldType(true, LIST_FIELD_TYPE.getType(), null, fieldType.getMetadata()); + // Value vector is non-nullable (ES|QL multivalues cannot contain nulls). + var valueType = new FieldType(false, fieldType.getType(), fieldType.getDictionary(), null); + // The nested vector is named "$data$", following what the Arrow/Java library does. + return new Field(c.name, listType, List.of(new Field("$data$", valueType, null))); + } else { + return new Field(c.name, fieldType, null); + } + }).toList()); } } @@ -257,7 +279,14 @@ protected void encodeChunk(int sizeHint, RecyclerBytesStreamOutput out) throws I @Override public void write(ArrowBuf buffer) throws IOException { - extraPosition += bufWriters.get(bufIdx++).write(out); + var len = bufWriters.get(bufIdx++).write(out); + // Consistency check + if (len != buffer.writerIndex()) { + throw new IllegalStateException( + "Buffer [" + (bufIdx - 1) + "]: wrote [" + len + "] bytes, but expected [" + buffer.writerIndex() + "]" + ); + } + extraPosition += len; } @Override @@ -277,11 +306,26 @@ public long align() throws IOException { // Create Arrow buffers for each of the blocks in this page for (int b = 0; b < page.getBlockCount(); b++) { - var converter = response.columns.get(b).converter; + var column = response.columns.get(b); + var converter = column.converter; Block block = page.getBlock(b); - nodes.add(new ArrowFieldNode(block.getPositionCount(), converter.nullValuesCount(block))); - converter.convert(block, bufs, bufWriters); + if (column.multivalued) { + // List node. + nodes.add(new ArrowFieldNode(block.getPositionCount(), converter.nullValuesCount(block))); + // Value vector, does not contain nulls. + nodes.add(new ArrowFieldNode(BlockConverter.valueCount(block), 0)); + } else { + nodes.add(new ArrowFieldNode(block.getPositionCount(), converter.nullValuesCount(block))); + } + converter.convert(block, column.multivalued, bufs, bufWriters); + } + + // Consistency check + if (bufs.size() != bufWriters.size()) { + throw new IllegalStateException( + "Inconsistent Arrow buffers: [" + bufs.size() + "] buffers and [" + bufWriters.size() + "] writers" + ); } // Create the batch and serialize it diff --git a/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/BlockConverter.java b/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/BlockConverter.java index 0a65792ab8e13..2a305cfdbc503 100644 --- a/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/BlockConverter.java +++ b/x-pack/plugin/esql/arrow/src/main/java/org/elasticsearch/xpack/esql/arrow/BlockConverter.java @@ -71,10 +71,11 @@ public interface BufWriter { /** * Convert a block into Arrow buffers. * @param block the ESQL block + * @param multivalued is this column multivalued? This block may not, but some blocks in that column are. * @param bufs arrow buffers, used to track sizes * @param bufWriters buffer writers, that will do the actual work of writing the data */ - public abstract void convert(Block block, List bufs, List bufWriters); + public abstract void convert(Block block, boolean multivalued, List bufs, List bufWriters); /** * Conversion of Double blocks @@ -86,28 +87,31 @@ public AsFloat64(String esqlType) { } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { DoubleBlock block = (DoubleBlock) b; - accumulateVectorValidity(bufs, bufWriters, block); + if (multivalued) { + addListOffsets(bufs, bufWriters, block); + } + accumulateVectorValidity(bufs, bufWriters, block, multivalued); - bufs.add(dummyArrowBuf(vectorLength(block))); + bufs.add(dummyArrowBuf(vectorByteSize(block))); bufWriters.add(out -> { if (block.areAllValuesNull()) { - return BlockConverter.writeZeroes(out, vectorLength(block)); + return BlockConverter.writeZeroes(out, vectorByteSize(block)); } // TODO could we "just" get the memory of the array and dump it? - int count = block.getPositionCount(); + int count = BlockConverter.valueCount(block); for (int i = 0; i < count; i++) { out.writeDoubleLE(block.getDouble(i)); } - return vectorLength(block); + return (long) count * Double.BYTES; }); } - private static int vectorLength(DoubleBlock b) { - return Double.BYTES * b.getPositionCount(); + private static int vectorByteSize(DoubleBlock b) { + return Double.BYTES * BlockConverter.valueCount(b); } } @@ -121,28 +125,31 @@ public AsInt32(String esqlType) { } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { IntBlock block = (IntBlock) b; - accumulateVectorValidity(bufs, bufWriters, block); + if (multivalued) { + addListOffsets(bufs, bufWriters, block); + } + accumulateVectorValidity(bufs, bufWriters, block, multivalued); - bufs.add(dummyArrowBuf(vectorLength(block))); + bufs.add(dummyArrowBuf(vectorByteSize(block))); bufWriters.add(out -> { if (block.areAllValuesNull()) { - return BlockConverter.writeZeroes(out, vectorLength(block)); + return BlockConverter.writeZeroes(out, vectorByteSize(block)); } // TODO could we "just" get the memory of the array and dump it? - int count = block.getPositionCount(); + int count = BlockConverter.valueCount(block); for (int i = 0; i < count; i++) { out.writeIntLE(block.getInt(i)); } - return vectorLength(block); + return (long) count * Integer.BYTES; }); } - private static int vectorLength(IntBlock b) { - return Integer.BYTES * b.getPositionCount(); + private static int vectorByteSize(Block b) { + return Integer.BYTES * BlockConverter.valueCount(b); } } @@ -159,27 +166,31 @@ protected AsInt64(String esqlType, Types.MinorType minorType) { } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { LongBlock block = (LongBlock) b; - accumulateVectorValidity(bufs, bufWriters, block); - bufs.add(dummyArrowBuf(vectorLength(block))); + if (multivalued) { + addListOffsets(bufs, bufWriters, block); + } + accumulateVectorValidity(bufs, bufWriters, block, multivalued); + + bufs.add(dummyArrowBuf(vectorByteSize(block))); bufWriters.add(out -> { if (block.areAllValuesNull()) { - return BlockConverter.writeZeroes(out, vectorLength(block)); + return BlockConverter.writeZeroes(out, vectorByteSize(block)); } // TODO could we "just" get the memory of the array and dump it? - int count = block.getPositionCount(); + int count = BlockConverter.valueCount(block); for (int i = 0; i < count; i++) { out.writeLongLE(block.getLong(i)); } - return vectorLength(block); + return (long) count * Long.BYTES; }); } - private static int vectorLength(LongBlock b) { - return Long.BYTES * b.getPositionCount(); + private static int vectorByteSize(LongBlock b) { + return Long.BYTES * BlockConverter.valueCount(b); } } @@ -192,13 +203,17 @@ public AsBoolean(String esqlType) { } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { BooleanBlock block = (BooleanBlock) b; - accumulateVectorValidity(bufs, bufWriters, block); - bufs.add(dummyArrowBuf(vectorLength(block))); + if (multivalued) { + addListOffsets(bufs, bufWriters, block); + } + accumulateVectorValidity(bufs, bufWriters, block, multivalued); + + bufs.add(dummyArrowBuf(vectorByteSize(block))); bufWriters.add(out -> { - int count = block.getPositionCount(); + int count = BlockConverter.valueCount(block); BitSet bits = new BitSet(); // Only set the bits that are true, writeBitSet will take @@ -215,8 +230,8 @@ public void convert(Block b, List bufs, List bufWriters) { }); } - private static int vectorLength(BooleanBlock b) { - return BlockConverter.bitSetLength(b.getPositionCount()); + private static int vectorByteSize(BooleanBlock b) { + return BlockConverter.bitSetLength(BlockConverter.valueCount(b)); } } @@ -230,27 +245,30 @@ public BytesRefConverter(String esqlType, Types.MinorType minorType) { } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { BytesRefBlock block = (BytesRefBlock) b; - BlockConverter.accumulateVectorValidity(bufs, bufWriters, block); + if (multivalued) { + addListOffsets(bufs, bufWriters, block); + } + accumulateVectorValidity(bufs, bufWriters, block, multivalued); // Offsets vector - bufs.add(dummyArrowBuf(offsetVectorLength(block))); + bufs.add(dummyArrowBuf(offsetvectorByteSize(block))); bufWriters.add(out -> { if (block.areAllValuesNull()) { - var count = block.getPositionCount() + 1; + var count = valueCount(block) + 1; for (int i = 0; i < count; i++) { out.writeIntLE(0); } - return offsetVectorLength(block); + return offsetvectorByteSize(block); } // TODO could we "just" get the memory of the array and dump it? BytesRef scratch = new BytesRef(); int offset = 0; - for (int i = 0; i < block.getPositionCount(); i++) { + for (int i = 0; i < valueCount(block); i++) { out.writeIntLE(offset); // FIXME: add a ByteRefsVector.getLength(position): there are some cases // where getBytesRef will allocate, which isn't needed here. @@ -259,11 +277,11 @@ public void convert(Block b, List bufs, List bufWriters) { offset += v.length; } out.writeIntLE(offset); - return offsetVectorLength(block); + return offsetvectorByteSize(block); }); // Data vector - bufs.add(BlockConverter.dummyArrowBuf(dataVectorLength(block))); + bufs.add(BlockConverter.dummyArrowBuf(dataVectorByteSize(block))); bufWriters.add(out -> { if (block.areAllValuesNull()) { @@ -273,7 +291,7 @@ public void convert(Block b, List bufs, List bufWriters) { // TODO could we "just" get the memory of the array and dump it? BytesRef scratch = new BytesRef(); long length = 0; - for (int i = 0; i < block.getPositionCount(); i++) { + for (int i = 0; i < valueCount(block); i++) { BytesRef v = block.getBytesRef(i, scratch); out.write(v.bytes, v.offset, v.length); @@ -283,11 +301,11 @@ public void convert(Block b, List bufs, List bufWriters) { }); } - private static int offsetVectorLength(BytesRefBlock block) { - return Integer.BYTES * (block.getPositionCount() + 1); + private static int offsetvectorByteSize(BytesRefBlock block) { + return Integer.BYTES * (valueCount(block) + 1); } - private int dataVectorLength(BytesRefBlock block) { + private int dataVectorByteSize(BytesRefBlock block) { if (block.areAllValuesNull()) { return 0; } @@ -296,7 +314,7 @@ private int dataVectorLength(BytesRefBlock block) { int length = 0; BytesRef scratch = new BytesRef(); - for (int i = 0; i < block.getPositionCount(); i++) { + for (int i = 0; i < valueCount(block); i++) { BytesRef v = block.getBytesRef(i, scratch); length += v.length; } @@ -323,10 +341,10 @@ public TransformedBytesRef(String esqlType, Types.MinorType minorType, BiFunctio } @Override - public void convert(Block b, List bufs, List bufWriters) { + public void convert(Block b, boolean multivalued, List bufs, List bufWriters) { BytesRefBlock block = (BytesRefBlock) b; try (BytesRefBlock transformed = transformValues(block)) { - super.convert(transformed, bufs, bufWriters); + super.convert(transformed, multivalued, bufs, bufWriters); } } @@ -336,20 +354,40 @@ public void convert(Block b, List bufs, List bufWriters) { private BytesRefBlock transformValues(BytesRefBlock block) { try (BytesRefBlock.Builder builder = block.blockFactory().newBytesRefBlockBuilder(block.getPositionCount())) { BytesRef scratch = new BytesRef(); - for (int i = 0; i < block.getPositionCount(); i++) { - if (block.isNull(i)) { - builder.appendNull(); - } else { - BytesRef bytes = block.getBytesRef(i, scratch); - if (bytes.length != 0) { - bytes = valueConverter.apply(bytes, scratch); + if (block.mayHaveMultivaluedFields() == false) { + for (int pos = 0; pos < valueCount(block); pos++) { + if (block.isNull(pos)) { + builder.appendNull(); + } else { + convertAndAppend(builder, block, pos, scratch); + } + } + } else { + for (int pos = 0; pos < block.getPositionCount(); pos++) { + if (block.isNull(pos)) { + builder.appendNull(); + } else { + builder.beginPositionEntry(); + int startPos = block.getFirstValueIndex(pos); + int lastPos = block.getFirstValueIndex(pos + 1); + for (int valuePos = startPos; valuePos < lastPos; valuePos++) { + convertAndAppend(builder, block, valuePos, scratch); + } + builder.endPositionEntry(); } - builder.appendBytesRef(bytes); } } return builder.build(); } } + + private void convertAndAppend(BytesRefBlock.Builder builder, BytesRefBlock block, int position, BytesRef scratch) { + BytesRef bytes = block.getBytesRef(position, scratch); + if (bytes.length != 0) { + bytes = valueConverter.apply(bytes, scratch); + } + builder.appendBytesRef(bytes); + } } public static class AsVarChar extends BytesRefConverter { @@ -370,7 +408,7 @@ public AsNull(String esqlType) { } @Override - public void convert(Block block, List bufs, List bufWriters) { + public void convert(Block block, boolean multivalued, List bufs, List bufWriters) { // Null vector in arrow has no associated buffers // See https://arrow.apache.org/docs/format/Columnar.html#null-layout } @@ -386,15 +424,38 @@ private static int bitSetLength(int totalValues) { return (totalValues + 7) / 8; } - private static void accumulateVectorValidity(List bufs, List bufWriters, Block b) { - bufs.add(dummyArrowBuf(bitSetLength(b.getPositionCount()))); + /** + * Get the value count for a block. For single-valued blocks this is the same as the position count. + * For multivalued blocks, this is the flattened number of items. + */ + static int valueCount(Block block) { + int result = block.getFirstValueIndex(block.getPositionCount()); + + // firstValueIndex is always zero for all-null blocks. + if (result == 0 && block.areAllValuesNull()) { + result = block.getPositionCount(); + } + + return result; + } + + private static void accumulateVectorValidity(List bufs, List bufWriters, Block b, boolean multivalued) { + // If that block is in a multivalued-column, validities are output in the parent Arrow List buffer (values themselves + // do not contain nulls per docvalues limitations). + if (multivalued || b.mayHaveNulls() == false) { + // Arrow IPC allows a compact form for "all true" validities using an empty buffer. + bufs.add(dummyArrowBuf(0)); + bufWriters.add(w -> 0); + return; + } + + int valueCount = b.getPositionCount(); + bufs.add(dummyArrowBuf(bitSetLength(valueCount))); bufWriters.add(out -> { - if (b.mayHaveNulls() == false) { - return writeAllTrueValidity(out, b.getPositionCount()); - } else if (b.areAllValuesNull()) { - return writeAllFalseValidity(out, b.getPositionCount()); + if (b.areAllValuesNull()) { + return writeAllFalseValidity(out, valueCount); } else { - return writeValidities(out, b); + return writeValidities(out, b, valueCount); } }); } @@ -420,10 +481,10 @@ private static long writeAllFalseValidity(RecyclerBytesStreamOutput out, int val return count; } - private static long writeValidities(RecyclerBytesStreamOutput out, Block block) { - int valueCount = block.getPositionCount(); + private static long writeValidities(RecyclerBytesStreamOutput out, Block block, int valueCount) { BitSet bits = new BitSet(valueCount); for (int i = 0; i < block.getPositionCount(); i++) { + // isNull is value indices, not multi-value positions if (block.isNull(i) == false) { bits.set(i); } @@ -449,4 +510,29 @@ private static long writeZeroes(RecyclerBytesStreamOutput out, int byteCount) { } return byteCount; } + + private static void addListOffsets(List bufs, List bufWriters, Block block) { + // Add validity buffer + accumulateVectorValidity(bufs, bufWriters, block, false); + + // Add offsets buffer + int bufferLen = Integer.BYTES * (block.getPositionCount() + 1); + + bufs.add(dummyArrowBuf(bufferLen)); + bufWriters.add(out -> { + if (block.mayHaveMultivaluedFields()) { + // '<=' is intentional to write the end position of the last item + for (int i = 0; i <= block.getPositionCount(); i++) { + // TODO could we get the block's firstValueIndexes and dump it? + out.writeIntLE(block.getFirstValueIndex(i)); + } + } else { + for (int i = 0; i <= block.getPositionCount(); i++) { + out.writeIntLE(i); + } + } + + return bufferLen; + }); + } } diff --git a/x-pack/plugin/esql/arrow/src/test/java/org/elasticsearch/xpack/esql/arrow/ArrowResponseTests.java b/x-pack/plugin/esql/arrow/src/test/java/org/elasticsearch/xpack/esql/arrow/ArrowResponseTests.java index cf49b37db2805..b187e49554f8b 100644 --- a/x-pack/plugin/esql/arrow/src/test/java/org/elasticsearch/xpack/esql/arrow/ArrowResponseTests.java +++ b/x-pack/plugin/esql/arrow/src/test/java/org/elasticsearch/xpack/esql/arrow/ArrowResponseTests.java @@ -19,7 +19,10 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.util.VectorSchemaRootAppender; import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.util.BytesRef; @@ -34,7 +37,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVectorBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.test.ESTestCase; @@ -42,6 +44,8 @@ import org.elasticsearch.xpack.versionfield.Version; import org.junit.AfterClass; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -75,6 +79,7 @@ public static void afterClass() throws Exception { // Value creation, getters for ESQL and Arrow static final ValueType INTEGER_VALUES = new ValueTypeImpl( + "integer", factory -> factory.newIntBlockBuilder(0), block -> block.appendInt(randomInt()), (block, i, scratch) -> block.getInt(i), @@ -82,6 +87,7 @@ public static void afterClass() throws Exception { ); static final ValueType LONG_VALUES = new ValueTypeImpl( + "long", factory -> factory.newLongBlockBuilder(0), block -> block.appendLong(randomLong()), (block, i, scratch) -> block.getLong(i), @@ -89,6 +95,7 @@ public static void afterClass() throws Exception { ); static final ValueType ULONG_VALUES = new ValueTypeImpl( + "ulong", factory -> factory.newLongBlockBuilder(0), block -> block.appendLong(randomLong()), (block, i, scratch) -> block.getLong(i), @@ -96,6 +103,7 @@ public static void afterClass() throws Exception { ); static final ValueType DATE_VALUES = new ValueTypeImpl( + "date", factory -> factory.newLongBlockBuilder(0), block -> block.appendLong(randomLong()), (block, i, scratch) -> block.getLong(i), @@ -103,6 +111,7 @@ public static void afterClass() throws Exception { ); static final ValueType DOUBLE_VALUES = new ValueTypeImpl( + "double", factory -> factory.newDoubleBlockBuilder(0), block -> block.appendDouble(randomDouble()), (block, i, scratch) -> block.getDouble(i), @@ -110,6 +119,7 @@ public static void afterClass() throws Exception { ); static final ValueType BOOLEAN_VALUES = new ValueTypeImpl( + "boolean", factory -> factory.newBooleanBlockBuilder(0), block -> block.appendBoolean(randomBoolean()), (b, i, s) -> b.getBoolean(i), @@ -117,21 +127,23 @@ public static void afterClass() throws Exception { ); static final ValueType TEXT_VALUES = new ValueTypeImpl( + "text", factory -> factory.newBytesRefBlockBuilder(0), - block -> block.appendBytesRef(new BytesRef("🚀" + randomAlphaOfLengthBetween(1, 20))), + block -> block.appendBytesRef(new BytesRef(randomUnicodeOfLengthBetween(1, 20))), (b, i, s) -> b.getBytesRef(i, s).utf8ToString(), (v, i) -> new String(v.get(i), StandardCharsets.UTF_8) ); static final ValueType SOURCE_VALUES = new ValueTypeImpl( + "source", factory -> factory.newBytesRefBlockBuilder(0), - // Use a constant value, conversion is tested separately - block -> block.appendBytesRef(new BytesRef("{\"foo\": 42}")), + block -> block.appendBytesRef(new BytesRef("{\"foo\": " + randomIntBetween(-42, 42) + "}")), (b, i, s) -> b.getBytesRef(i, s).utf8ToString(), (v, i) -> new String(v.get(i), StandardCharsets.UTF_8) ); static final ValueType IP_VALUES = new ValueTypeImpl( + "ip", factory -> factory.newBytesRefBlockBuilder(0), block -> { byte[] addr = InetAddressPoint.encode(randomIp(randomBoolean())); @@ -143,6 +155,7 @@ public static void afterClass() throws Exception { ); static final ValueType BINARY_VALUES = new ValueTypeImpl( + "binary", factory -> factory.newBytesRefBlockBuilder(0), block -> block.appendBytesRef(new BytesRef(randomByteArrayOfLength(randomIntBetween(1, 100)))), BytesRefBlock::getBytesRef, @@ -150,6 +163,7 @@ public static void afterClass() throws Exception { ); static final ValueType VERSION_VALUES = new ValueTypeImpl( + "version", factory -> factory.newBytesRefBlockBuilder(0), block -> block.appendBytesRef(new Version(between(0, 100) + "." + between(0, 100) + "." + between(0, 100)).toBytesRef()), (b, i, s) -> new Version(b.getBytesRef(i, s)).toString(), @@ -157,6 +171,7 @@ public static void afterClass() throws Exception { ); static final ValueType NULL_VALUES = new ValueTypeImpl( + "null", factory -> factory.newBytesRefBlockBuilder(0), Block.Builder::appendNull, (b, i, s) -> b.isNull(i) ? null : "non-null in block", @@ -201,9 +216,10 @@ public void testTestHarness() { TestBlock emptyBlock = TestBlock.create(BLOCK_FACTORY, testColumn, Density.Empty, 7); // Test that density works as expected - assertTrue(denseBlock.block instanceof IntVectorBlock); - assertEquals("IntArrayBlock", sparseBlock.block.getClass().getSimpleName()); // non-public class - assertEquals("ConstantNullBlock", emptyBlock.block.getClass().getSimpleName()); + assertFalse(denseBlock.block.mayHaveNulls()); + assertTrue(sparseBlock.block.mayHaveNulls()); + assertFalse(sparseBlock.block.areAllValuesNull()); + assertTrue(emptyBlock.block.areAllValuesNull()); // Test that values iterator scans all pages List pages = Stream.of(denseBlock, sparseBlock, emptyBlock).map(b -> new TestPage(List.of(b))).toList(); @@ -229,7 +245,7 @@ public void testTestHarness() { */ public void testSingleColumn() throws IOException { for (var type : VALUE_TYPES.keySet()) { - TestColumn testColumn = new TestColumn("foo", type, VALUE_TYPES.get(type)); + TestColumn testColumn = new TestColumn("foo", type, VALUE_TYPES.get(type), false); List pages = new ArrayList<>(); for (var density : Density.values()) { @@ -248,7 +264,7 @@ public void testSingleBlock() throws IOException { String type = "text"; Density density = Density.Dense; - TestColumn testColumn = new TestColumn("foo", type, VALUE_TYPES.get(type)); + TestColumn testColumn = new TestColumn("foo", type, VALUE_TYPES.get(type), false); List pages = new ArrayList<>(); TestBlock testBlock = TestBlock.create(BLOCK_FACTORY, testColumn, density, 10); @@ -261,44 +277,156 @@ public void testSingleBlock() throws IOException { } /** - * Test that multivalued arrays are rejected + * Test a multivalued field with fixed size values. */ - public void testMultivaluedField() throws IOException { + public void testMultivaluedInteger() throws IOException { IntBlock.Builder builder = BLOCK_FACTORY.newIntBlockBuilder(0); + builder.beginPositionEntry(); builder.appendInt(42); + builder.appendInt(43); + builder.endPositionEntry(); + + // The multivalue can be null, but a multivalue cannot contain nulls. + // Calling appendNull within a begin/endEntry causes consistency checks to fail in build() + // See also https://github.com/elastic/elasticsearch/issues/114324 builder.appendNull(); + builder.beginPositionEntry(); builder.appendInt(44); builder.appendInt(45); builder.endPositionEntry(); + + // single value builder.appendInt(46); + IntBlock block = builder.build(); + builder.close(); - // Consistency check + // Consistency check. + // AbstractArrayBlock.assertInvariants does some of these consistency checks, but those below + // specifically verify the assumptions on which the conversion to Arrow is built. assertTrue(block.mayHaveMultivaluedFields()); + assertEquals(4, block.getPositionCount()); // counts null entries + assertEquals(5, block.getTotalValueCount()); // nulls aren't counted + + // Value 0 + assertEquals(2, block.getValueCount(0)); assertEquals(0, block.getFirstValueIndex(0)); - assertEquals(1, block.getValueCount(0)); + assertEquals(42, block.getInt(block.getFirstValueIndex(0))); + assertEquals(43, block.getInt(block.getFirstValueIndex(0) + 1)); - // null values still use one position in the array + // Value 1 assertEquals(0, block.getValueCount(1)); - assertEquals(1, block.getFirstValueIndex(1)); - assertTrue(block.isNull(1)); - assertEquals(0, block.getInt(1)); + assertTrue(block.isNull(1)); // This is the position index, not value index + // No value, but still occupies a value slot with zero + assertEquals(2, block.getFirstValueIndex(1)); + assertEquals(0, block.getInt(block.getFirstValueIndex(1))); + assertEquals(3, block.getFirstValueIndex(2)); - assertEquals(2, block.getFirstValueIndex(2)); + // Value 2 assertEquals(2, block.getValueCount(2)); - assertEquals(2, block.getFirstValueIndex(2)); + assertEquals(3, block.getFirstValueIndex(2)); + assertEquals(44, block.getInt(block.getFirstValueIndex(2))); assertEquals(45, block.getInt(block.getFirstValueIndex(2) + 1)); - assertEquals(4, block.getFirstValueIndex(3)); + // Value 3 + assertEquals(1, block.getValueCount(3)); + assertEquals(5, block.getFirstValueIndex(3)); + assertEquals(46, block.getInt(block.getFirstValueIndex(3))); - var column = TestColumn.create("some-field", "integer"); - TestCase testCase = new TestCase(List.of(column), List.of(new TestPage(List.of(new TestBlock(column, block, Density.Dense))))); + // End of block + assertEquals(6, block.getFirstValueIndex(4)); - IllegalArgumentException exc = assertThrows(IllegalArgumentException.class, () -> compareEsqlAndArrow(testCase)); + var column = TestColumn.create("some-field", "integer", true); + TestCase testCase = new TestCase(List.of(column), List.of(new TestPage(List.of(TestBlock.create(column, block))))); - assertEquals("ES|QL response field [some-field] is multi-valued. This isn't supported yet by the Arrow format", exc.getMessage()); + compareEsqlAndArrow(testCase); + } + + /** + * Test a multivalued field with variable size values. + */ + public void testMultivalueString() throws IOException { + BytesRefBlock.Builder builder = BLOCK_FACTORY.newBytesRefBlockBuilder(0); + + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("a")); + builder.appendBytesRef(new BytesRef("b")); + builder.endPositionEntry(); + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("c")); + builder.appendBytesRef(new BytesRef("d")); + builder.endPositionEntry(); + + BytesRefBlock block = builder.build(); + builder.close(); + + var column = TestColumn.create("some-field", "text"); + TestCase testCase = new TestCase(List.of(column), List.of(new TestPage(List.of(TestBlock.create(column, block))))); + + compareEsqlAndArrow(testCase); + } + + // Test exercising Arrow's multivalue API + public void testMultiValueArrow() throws IOException { + + byte[] bytes; + + try (ListVector listVector = ListVector.empty("some-field", ALLOCATOR)) { + UnionListWriter writer = listVector.getWriter(); + + writer.startList(); + writer.writeInt(42); // 0x2A + writer.writeInt(43); // 0x2A + writer.endList(); + + writer.startList(); + // Size is zero without a writeNull() + writer.writeNull(); // Adds a null value in that list + writer.endList(); + + writer.startList(); + writer.writeInt(44); // 0x2C + writer.writeInt(45); // 0x2D + writer.endList(); + + writer.startList(); + writer.writeInt(46); // 0x2E + writer.endList(); + + listVector.setValueCount(4); + bytes = getBytes(listVector); + } + + try (var reader = new ArrowStreamReader(new ByteArrayInputStream(bytes), ALLOCATOR)) { + var root = reader.getVectorSchemaRoot(); + reader.loadNextBatch(); + + ListVector listVector = (ListVector) root.getVector("some-field"); + + assertEquals(4, listVector.getValueCount()); + assertEquals(List.of(42, 43), listVector.getObject(0)); + assertEquals(Collections.singletonList((Integer) null), listVector.getObject(1)); + assertEquals(List.of(44, 45), listVector.getObject(2)); + assertEquals(List.of(46), listVector.getObject(3)); + } + } + + private static byte[] getBytes(ListVector listVector) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + var fields = List.of(listVector.getField()); + List vectors = List.of(listVector); + + try ( + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors); + ArrowStreamWriter arrowWriter = new ArrowStreamWriter(root, null, baos); + ) { + arrowWriter.start(); + arrowWriter.writeBatch(); + arrowWriter.end(); + } + return baos.toByteArray(); } /** @@ -319,10 +447,6 @@ public void testRandomTypesAndSize() throws IOException { .toList(); TestCase testCase = new TestCase(columns, pages); - // System.out.println(testCase); - // for (TestPage page: pages) { - // System.out.println(page); - // } compareEsqlAndArrow(testCase); } @@ -347,8 +471,13 @@ private void compareEsqlAndArrow(TestCase testCase, VectorSchemaRoot root) { var esqlValuesIterator = new EsqlValuesIterator(testCase, i); var arrowValuesIterator = new ArrowValuesIterator(testCase, root, i); + int line = 0; + while (esqlValuesIterator.hasNext() && arrowValuesIterator.hasNext()) { - assertEquals(esqlValuesIterator.next(), arrowValuesIterator.next()); + Object esqlValue = esqlValuesIterator.next(); + Object arrowValue = arrowValuesIterator.next(); + assertEquals(("line " + line), esqlValue, arrowValue); + line++; } // Make sure we entirely consumed both sides. @@ -387,7 +516,6 @@ private VectorSchemaRoot toArrowVectors(TestCase testCase) throws IOException { static class EsqlValuesIterator implements Iterator { private final int fieldPos; private final ValueType type; - private final BytesRef scratch = new BytesRef(); private final Iterator pages; private TestPage page; @@ -412,7 +540,7 @@ public Object next() { throw new NoSuchElementException(); } Block block = page.blocks.get(fieldPos).block; - Object result = block.isNull(position) ? null : type.valueAt(block, position, scratch); + Object result = block.isNull(position) ? null : type.valueAt(block, position, new BytesRef()); position++; if (position >= block.getPositionCount()) { position = 0; @@ -475,9 +603,13 @@ public String toString() { } } - record TestColumn(String name, String type, ValueType valueType) { + record TestColumn(String name, String type, ValueType valueType, boolean multivalue) { static TestColumn create(String name, String type) { - return new TestColumn(name, type, VALUE_TYPES.get(type)); + return create(name, type, randomBoolean()); + } + + static TestColumn create(String name, String type, boolean multivalue) { + return new TestColumn(name, type, VALUE_TYPES.get(type), multivalue); } } @@ -498,6 +630,18 @@ public String toString() { record TestBlock(TestColumn column, Block block, Density density) { + static TestBlock create(TestColumn column, Block block) { + Density density; + if (block.areAllValuesNull()) { + density = Density.Empty; + } else if (block.mayHaveNulls()) { + density = Density.Sparse; + } else { + density = Density.Dense; + } + return new TestBlock(column, block, density); + } + static TestBlock create(BlockFactory factory, TestColumn column, int positions) { return create(factory, column, randomFrom(Density.values()), positions); } @@ -517,10 +661,21 @@ static TestBlock create(BlockFactory factory, TestColumn column, Density density start = 2; } for (int i = start; i < positions; i++) { - valueType.addValue(builder, density); + // If multivalued, randomly insert a series of values if the type isn't null (nulls are not allowed in multivalues) + if (column.multivalue && column.valueType != NULL_VALUES && randomBoolean()) { + builder.beginPositionEntry(); + int numEntries = randomIntBetween(2, 5); + for (int j = 0; j < numEntries; j++) { + valueType.addValue(builder, Density.Dense); + } + builder.endPositionEntry(); + } else { + valueType.addValue(builder, density); + } } // Will create an ArrayBlock if there are null values, VectorBlock otherwise block = builder.build(); + assertEquals(positions, block.getPositionCount()); } return new TestBlock(column, block, density); } @@ -553,17 +708,20 @@ interface ValueType { public static class ValueTypeImpl implements ValueType { + private final String name; private final Function builderCreator; private final Consumer valueAdder; private final TriFunction blockGetter; private final BiFunction vectorGetter; public ValueTypeImpl( + String name, Function builderCreator, Consumer valueAdder, TriFunction blockGetter, BiFunction vectorGetter ) { + this.name = name; this.builderCreator = builderCreator; this.valueAdder = valueAdder; this.blockGetter = blockGetter; @@ -588,13 +746,35 @@ public void addValue(Block.Builder builder, Density density) { @Override @SuppressWarnings("unchecked") public Object valueAt(Block block, int position, BytesRef scratch) { - return blockGetter.apply((BlockT) block, position, scratch); + // Build the list of values + var values = new ArrayList<>(); + for (int i = block.getFirstValueIndex(position); i < block.getFirstValueIndex(position + 1); i++) { + values.add(blockGetter.apply((BlockT) block, i, scratch)); + scratch = new BytesRef(); // do not overwrite previous value + } + return values.size() == 1 ? values.getFirst() : values; } @Override @SuppressWarnings("unchecked") public Object valueAt(ValueVector arrowVec, int position) { - return vectorGetter.apply((VectorT) arrowVec, position); + if (arrowVec instanceof ListVector listVector) { + var type = listVector.getField().getMetadata().get("elastic:type"); + // Build the list of values + var valueVec = listVector.getDataVector(); + var values = new ArrayList<>(); + for (int i = listVector.getElementStartIndex(position); i < listVector.getElementEndIndex(position); i++) { + values.add(vectorGetter.apply((VectorT) valueVec, i)); + } + return values.size() == 1 ? values.getFirst() : values; + } else { + return vectorGetter.apply((VectorT) arrowVec, position); + } + } + + @Override + public String toString() { + return name; } } }