diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java index d9afd0189d807..27acf84d30157 100644 --- a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java +++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java @@ -90,8 +90,12 @@ public StructVector load(BufferAllocator allocator, ArrowRecordBatch recordBatch .fromCompressionType(recordBatch.getBodyCompression().getCodec()); decompressionNeeded = codecType != CompressionUtil.CodecType.NO_COMPRESSION; CompressionCodec codec = decompressionNeeded ? factory.createCodec(codecType) : NoCompressionCodec.INSTANCE; + Iterator variadicBufferCounts = null; + if (recordBatch.getVariadicBufferCounts() != null && !recordBatch.getVariadicBufferCounts().isEmpty()) { + variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator(); + } for (FieldVector fieldVector : result.getChildrenFromFields()) { - loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec); + loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec, variadicBufferCounts); } result.loadFieldBuffers(new ArrowFieldNode(recordBatch.getLength(), 0), Collections.singletonList(null)); if (nodes.hasNext() || buffers.hasNext()) { @@ -102,10 +106,15 @@ public StructVector load(BufferAllocator allocator, ArrowRecordBatch recordBatch } private void loadBuffers(FieldVector vector, Field field, Iterator buffers, Iterator nodes, - CompressionCodec codec) { + CompressionCodec codec, Iterator variadicBufferCounts) { checkArgument(nodes.hasNext(), "no more field nodes for field %s and vector %s", field, vector); ArrowFieldNode fieldNode = nodes.next(); - int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType()); + // variadicBufferLayoutCount will be 0 for vectors of type except BaseVariableWidthViewVector + long variadicBufferLayoutCount = 0; + if (variadicBufferCounts != null) { + variadicBufferLayoutCount = variadicBufferCounts.next(); + } + int bufferLayoutCount = (int) (variadicBufferLayoutCount + TypeLayout.getTypeBufferCount(field.getType())); List ownBuffers = new ArrayList<>(bufferLayoutCount); for (int j = 0; j < bufferLayoutCount; j++) { ArrowBuf nextBuf = buffers.next(); @@ -138,7 +147,7 @@ private void loadBuffers(FieldVector vector, Field field, Iterator buf for (int i = 0; i < childrenFromFields.size(); i++) { Field child = children.get(i); FieldVector fieldVector = childrenFromFields.get(i); - loadBuffers(fieldVector, child, buffers, nodes, codec); + loadBuffers(fieldVector, child, buffers, nodes, codec, variadicBufferCounts); } } } diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java index aa6d9b4d0f6a7..8d015157ebf38 100644 --- a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java +++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.BaseVariableWidthViewVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.TypeLayout; import org.apache.arrow.vector.complex.StructVector; @@ -87,17 +88,28 @@ public StructVectorUnloader(StructVector root, boolean includeNullCount, Compres public ArrowRecordBatch getRecordBatch() { List nodes = new ArrayList<>(); List buffers = new ArrayList<>(); + List variadicBufferCounts = new ArrayList<>(); for (FieldVector vector : root.getChildrenFromFields()) { - appendNodes(vector, nodes, buffers); + appendNodes(vector, nodes, buffers, variadicBufferCounts); } return new ArrowRecordBatch(root.getValueCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), - alignBuffers); + variadicBufferCounts, alignBuffers); } - private void appendNodes(FieldVector vector, List nodes, List buffers) { + private long getVariadicBufferCount(FieldVector vector) { + if (vector instanceof BaseVariableWidthViewVector) { + return ((BaseVariableWidthViewVector) vector).getDataBuffers().size(); + } + return 0L; + } + + private void appendNodes(FieldVector vector, List nodes, List buffers, + List variadicBufferCounts) { nodes.add(new ArrowFieldNode(vector.getValueCount(), includeNullCount ? vector.getNullCount() : -1)); List fieldBuffers = vector.getFieldBuffers(); - int expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField().getType()); + long variadicBufferCount = getVariadicBufferCount(vector); + int expectedBufferCount = (int) (TypeLayout.getTypeBufferCount(vector.getField().getType()) + variadicBufferCount); + variadicBufferCounts.add(variadicBufferCount); if (fieldBuffers.size() != expectedBufferCount) { throw new IllegalArgumentException(String.format("wrong number of buffers for field %s in vector %s. found: %s", vector.getField(), vector.getClass().getSimpleName(), fieldBuffers)); @@ -106,7 +118,7 @@ private void appendNodes(FieldVector vector, List nodes, List getDataBuffers() { return dataBuffers; @@ -368,8 +368,21 @@ public List getChildrenFromFields() { */ @Override public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers) { - // TODO: https://github.com/apache/arrow/issues/40931 - throw new UnsupportedOperationException("loadFieldBuffers is not supported for BaseVariableWidthViewVector"); + ArrowBuf bitBuf = ownBuffers.get(0); + ArrowBuf viewBuf = ownBuffers.get(1); + List dataBufs = ownBuffers.subList(2, ownBuffers.size()); + + this.clear(); + + this.viewBuffer = viewBuf.getReferenceManager().retain(viewBuf, allocator); + this.validityBuffer = BitVectorHelper.loadValidityBuffer(fieldNode, bitBuf, allocator); + + for (ArrowBuf dataBuf : dataBufs) { + this.dataBuffers.add(dataBuf.getReferenceManager().retain(dataBuf, allocator)); + } + + lastSet = fieldNode.getLength() - 1; + valueCount = fieldNode.getLength(); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java index ea92efdc55f61..0d01d77632bde 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeVisitor; import org.apache.arrow.vector.types.pojo.ArrowType.Binary; +import org.apache.arrow.vector.types.pojo.ArrowType.BinaryView; import org.apache.arrow.vector.types.pojo.ArrowType.Bool; import org.apache.arrow.vector.types.pojo.ArrowType.Date; import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; @@ -186,8 +187,7 @@ public TypeLayout visit(Binary type) { @Override public TypeLayout visit(ArrowType.BinaryView type) { - // TODO: https://github.com/apache/arrow/issues/40934 - throw new UnsupportedOperationException("BinaryView not supported"); + return newVariableWidthViewTypeLayout(); } @Override @@ -197,8 +197,7 @@ public TypeLayout visit(Utf8 type) { @Override public TypeLayout visit(Utf8View type) { - // TODO: https://github.com/apache/arrow/issues/40934 - throw new UnsupportedOperationException("Utf8View not supported"); + return newVariableWidthViewTypeLayout(); } @Override @@ -216,7 +215,12 @@ private TypeLayout newVariableWidthTypeLayout() { BufferLayout.byteVector()); } + private TypeLayout newVariableWidthViewTypeLayout() { + return newPrimitiveTypeLayout(BufferLayout.validityVector(), BufferLayout.byteVector()); + } + private TypeLayout newLargeVariableWidthTypeLayout() { + // NOTE: only considers the non variadic buffers return newPrimitiveTypeLayout(BufferLayout.validityVector(), BufferLayout.largeOffsetBuffer(), BufferLayout.byteVector()); } @@ -377,9 +381,9 @@ public Integer visit(Binary type) { } @Override - public Integer visit(ArrowType.BinaryView type) { - // TODO: https://github.com/apache/arrow/issues/40935 - return VARIABLE_WIDTH_BUFFER_COUNT; + public Integer visit(BinaryView type) { + // NOTE: only consider the validity and view buffers + return 2; } @Override @@ -389,8 +393,8 @@ public Integer visit(Utf8 type) { @Override public Integer visit(Utf8View type) { - // TODO: https://github.com/apache/arrow/issues/40935 - return VARIABLE_WIDTH_BUFFER_COUNT; + // NOTE: only consider the validity and view buffers + return 2; } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java index 510cef24c7e16..9590e70f46770 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java @@ -80,8 +80,13 @@ public void load(ArrowRecordBatch recordBatch) { CompressionUtil.CodecType.fromCompressionType(recordBatch.getBodyCompression().getCodec()); decompressionNeeded = codecType != CompressionUtil.CodecType.NO_COMPRESSION; CompressionCodec codec = decompressionNeeded ? factory.createCodec(codecType) : NoCompressionCodec.INSTANCE; + Iterator variadicBufferCounts = null; + if (recordBatch.getVariadicBufferCounts() != null && !recordBatch.getVariadicBufferCounts().isEmpty()) { + variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator(); + } + for (FieldVector fieldVector : root.getFieldVectors()) { - loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec); + loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec, variadicBufferCounts); } root.setRowCount(recordBatch.getLength()); if (nodes.hasNext() || buffers.hasNext()) { @@ -95,10 +100,16 @@ private void loadBuffers( Field field, Iterator buffers, Iterator nodes, - CompressionCodec codec) { + CompressionCodec codec, + Iterator variadicBufferCounts) { checkArgument(nodes.hasNext(), "no more field nodes for field %s and vector %s", field, vector); ArrowFieldNode fieldNode = nodes.next(); - int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType()); + // variadicBufferLayoutCount will be 0 for vectors of type except BaseVariableWidthViewVector + long variadicBufferLayoutCount = 0; + if (variadicBufferCounts != null) { + variadicBufferLayoutCount = variadicBufferCounts.next(); + } + int bufferLayoutCount = (int) (variadicBufferLayoutCount + TypeLayout.getTypeBufferCount(field.getType())); List ownBuffers = new ArrayList<>(bufferLayoutCount); for (int j = 0; j < bufferLayoutCount; j++) { ArrowBuf nextBuf = buffers.next(); @@ -130,7 +141,7 @@ private void loadBuffers( for (int i = 0; i < childrenFromFields.size(); i++) { Field child = children.get(i); FieldVector fieldVector = childrenFromFields.get(i); - loadBuffers(fieldVector, child, buffers, nodes, codec); + loadBuffers(fieldVector, child, buffers, nodes, codec, variadicBufferCounts); } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java index 1d44e37ac71af..8528099b6d619 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java @@ -80,19 +80,30 @@ public VectorUnloader( public ArrowRecordBatch getRecordBatch() { List nodes = new ArrayList<>(); List buffers = new ArrayList<>(); + List variadicBufferCounts = new ArrayList<>(); for (FieldVector vector : root.getFieldVectors()) { - appendNodes(vector, nodes, buffers); + appendNodes(vector, nodes, buffers, variadicBufferCounts); } // Do NOT retain buffers in ArrowRecordBatch constructor since we have already retained them. return new ArrowRecordBatch( - root.getRowCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), alignBuffers, - /*retainBuffers*/ false); + root.getRowCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), + variadicBufferCounts, alignBuffers, /*retainBuffers*/ false); } - private void appendNodes(FieldVector vector, List nodes, List buffers) { + private long getVariadicBufferCount(FieldVector vector) { + if (vector instanceof BaseVariableWidthViewVector) { + return ((BaseVariableWidthViewVector) vector).getDataBuffers().size(); + } + return 0L; + } + + private void appendNodes(FieldVector vector, List nodes, List buffers, + List variadicBufferCounts) { nodes.add(new ArrowFieldNode(vector.getValueCount(), includeNullCount ? vector.getNullCount() : -1)); List fieldBuffers = vector.getFieldBuffers(); - int expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField().getType()); + long variadicBufferCount = getVariadicBufferCount(vector); + int expectedBufferCount = (int) (TypeLayout.getTypeBufferCount(vector.getField().getType()) + variadicBufferCount); + variadicBufferCounts.add(variadicBufferCount); if (fieldBuffers.size() != expectedBufferCount) { throw new IllegalArgumentException(String.format( "wrong number of buffers for field %s in vector %s. found: %s", @@ -107,7 +118,7 @@ private void appendNodes(FieldVector vector, List nodes, List vectorTypes = typeLayout.getBufferTypes(); ArrowBuf[] vectorBuffers = new ArrowBuf[vectorTypes.size()]; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index f5e267e81256c..670881b238ecb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -208,6 +208,7 @@ private void writeBatch(VectorSchemaRoot recordBatch) throws IOException { } private void writeFromVectorIntoJson(Field field, FieldVector vector) throws IOException { + // TODO: https://github.com/apache/arrow/issues/41733 List vectorTypes = TypeLayout.getTypeLayout(field.getType()).getBufferTypes(); List vectorBuffers = vector.getFieldBuffers(); if (vectorTypes.size() != vectorBuffers.size()) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java index f81d049a9257f..b910cfc6ecc25 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java @@ -56,17 +56,19 @@ public class ArrowRecordBatch implements ArrowMessage { private final List buffersLayout; + private final List variadicBufferCounts; + private boolean closed = false; public ArrowRecordBatch( int length, List nodes, List buffers) { - this(length, nodes, buffers, NoCompressionCodec.DEFAULT_BODY_COMPRESSION, true); + this(length, nodes, buffers, NoCompressionCodec.DEFAULT_BODY_COMPRESSION, null, true); } public ArrowRecordBatch( int length, List nodes, List buffers, ArrowBodyCompression bodyCompression) { - this(length, nodes, buffers, bodyCompression, true); + this(length, nodes, buffers, bodyCompression, null, true); } /** @@ -81,7 +83,7 @@ public ArrowRecordBatch( public ArrowRecordBatch( int length, List nodes, List buffers, ArrowBodyCompression bodyCompression, boolean alignBuffers) { - this(length, nodes, buffers, bodyCompression, alignBuffers, /*retainBuffers*/ true); + this(length, nodes, buffers, bodyCompression, null, alignBuffers, /*retainBuffers*/ true); } /** @@ -98,12 +100,48 @@ public ArrowRecordBatch( public ArrowRecordBatch( int length, List nodes, List buffers, ArrowBodyCompression bodyCompression, boolean alignBuffers, boolean retainBuffers) { + this(length, nodes, buffers, bodyCompression, null, alignBuffers, retainBuffers); + } + + /** + * Construct a record batch from nodes. + * + * @param length how many rows in this batch + * @param nodes field level info + * @param buffers will be retained until this recordBatch is closed + * @param bodyCompression compression info. + * @param variadicBufferCounts the number of buffers in each variadic section. + * @param alignBuffers Whether to align buffers to an 8 byte boundary. + */ + public ArrowRecordBatch( + int length, List nodes, List buffers, + ArrowBodyCompression bodyCompression, List variadicBufferCounts, boolean alignBuffers) { + this(length, nodes, buffers, bodyCompression, variadicBufferCounts, alignBuffers, /*retainBuffers*/ true); + } + + /** + * Construct a record batch from nodes. + * + * @param length how many rows in this batch + * @param nodes field level info + * @param buffers will be retained until this recordBatch is closed + * @param bodyCompression compression info. + * @param variadicBufferCounts the number of buffers in each variadic section. + * @param alignBuffers Whether to align buffers to an 8 byte boundary. + * @param retainBuffers Whether to retain() each source buffer in the constructor. If false, the caller is + * responsible for retaining the buffers beforehand. + */ + public ArrowRecordBatch( + int length, List nodes, List buffers, + ArrowBodyCompression bodyCompression, List variadicBufferCounts, boolean alignBuffers, + boolean retainBuffers) { super(); this.length = length; this.nodes = nodes; this.buffers = buffers; Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null"); this.bodyCompression = bodyCompression; + this.variadicBufferCounts = variadicBufferCounts; List arrowBuffers = new ArrayList<>(buffers.size()); long offset = 0; for (ArrowBuf arrowBuf : buffers) { @@ -129,12 +167,14 @@ public ArrowRecordBatch( // to distinguish this from the public constructor. private ArrowRecordBatch( boolean dummy, int length, List nodes, - List buffers, ArrowBodyCompression bodyCompression) { + List buffers, ArrowBodyCompression bodyCompression, + List variadicBufferCounts) { this.length = length; this.nodes = nodes; this.buffers = buffers; Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null"); this.bodyCompression = bodyCompression; + this.variadicBufferCounts = variadicBufferCounts; this.closed = false; List arrowBuffers = new ArrayList<>(); long offset = 0; @@ -179,6 +219,14 @@ public List getBuffers() { return buffers; } + /** + * Get the record batch variadic buffer counts. + * @return the variadic buffer counts + */ + public List getVariadicBufferCounts() { + return variadicBufferCounts; + } + /** * Create a new ArrowRecordBatch which has the same information as this batch but whose buffers * are owned by that Allocator. @@ -195,7 +243,7 @@ public ArrowRecordBatch cloneWithTransfer(final BufferAllocator allocator) { .writerIndex(buf.writerIndex())) .collect(Collectors.toList()); close(); - return new ArrowRecordBatch(false, length, nodes, newBufs, bodyCompression); + return new ArrowRecordBatch(false, length, nodes, newBufs, bodyCompression, variadicBufferCounts); } /** @@ -217,6 +265,24 @@ public int writeTo(FlatBufferBuilder builder) { if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { compressOffset = bodyCompression.writeTo(builder); } + + // Start the variadicBufferCounts vector. + int variadicBufferCountsOffset = 0; + if (variadicBufferCounts != null && !variadicBufferCounts.isEmpty()) { + variadicBufferCountsOffset = variadicBufferCounts.size(); + int elementSizeInBytes = 8; // Size of long in bytes + builder.startVector(elementSizeInBytes, variadicBufferCountsOffset, elementSizeInBytes); + + // Add each long to the builder. Note that elements should be added in reverse order. + for (int i = variadicBufferCounts.size() - 1; i >= 0; i--) { + long value = variadicBufferCounts.get(i); + builder.addLong(value); + } + + // End the vector. This returns an offset that you can use to refer to the vector. + variadicBufferCountsOffset = builder.endVector(); + } + RecordBatch.startRecordBatch(builder); RecordBatch.addLength(builder, length); RecordBatch.addNodes(builder, nodesOffset); @@ -224,6 +290,12 @@ public int writeTo(FlatBufferBuilder builder) { if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { RecordBatch.addCompression(builder, compressOffset); } + + // Add the variadicBufferCounts to the RecordBatch + if (variadicBufferCounts != null && !variadicBufferCounts.isEmpty()) { + RecordBatch.addVariadicBufferCounts(builder, variadicBufferCountsOffset); + } + return RecordBatch.endRecordBatch(builder); } @@ -247,8 +319,13 @@ public void close() { @Override public String toString() { + int variadicBufCount = 0; + if (variadicBufferCounts != null && !variadicBufferCounts.isEmpty()) { + variadicBufCount = variadicBufferCounts.size(); + } return "ArrowRecordBatch [length=" + length + ", nodes=" + nodes + ", #buffers=" + buffers.size() + - ", buffersLayout=" + buffersLayout + ", closed=" + closed + "]"; + ", #variadicBufferCounts=" + variadicBufCount + ", buffersLayout=" + buffersLayout + + ", closed=" + closed + "]"; } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorBufferVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorBufferVisitor.java index 0a67db0455b41..af5a67049f722 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorBufferVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorBufferVisitor.java @@ -51,6 +51,7 @@ private void validateVectorCommon(ValueVector vector) { if (vector instanceof FieldVector) { FieldVector fieldVector = (FieldVector) vector; + // TODO: https://github.com/apache/arrow/issues/41734 int typeBufferCount = TypeLayout.getTypeBufferCount(arrowType); validateOrThrow(fieldVector.getFieldBuffers().size() == typeBufferCount, "Expected %s buffers in vector of type %s, got %s.", diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java b/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java index 97930f433d301..5a58133f2e2bd 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java @@ -17,82 +17,158 @@ package org.apache.arrow.vector; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.Random; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.IntervalUnit; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; public class TestTypeLayout { + private BufferAllocator allocator; + + @BeforeEach + public void prepare() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterEach + public void shutdown() { + allocator.close(); + } + + @Test public void testTypeBufferCount() { ArrowType type = new ArrowType.Int(8, true); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Union(UnionMode.Sparse, new int[2]); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Union(UnionMode.Dense, new int[1]); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Struct(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.List(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.FixedSizeList(5); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Map(false); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Decimal(10, 10, 128); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Decimal(10, 10, 256); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.FixedSizeBinary(5); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Bool(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Binary(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Utf8(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Null(); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Date(DateUnit.DAY); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Time(TimeUnit.MILLISECOND, 32); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Interval(IntervalUnit.DAY_TIME); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); type = new ArrowType.Duration(TimeUnit.MILLISECOND); - assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + } + + private String generateRandomString(int length) { + Random random = new Random(); + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(random.nextInt(10)); // 0-9 + } + return sb.toString(); + } + + @Test + public void testTypeBufferCountInVectorsWithVariadicBuffers() { + // empty vector + try (ViewVarCharVector viewVarCharVector = new ViewVarCharVector("myvector", allocator)) { + ArrowType type = viewVarCharVector.getMinorType().getType(); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + } + // vector with long strings + try (ViewVarCharVector viewVarCharVector = new ViewVarCharVector("myvector", allocator)) { + viewVarCharVector.allocateNew(32, 6); + + viewVarCharVector.setSafe(0, generateRandomString(8).getBytes()); + viewVarCharVector.setSafe(1, generateRandomString(12).getBytes()); + viewVarCharVector.setSafe(2, generateRandomString(14).getBytes()); + viewVarCharVector.setSafe(3, generateRandomString(18).getBytes()); + viewVarCharVector.setSafe(4, generateRandomString(22).getBytes()); + viewVarCharVector.setSafe(5, generateRandomString(24).getBytes()); + + viewVarCharVector.setValueCount(6); + + ArrowType type = viewVarCharVector.getMinorType().getType(); + assertEquals(TypeLayout.getTypeBufferCount(type), + TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java index efb5afac91b13..2d37b0b4eb9ad 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java @@ -31,6 +31,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -41,8 +42,11 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.util.ArrowBufPointer; import org.apache.arrow.memory.util.CommonUtil; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.testing.ValueVectorDataPopulator; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.ReusableByteArray; import org.apache.arrow.vector.util.Text; import org.junit.jupiter.api.AfterEach; @@ -1451,6 +1455,68 @@ public void testSafeOverwriteLongFromALongerLongString() { } } + @Test + public void testVectorLoadUnload() { + + try (final ViewVarCharVector vector1 = new ViewVarCharVector("myvector", allocator)) { + + setVector(vector1, STR1, STR2, STR3, STR4, STR5, STR6); + + assertEquals(5, vector1.getLastSet()); + vector1.setValueCount(15); + assertEquals(14, vector1.getLastSet()); + + /* Check the vector output */ + assertArrayEquals(STR1, vector1.get(0)); + assertArrayEquals(STR2, vector1.get(1)); + assertArrayEquals(STR3, vector1.get(2)); + assertArrayEquals(STR4, vector1.get(3)); + assertArrayEquals(STR5, vector1.get(4)); + assertArrayEquals(STR6, vector1.get(5)); + + Field field = vector1.getField(); + String fieldName = field.getName(); + + List fields = new ArrayList<>(); + List fieldVectors = new ArrayList<>(); + + fields.add(field); + fieldVectors.add(vector1); + + Schema schema = new Schema(fields); + + VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, vector1.getValueCount()); + VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1); + + try ( + ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); + BufferAllocator finalVectorsAllocator = allocator.newChildAllocator("new vector", 0, Long.MAX_VALUE); + VectorSchemaRoot schemaRoot2 = VectorSchemaRoot.create(schema, finalVectorsAllocator); + ) { + + VectorLoader vectorLoader = new VectorLoader(schemaRoot2); + vectorLoader.load(recordBatch); + + ViewVarCharVector vector2 = (ViewVarCharVector) schemaRoot2.getVector(fieldName); + /* + * lastSet would have internally been set by VectorLoader.load() when it invokes + * loadFieldBuffers. + */ + assertEquals(14, vector2.getLastSet()); + vector2.setValueCount(25); + assertEquals(24, vector2.getLastSet()); + + /* Check the vector output */ + assertArrayEquals(STR1, vector2.get(0)); + assertArrayEquals(STR2, vector2.get(1)); + assertArrayEquals(STR3, vector2.get(2)); + assertArrayEquals(STR4, vector2.get(3)); + assertArrayEquals(STR5, vector2.get(4)); + assertArrayEquals(STR6, vector2.get(5)); + } + } + } + private String generateRandomString(int length) { Random random = new Random(); StringBuilder sb = new StringBuilder(length);