From 8ca7033fcd3fcf377cb7924eae9be45b8f6ebd5d Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Fri, 20 Jan 2017 17:56:23 -0500 Subject: [PATCH] ARROW-499: Update file serialization to use the streaming serialization format. Author: Wes McKinney Author: Nong Li Closes #292 from nongli/file and squashes the following commits: 18890a9 [Wes McKinney] Message fixes. Fix Java test suite. Integration tests pass f187539 [Nong Li] Merge pull request #1 from wesm/file-change-cpp-impl e3af434 [Wes McKinney] Remove unused variable 664d5be [Wes McKinney] Fixes, stream tests pass again ba8db91 [Wes McKinney] Redo MessageSerializer with unions. Still has bugs 21854cc [Wes McKinney] Restore Block.bodyLength to long 7c6f7ef [Nong Li] Update to restore Block behavior 27b3909 [Nong Li] [ARROW-499]: [Java] Update file serialization to use the streaming serialization format. --- cpp/src/arrow/ipc/adapter.cc | 11 +- cpp/src/arrow/ipc/metadata-internal.cc | 21 +-- format/File.fbs | 5 +- integration/integration_test.py | 2 +- .../apache/arrow/vector/file/ArrowFooter.java | 5 +- .../apache/arrow/vector/file/ArrowReader.java | 64 ++----- .../apache/arrow/vector/file/ArrowWriter.java | 43 +---- .../apache/arrow/vector/file/ReadChannel.java | 11 +- .../vector/stream/MessageSerializer.java | 169 +++++++++++------- .../arrow/vector/file/TestArrowFile.java | 4 - .../arrow/vector/file/TestArrowFooter.java | 8 + .../vector/file/TestArrowReaderWriter.java | 16 ++ 12 files changed, 174 insertions(+), 185 deletions(-) diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index 2b5ef11f861af..7b4d18c267d43 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor { num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb)); // Need to write 4 bytes (metadata size), the metadata, plus padding to - // fall on a 64-byte offset - int64_t padded_metadata_length = - BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4); + // fall on an 8-byte offset + int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4); // The returned metadata size includes the length prefix, the flatbuffer, // plus padding - *metadata_length = padded_metadata_length; + *metadata_length = static_cast(padded_metadata_length); // Write the flatbuffer size prefix int32_t flatbuffer_size = metadata_fb->size(); @@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length, return Status::Invalid(ss.str()); } - *metadata = std::make_shared(buffer, sizeof(int32_t)); + std::shared_ptr message; + RETURN_NOT_OK(Message::Open(buffer, 4, &message)); + *metadata = std::make_shared(message); return Status::OK(); } diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index 16069a8f9dcf0..cc160c42ec9ef 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length, Status WriteRecordBatchMetadata(int32_t length, int64_t body_length, const std::vector& nodes, const std::vector& buffers, std::shared_ptr* out) { - flatbuffers::FlatBufferBuilder fbb; - - auto batch = flatbuf::CreateRecordBatch( - fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers)); - - fbb.Finish(batch); - - int32_t size = fbb.GetSize(); - - auto result = std::make_shared(); - RETURN_NOT_OK(result->Resize(size)); - - uint8_t* dst = result->mutable_data(); - memcpy(dst, fbb.GetBufferPointer(), size); - - *out = result; - return Status::OK(); + MessageBuilder builder; + RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers)); + RETURN_NOT_OK(builder.Finish()); + return builder.GetBuffer(out); } Status MessageBuilder::Finish() { diff --git a/format/File.fbs b/format/File.fbs index f28dc204d58d9..e8d6da4f848ff 100644 --- a/format/File.fbs +++ b/format/File.fbs @@ -35,12 +35,15 @@ table Footer { struct Block { + /// Index to the start of the RecordBlock (note this is past the Message header) offset: long; + /// Length of the metadata metaDataLength: int; + /// Length of the data (this is aligned so there can be a gap between this and + /// the metatdata). bodyLength: long; - } root_type Footer; diff --git a/integration/integration_test.py b/integration/integration_test.py index 417354bc83d9e..77510daecc0b4 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -648,7 +648,7 @@ def get_static_json_files(): def run_all_tests(debug=False): - testers = [JavaTester(debug=debug), CPPTester(debug=debug)] + testers = [CPPTester(debug=debug), JavaTester(debug=debug)] static_json_files = get_static_json_files() generated_json_files = get_generated_json_files() json_files = static_json_files + generated_json_files diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java index 3be19296cb56d..38903068570c7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java @@ -65,10 +65,11 @@ private static List recordBatches(Footer footer) { private static List dictionaries(Footer footer) { List dictionaries = new ArrayList<>(); - Block tempBLock = new Block(); + Block tempBlock = new Block(); + int dictionariesLength = footer.dictionariesLength(); for (int i = 0; i < dictionariesLength; i++) { - Block block = footer.dictionaries(tempBLock, i); + Block block = footer.dictionaries(tempBlock, i); dictionaries.add(new ArrowBlock(block.offset(), block.metaDataLength(), block.bodyLength())); } return dictionaries; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java index 58c51605c5600..8f4f4978d66cf 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java @@ -20,23 +20,15 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SeekableByteChannel; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; -import org.apache.arrow.flatbuf.Buffer; -import org.apache.arrow.flatbuf.FieldNode; import org.apache.arrow.flatbuf.Footer; -import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.stream.MessageSerializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.netty.buffer.ArrowBuf; - public class ArrowReader implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowReader.class); @@ -54,15 +46,6 @@ public ArrowReader(SeekableByteChannel in, BufferAllocator allocator) { this.allocator = allocator; } - private int readFully(ArrowBuf buffer, int l) throws IOException { - int n = readFully(buffer.nioBuffer(buffer.writerIndex(), l)); - buffer.writerIndex(n); - if (n != l) { - throw new IllegalStateException(n + " != " + l); - } - return n; - } - private int readFully(ByteBuffer buffer) throws IOException { int total = 0; int n; @@ -104,46 +87,21 @@ public ArrowFooter readFooter() throws IOException { // TODO: read dictionaries - public ArrowRecordBatch readRecordBatch(ArrowBlock recordBatchBlock) throws IOException { - LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", recordBatchBlock.getOffset(), recordBatchBlock.getMetadataLength(), recordBatchBlock.getBodyLength())); - int l = (int)(recordBatchBlock.getMetadataLength() + recordBatchBlock.getBodyLength()); - if (l < 0) { - throw new InvalidArrowFileException("block invalid: " + recordBatchBlock); - } - final ArrowBuf buffer = allocator.buffer(l); - LOGGER.debug("allocated buffer " + buffer); - in.position(recordBatchBlock.getOffset()); - int n = readFully(buffer, l); - if (n != l) { - throw new IllegalStateException(n + " != " + l); - } - - // Record batch flatbuffer is prefixed by its size as int32le - final ArrowBuf metadata = buffer.slice(4, recordBatchBlock.getMetadataLength() - 4); - RecordBatch recordBatchFB = RecordBatch.getRootAsRecordBatch(metadata.nioBuffer().asReadOnlyBuffer()); - - int nodesLength = recordBatchFB.nodesLength(); - final ArrowBuf body = buffer.slice(recordBatchBlock.getMetadataLength(), (int)recordBatchBlock.getBodyLength()); - List nodes = new ArrayList<>(); - for (int i = 0; i < nodesLength; ++i) { - FieldNode node = recordBatchFB.nodes(i); - nodes.add(new ArrowFieldNode(node.length(), node.nullCount())); + public ArrowRecordBatch readRecordBatch(ArrowBlock block) throws IOException { + LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), + block.getBodyLength())); + in.position(block.getOffset()); + ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch( + new ReadChannel(in, block.getOffset()), block, allocator); + if (batch == null) { + throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); } - List buffers = new ArrayList<>(); - for (int i = 0; i < recordBatchFB.buffersLength(); ++i) { - Buffer bufferFB = recordBatchFB.buffers(i); - LOGGER.debug(String.format("Buffer in RecordBatch at %d, length: %d", bufferFB.offset(), bufferFB.length())); - ArrowBuf vectorBuffer = body.slice((int)bufferFB.offset(), (int)bufferFB.length()); - buffers.add(vectorBuffer); - } - ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers); - LOGGER.debug("released buffer " + buffer); - buffer.release(); - return arrowRecordBatch; + return batch; } + @Override public void close() throws IOException { in.close(); } - } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java index 3febd11f4c76a..24c667e67d98d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java @@ -23,14 +23,12 @@ import java.util.Collections; import java.util.List; -import org.apache.arrow.vector.schema.ArrowBuffer; import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.netty.buffer.ArrowBuf; - public class ArrowWriter implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class); @@ -39,7 +37,6 @@ public class ArrowWriter implements AutoCloseable { private final Schema schema; private final List recordBatches = new ArrayList<>(); - private boolean started = false; public ArrowWriter(WritableByteChannel out, Schema schema) { @@ -49,47 +46,19 @@ public ArrowWriter(WritableByteChannel out, Schema schema) { private void start() throws IOException { writeMagic(); + MessageSerializer.serialize(out, schema); } - // TODO: write dictionaries public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException { checkStarted(); - out.align(); + ArrowBlock batchDesc = MessageSerializer.serialize(out, recordBatch); + LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", + batchDesc.getOffset(), batchDesc.getMetadataLength(), batchDesc.getBodyLength())); - // write metadata header with int32 size prefix - long offset = out.getCurrentPosition(); - out.write(recordBatch, true); - out.align(); - // write body - long bodyOffset = out.getCurrentPosition(); - List buffers = recordBatch.getBuffers(); - List buffersLayout = recordBatch.getBuffersLayout(); - if (buffers.size() != buffersLayout.size()) { - throw new IllegalStateException("the layout does not match: " + buffers.size() + " != " + buffersLayout.size()); - } - for (int i = 0; i < buffers.size(); i++) { - ArrowBuf buffer = buffers.get(i); - ArrowBuffer layout = buffersLayout.get(i); - long startPosition = bodyOffset + layout.getOffset(); - if (startPosition != out.getCurrentPosition()) { - out.writeZeros((int)(startPosition - out.getCurrentPosition())); - } - - out.write(buffer); - if (out.getCurrentPosition() != startPosition + layout.getSize()) { - throw new IllegalStateException("wrong buffer size: " + out.getCurrentPosition() + " != " + startPosition + layout.getSize()); - } - } - int metadataLength = (int)(bodyOffset - offset); - if (metadataLength <= 0) { - throw new InvalidArrowFileException("invalid recordBatch"); - } - long bodyLength = out.getCurrentPosition() - bodyOffset; - LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", offset, metadataLength, bodyLength)); // add metadata to footer - recordBatches.add(new ArrowBlock(offset, metadataLength, bodyLength)); + recordBatches.add(batchDesc); } private void checkStarted() throws IOException { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java index b062f3826eab3..a9dc1293b8193 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java @@ -32,9 +32,16 @@ public class ReadChannel implements AutoCloseable { private ReadableByteChannel in; private long bytesRead = 0; + // The starting byte offset into 'in'. + private final long startByteOffset; - public ReadChannel(ReadableByteChannel in) { + public ReadChannel(ReadableByteChannel in, long startByteOffset) { this.in = in; + this.startByteOffset = startByteOffset; + } + + public ReadChannel(ReadableByteChannel in) { + this(in, 0); } public long bytesRead() { return bytesRead; } @@ -65,6 +72,8 @@ public int readFully(ArrowBuf buffer, int l) throws IOException { return n; } + public long getCurrentPositiion() { return startByteOffset + bytesRead; } + @Override public void close() throws IOException { if (this.in != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java index 22c46e2817b1e..6e22dbd164d6e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java @@ -29,6 +29,7 @@ import org.apache.arrow.flatbuf.MetadataVersion; import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.file.ArrowBlock; import org.apache.arrow.vector.file.ReadChannel; import org.apache.arrow.vector.file.WriteChannel; import org.apache.arrow.vector.schema.ArrowBuffer; @@ -52,7 +53,8 @@ * For RecordBatch messages the serialization is: * 1. 4 byte little endian batch metadata header * 2. FB serialized RowBatch - * 3. serialized RowBatch buffers. + * 3. Padding to align to 8 byte boundary. + * 4. serialized RowBatch buffers. */ public class MessageSerializer { @@ -68,14 +70,10 @@ public static int bytesToInt(byte[] bytes) { */ public static long serialize(WriteChannel out, Schema schema) throws IOException { FlatBufferBuilder builder = new FlatBufferBuilder(); - builder.finish(schema.getSchema(builder)); - ByteBuffer serializedBody = builder.dataBuffer(); - ByteBuffer serializedHeader = - serializeHeader(MessageHeader.Schema, serializedBody.remaining()); - - long size = out.writeIntLittleEndian(serializedHeader.remaining()); - size += out.write(serializedHeader); - size += out.write(serializedBody); + int schemaOffset = schema.getSchema(builder); + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0); + long size = out.writeIntLittleEndian(serializedMessage.remaining()); + size += out.write(serializedMessage); return size; } @@ -83,49 +81,51 @@ public static long serialize(WriteChannel out, Schema schema) throws IOException * Deserializes a schema object. Format is from serialize(). */ public static Schema deserializeSchema(ReadChannel in) throws IOException { - Message header = deserializeHeader(in, MessageHeader.Schema); - if (header == null) { + Message message = deserializeMessage(in, MessageHeader.Schema); + if (message == null) { throw new IOException("Unexpected end of input. Missing schema."); } - // Now read the schema. - ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength()); - if (in.readFully(buffer) != header.bodyLength()) { - throw new IOException("Unexpected end of input trying to read schema."); - } - buffer.rewind(); - return Schema.deserialize(buffer); + return Schema.convertSchema((org.apache.arrow.flatbuf.Schema) + message.header(new org.apache.arrow.flatbuf.Schema())); } /** - * Serializes an ArrowRecordBatch. + * Serializes an ArrowRecordBatch. Returns the offset and length of the written batch. */ - public static long serialize(WriteChannel out, ArrowRecordBatch batch) + public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) throws IOException { long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); - ByteBuffer metadata = WriteChannel.serialize(batch); - ByteBuffer serializedHeader = - serializeHeader(MessageHeader.RecordBatch, bodyLength + metadata.remaining() + 4); + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, + batchOffset, bodyLength); + + int metadataLength = serializedMessage.remaining(); + + // Add extra padding bytes so that length prefix + metadata is a multiple + // of 8 after alignment + if ((start + metadataLength + 4) % 8 != 0) { + metadataLength += 8 - (start + metadataLength + 4) % 8; + } - // Write message header. - out.writeIntLittleEndian(serializedHeader.remaining()); - out.write(serializedHeader); + out.writeIntLittleEndian(metadataLength); + out.write(serializedMessage); - // Write the metadata, with the 4 byte little endian prefix - out.writeIntLittleEndian(metadata.remaining()); - out.write(metadata); + // Align the output to 8 byte boundary. + out.align(); - // Write batch header. - long offset = out.getCurrentPosition(); + long bufferStart = out.getCurrentPosition(); List buffers = batch.getBuffers(); List buffersLayout = batch.getBuffersLayout(); for (int i = 0; i < buffers.size(); i++) { ArrowBuf buffer = buffers.get(i); ArrowBuffer layout = buffersLayout.get(i); - long startPosition = offset + layout.getOffset(); + long startPosition = bufferStart + layout.getOffset(); if (startPosition != out.getCurrentPosition()) { out.writeZeros((int)(startPosition - out.getCurrentPosition())); } @@ -135,7 +135,8 @@ public static long serialize(WriteChannel out, ArrowRecordBatch batch) " != " + startPosition + layout.getSize()); } } - return out.getCurrentPosition() - start; + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart); } /** @@ -143,23 +144,62 @@ public static long serialize(WriteChannel out, ArrowRecordBatch batch) */ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, BufferAllocator alloc) throws IOException { - Message header = deserializeHeader(in, MessageHeader.RecordBatch); - if (header == null) return null; + Message message = deserializeMessage(in, MessageHeader.RecordBatch); + if (message == null) return null; + + if (message.bodyLength() > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch()); + + int bodyLength = (int) message.bodyLength(); + + // Now read the record batch body + ArrowBuf buffer = alloc.buffer(bodyLength); + if (in.readFully(buffer, bodyLength) != bodyLength) { + throw new IOException("Unexpected end of input trying to read batch."); + } + return deserializeRecordBatch(recordBatchFB, buffer); + } + + /** + * Deserializes a RecordBatch knowing the size of the entire message up front. This + * minimizes the number of reads to the underlying stream. + */ + public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block, + BufferAllocator alloc) throws IOException { + // Metadata length contains integer prefix plus byte padding + long totalLen = block.getMetadataLength() + block.getBodyLength(); - int messageLen = (int)header.bodyLength(); - // Now read the buffer. This has the metadata followed by the data. - ArrowBuf buffer = alloc.buffer(messageLen); - if (in.readFully(buffer, messageLen) != messageLen) { + if (totalLen > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + ArrowBuf buffer = alloc.buffer((int) totalLen); + if (in.readFully(buffer, (int) totalLen) != totalLen) { throw new IOException("Unexpected end of input trying to read batch."); } - // Read the metadata. It starts with the 4 byte size of the metadata. - int metadataSize = buffer.readInt(); - RecordBatch recordBatchFB = - RecordBatch.getRootAsRecordBatch( buffer.nioBuffer().asReadOnlyBuffer()); + ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); + + Message messageFB = + Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer()); + + RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch()); + + // Now read the body + final ArrowBuf body = buffer.slice(block.getMetadataLength(), + (int) totalLen - block.getMetadataLength()); + ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body); + + return result; + } - // No read the body - final ArrowBuf body = buffer.slice(4 + metadataSize, messageLen - metadataSize - 4); + // Deserializes a record batch given the Flatbuffer metadata and in-memory body + private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, + ArrowBuf body) { + // Now read the body int nodesLength = recordBatchFB.nodesLength(); List nodes = new ArrayList<>(); for (int i = 0; i < nodesLength; ++i) { @@ -174,43 +214,44 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, } ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers); - buffer.release(); + body.release(); return arrowRecordBatch; } /** * Serializes a message header. */ - private static ByteBuffer serializeHeader(byte headerType, int bodyLength) { - FlatBufferBuilder headerBuilder = new FlatBufferBuilder(); - Message.startMessage(headerBuilder); - Message.addHeaderType(headerBuilder, headerType); - Message.addVersion(headerBuilder, MetadataVersion.V1); - Message.addBodyLength(headerBuilder, bodyLength); - headerBuilder.finish(Message.endMessage(headerBuilder)); - return headerBuilder.dataBuffer(); + private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType, + int headerOffset, int bodyLength) { + Message.startMessage(builder); + Message.addHeaderType(builder, headerType); + Message.addHeader(builder, headerOffset); + Message.addVersion(builder, MetadataVersion.V1); + Message.addBodyLength(builder, bodyLength); + builder.finish(Message.endMessage(builder)); + return builder.dataBuffer(); } - private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException { - // Read the header size. There is an i32 little endian prefix. + private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException { + // Read the message size. There is an i32 little endian prefix. ByteBuffer buffer = ByteBuffer.allocate(4); if (in.readFully(buffer) != 4) { return null; } - int headerLength = bytesToInt(buffer.array()); - buffer = ByteBuffer.allocate(headerLength); - if (in.readFully(buffer) != headerLength) { + int messageLength = bytesToInt(buffer.array()); + buffer = ByteBuffer.allocate(messageLength); + if (in.readFully(buffer) != messageLength) { throw new IOException( - "Unexpected end of stream trying to read header."); + "Unexpected end of stream trying to read message."); } buffer.rewind(); - Message header = Message.getRootAsMessage(buffer); - if (header.headerType() != headerType) { + Message message = Message.getRootAsMessage(buffer); + if (message.headerType() != headerType) { throw new IOException("Invalid message: expecting " + headerType + - ". Message contained: " + header.headerType()); + ". Message contained: " + message.headerType()); } - return header; + return message; } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java index bf635fb39f5b8..9b9914480bad0 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java @@ -109,8 +109,6 @@ public void testWriteRead() throws IOException { List recordBatches = footer.getRecordBatches(); for (ArrowBlock rbBlock : recordBatches) { - Assert.assertEquals(0, rbBlock.getOffset() % 8); - Assert.assertEquals(0, rbBlock.getMetadataLength() % 8); try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { List buffersLayout = recordBatch.getBuffersLayout(); for (ArrowBuffer arrowBuffer : buffersLayout) { @@ -271,8 +269,6 @@ public void testWriteReadMultipleRBs() throws IOException { for (ArrowBlock rbBlock : recordBatches) { Assert.assertTrue(rbBlock.getOffset() + " > " + previousOffset, rbBlock.getOffset() > previousOffset); previousOffset = rbBlock.getOffset(); - Assert.assertEquals(0, rbBlock.getOffset() % 8); - Assert.assertEquals(0, rbBlock.getMetadataLength() % 8); try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { Assert.assertEquals("RB #" + i, counts[i], recordBatch.getLength()); List buffersLayout = recordBatch.getBuffersLayout(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFooter.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFooter.java index 707dba2af9898..1e514585e502f 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFooter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFooter.java @@ -21,7 +21,9 @@ import static org.junit.Assert.assertEquals; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import org.apache.arrow.flatbuf.Footer; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -41,6 +43,12 @@ public void test() { ArrowFooter footer = new ArrowFooter(schema, Collections.emptyList(), Collections.emptyList()); ArrowFooter newFooter = roundTrip(footer); assertEquals(footer, newFooter); + + List ids = new ArrayList<>(); + ids.add(new ArrowBlock(0, 1, 2)); + ids.add(new ArrowBlock(4, 5, 6)); + footer = new ArrowFooter(schema, ids, ids); + assertEquals(footer, roundTrip(footer)); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java index 8ed89fa347b3b..96bcbb1dae71c 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java @@ -24,10 +24,14 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.util.Collections; import java.util.List; +import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.Message; +import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.schema.ArrowFieldNode; @@ -96,6 +100,18 @@ public void test() throws IOException { assertArrayEquals(validity, array(buffers.get(0))); assertArrayEquals(values, array(buffers.get(1))); + // Read just the header. This demonstrates being able to read without need to + // deserialize the buffer. + ByteBuffer headerBuffer = ByteBuffer.allocate(recordBatches.get(0).getMetadataLength()); + headerBuffer.put(byteArray, (int)recordBatches.get(0).getOffset(), headerBuffer.capacity()); + headerBuffer.position(4); + Message messageFB = Message.getRootAsMessage(headerBuffer); + RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch()); + assertEquals(2, recordBatchFB.buffersLength()); + assertEquals(1, recordBatchFB.nodesLength()); + FieldNode nodeFB = recordBatchFB.nodes(0); + assertEquals(16, nodeFB.length()); + assertEquals(8, nodeFB.nullCount()); } }