From 3fc0cc0cdfcfb9426e04f44e18b85a67b34367c3 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 18 Jun 2019 21:40:16 -0700 Subject: [PATCH] ARROW-5255: [Java] Proof-of-concept of Java extension types Not quite sure that `ExtensionType` and `ExtensionTypeVector` are correct. Also, needs a more sophisticated example/more complete unit tests. Author: David Li Closes #4251 from lihalite/arrow-java-extension-types and squashes the following commits: 7c740ed76 Proof-of-concept of Java extension types --- .../src/main/codegen/templates/ArrowType.java | 60 +++++ .../arrow/vector/ExtensionTypeVector.java | 236 ++++++++++++++++++ .../org/apache/arrow/vector/types/Types.java | 25 +- .../types/pojo/ExtensionTypeRegistry.java | 42 ++++ .../apache/arrow/vector/types/pojo/Field.java | 35 ++- .../arrow/vector/types/pojo/FieldType.java | 17 +- .../vector/types/pojo/TestExtensionType.java | 214 ++++++++++++++++ 7 files changed, 619 insertions(+), 10 deletions(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java create mode 100644 java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java diff --git a/java/vector/src/main/codegen/templates/ArrowType.java b/java/vector/src/main/codegen/templates/ArrowType.java index ae343fc5d3d29..a4d32ef345bda 100644 --- a/java/vector/src/main/codegen/templates/ArrowType.java +++ b/java/vector/src/main/codegen/templates/ArrowType.java @@ -26,7 +26,9 @@ import java.util.Objects; import org.apache.arrow.flatbuf.Type; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.types.*; +import org.apache.arrow.vector.FieldVector; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -108,6 +110,9 @@ public static interface ArrowTypeVisitor { <#list arrowTypes.types as type> T visit(${type.name?remove_ending("_")} type); + default T visit(ExtensionType type) { + return type.storageType().accept(this); + } } /** @@ -246,6 +251,61 @@ public T accept(ArrowTypeVisitor visitor) { } + /** + * A user-defined data type that wraps an underlying storage type. + */ + public abstract static class ExtensionType extends ComplexType { + /** The on-wire type for this user-defined type. */ + public abstract ArrowType storageType(); + /** The name of this user-defined type. Used to identify the type during serialization. */ + public abstract String extensionName(); + /** Check equality of this type to another user-defined type. */ + public abstract boolean extensionEquals(ExtensionType other); + /** Save any metadata for this type. */ + public abstract String serialize(); + /** Given saved metadata and the underlying storage type, construct a new instance of the user type. */ + public abstract ArrowType deserialize(ArrowType storageType, String serializedData); + /** Construct a vector for the user type. */ + public abstract FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator); + + /** The field metadata key storing the name of the extension type. */ + public static final String EXTENSION_METADATA_KEY_NAME = "ARROW:extension:name"; + /** The field metadata key storing metadata for the extension type. */ + public static final String EXTENSION_METADATA_KEY_METADATA = "ARROW:extension:metadata"; + + @Override + public ArrowTypeID getTypeID() { + return storageType().getTypeID(); + } + + @Override + public int getType(FlatBufferBuilder builder) { + return storageType().getType(builder); + } + + public String toString() { + return "ExtensionType(" + extensionName() + ", " + storageType().toString() + ")"; + } + + @Override + public int hashCode() { + return java.util.Arrays.deepHashCode(new Object[] {storageType(), extensionName()}); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof ExtensionType)) { + return false; + } + return this.extensionEquals((ExtensionType) obj); + } + + @Override + public T accept(ArrowTypeVisitor visitor) { + return visitor.visit(this); + } + } + public static org.apache.arrow.vector.types.pojo.ArrowType getTypeForField(org.apache.arrow.flatbuf.Field field) { switch(field.typeType()) { <#list arrowTypes.types as type> diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java new file mode 100644 index 0000000000000..9594d9e581479 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import java.util.Iterator; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.OutOfMemoryException; +import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.CallBack; +import org.apache.arrow.vector.util.TransferPair; + +import io.netty.buffer.ArrowBuf; + +/** + * A vector that wraps an underlying vector, used to help implement extension types. + * @param The wrapped vector type. + */ +public abstract class ExtensionTypeVector extends BaseValueVector implements + FieldVector { + + private final T underlyingVector; + + public ExtensionTypeVector(String name, BufferAllocator allocator, T underlyingVector) { + super(name, allocator); + this.underlyingVector = underlyingVector; + } + + /** Get the underlying vector. */ + public T getUnderlyingVector() { + return underlyingVector; + } + + @Override + public void allocateNew() throws OutOfMemoryException { + this.underlyingVector.allocateNew(); + } + + @Override + public boolean allocateNewSafe() { + return this.underlyingVector.allocateNewSafe(); + } + + @Override + public void reAlloc() { + this.underlyingVector.reAlloc(); + } + + @Override + public void setInitialCapacity(int numRecords) { + this.underlyingVector.setInitialCapacity(numRecords); + } + + @Override + public int getValueCapacity() { + return this.underlyingVector.getValueCapacity(); + } + + @Override + public void reset() { + this.underlyingVector.reset(); + } + + @Override + public Field getField() { + return this.underlyingVector.getField(); + } + + @Override + public MinorType getMinorType() { + return MinorType.EXTENSIONTYPE; + } + + @Override + public TransferPair getTransferPair(String ref, BufferAllocator allocator) { + return underlyingVector.getTransferPair(ref, allocator); + } + + @Override + public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallBack callBack) { + return underlyingVector.getTransferPair(ref, allocator, callBack); + } + + @Override + public TransferPair makeTransferPair(ValueVector target) { + return underlyingVector.makeTransferPair(target); + } + + @Override + public FieldReader getReader() { + return underlyingVector.getReader(); + } + + @Override + public int getBufferSize() { + return underlyingVector.getBufferSize(); + } + + @Override + public int getBufferSizeFor(int valueCount) { + return underlyingVector.getBufferSizeFor(valueCount); + } + + @Override + public ArrowBuf[] getBuffers(boolean clear) { + return underlyingVector.getBuffers(clear); + } + + @Override + public ArrowBuf getValidityBuffer() { + return underlyingVector.getValidityBuffer(); + } + + @Override + public ArrowBuf getDataBuffer() { + return underlyingVector.getDataBuffer(); + } + + @Override + public ArrowBuf getOffsetBuffer() { + return underlyingVector.getOffsetBuffer(); + } + + @Override + public int getValueCount() { + return underlyingVector.getValueCount(); + } + + @Override + public void setValueCount(int valueCount) { + underlyingVector.setValueCount(valueCount); + } + + /** + * Get the extension object at the specified index. + * + *

Generally, this should access the underlying vector and construct the corresponding Java object from the raw + * data. + */ + @Override + public abstract Object getObject(int index); + + @Override + public int getNullCount() { + return underlyingVector.getNullCount(); + } + + @Override + public boolean isNull(int index) { + return underlyingVector.isNull(index); + } + + @Override + public void initializeChildrenFromFields(List children) { + underlyingVector.initializeChildrenFromFields(children); + } + + @Override + public List getChildrenFromFields() { + return underlyingVector.getChildrenFromFields(); + } + + @Override + public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers) { + underlyingVector.loadFieldBuffers(fieldNode, ownBuffers); + } + + @Override + public List getFieldBuffers() { + return underlyingVector.getFieldBuffers(); + } + + @Override + public List getFieldInnerVectors() { + return underlyingVector.getFieldInnerVectors(); + } + + @Override + public long getValidityBufferAddress() { + return underlyingVector.getValidityBufferAddress(); + } + + @Override + public long getDataBufferAddress() { + return underlyingVector.getDataBufferAddress(); + } + + @Override + public long getOffsetBufferAddress() { + return underlyingVector.getOffsetBufferAddress(); + } + + @Override + public void clear() { + underlyingVector.clear(); + } + + @Override + public void close() { + underlyingVector.close(); + } + + @Override + public TransferPair getTransferPair(BufferAllocator allocator) { + return underlyingVector.getTransferPair(allocator); + } + + @Override + public Iterator iterator() { + return underlyingVector.iterator(); + } + + @Override + public BufferAllocator getAllocator() { + return underlyingVector.getAllocator(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java index 709610476aa75..9eb7d95ab391d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.DateMilliVector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.Float4Vector; @@ -105,6 +106,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType.Date; import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; import org.apache.arrow.vector.types.pojo.ArrowType.Duration; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeBinary; import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList; import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint; @@ -710,7 +712,23 @@ public FieldVector getNewVector( public FieldWriter getNewFieldWriter(ValueVector vector) { return new TimeStampNanoTZWriterImpl((TimeStampNanoTZVector) vector); } - }; + }, + EXTENSIONTYPE(null) { + @Override + public FieldVector getNewVector( + String name, + FieldType fieldType, + BufferAllocator allocator, + CallBack schemaChangeCallback) { + return ((ExtensionType) fieldType.getType()).getNewVector(name, fieldType, allocator); + } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + return ((ExtensionTypeVector) vector).getUnderlyingVector().getMinorType().getNewFieldWriter(vector); + } + }, + ; private final ArrowType type; @@ -889,6 +907,11 @@ public MinorType visit(Interval type) { public MinorType visit(Duration type) { return MinorType.DURATION; } + + @Override + public MinorType visit(ExtensionType type) { + return MinorType.EXTENSIONTYPE; + } }); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java new file mode 100644 index 0000000000000..f347008b42627 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.types.pojo; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; + +/** + * A registry of recognized extension types. + */ +public final class ExtensionTypeRegistry { + private static final ConcurrentMap registry = new ConcurrentHashMap<>(); + + public static void register(ExtensionType type) { + registry.put(type.extensionName(), type); + } + + public static void unregister(ExtensionType type) { + registry.remove(type.extensionName()); + } + + public static ExtensionType lookup(String name) { + return registry.get(name); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java index 54a69fcf56e0f..99ceb6a0f993a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java @@ -36,6 +36,10 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.TypeLayout; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -49,6 +53,8 @@ */ public class Field { + private static final Logger logger = LoggerFactory.getLogger(Field.class); + public static Field nullablePrimitive(String name, ArrowType.PrimitiveType type) { return nullable(name, type); } @@ -111,9 +117,30 @@ public FieldVector createVector(BufferAllocator allocator) { * Constructs a new instance from a flatbuffer representation of the field. */ public static Field convertField(org.apache.arrow.flatbuf.Field field) { + Map metadata = new HashMap<>(); + for (int i = 0; i < field.customMetadataLength(); i++) { + KeyValue kv = field.customMetadata(i); + String key = kv.key(); + String value = kv.value(); + metadata.put(key == null ? "" : key, value == null ? "" : value); + } + metadata = Collections.unmodifiableMap(metadata); + String name = field.name(); boolean nullable = field.nullable(); ArrowType type = getTypeForField(field); + + if (metadata.containsKey(ExtensionType.EXTENSION_METADATA_KEY_NAME)) { + final String extensionName = metadata.get(ExtensionType.EXTENSION_METADATA_KEY_NAME); + final String extensionMetadata = metadata.getOrDefault(ExtensionType.EXTENSION_METADATA_KEY_METADATA, ""); + ExtensionType extensionType = ExtensionTypeRegistry.lookup(extensionName); + if (extensionType != null) { + type = extensionType.deserialize(type, extensionMetadata); + } + // Otherwise, we haven't registered the type + logger.info("Unrecognized extension type: {}", extensionName); + } + DictionaryEncoding dictionary = null; org.apache.arrow.flatbuf.DictionaryEncoding dictionaryFB = field.dictionary(); if (dictionaryFB != null) { @@ -131,14 +158,6 @@ public static Field convertField(org.apache.arrow.flatbuf.Field field) { children.add(childField); } children = Collections.unmodifiableList(children); - Map metadata = new HashMap<>(); - for (int i = 0; i < field.customMetadataLength(); i++) { - KeyValue kv = field.customMetadata(i); - String key = kv.key(); - String value = kv.value(); - metadata.put(key == null ? "" : key, value == null ? "" : value); - } - metadata = Collections.unmodifiableMap(metadata); return new Field(name, nullable, type, dictionary, children, metadata); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java index cf33e56b25a11..4cc4067c997c3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java @@ -17,6 +17,7 @@ package org.apache.arrow.vector.types.pojo; +import java.util.HashMap; import java.util.Map; import org.apache.arrow.memory.BufferAllocator; @@ -25,6 +26,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.arrow.vector.util.CallBack; /** @@ -59,7 +61,20 @@ public FieldType(boolean nullable, ArrowType type, DictionaryEncoding dictionary this.nullable = nullable; this.type = Preconditions.checkNotNull(type); this.dictionary = dictionary; - this.metadata = metadata == null ? java.util.Collections.emptyMap() : Collections2.immutableMapCopy(metadata); + if (type instanceof ExtensionType) { + // Save the extension type name/metadata + final Map extensionMetadata = new HashMap<>(); + extensionMetadata.put(ExtensionType.EXTENSION_METADATA_KEY_NAME, ((ExtensionType) type).extensionName()); + extensionMetadata.put(ExtensionType.EXTENSION_METADATA_KEY_METADATA, ((ExtensionType) type).serialize()); + if (metadata != null) { + for (Map.Entry entry : metadata.entrySet()) { + extensionMetadata.put(entry.getKey(), entry.getValue()); + } + } + this.metadata = Collections2.immutableMapCopy(extensionMetadata); + } else { + this.metadata = metadata == null ? java.util.Collections.emptyMap() : Collections2.immutableMapCopy(metadata); + } } public boolean isNullable() { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java new file mode 100644 index 0000000000000..20d270c8988e6 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.types.pojo; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Collections; +import java.util.UUID; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ExtensionTypeVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowFileReader; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; + +import org.junit.Assert; +import org.junit.Test; + +public class TestExtensionType { + /** + * Test that a custom UUID type can be round-tripped through a temporary file. + */ + @Test + public void roundtripUuid() throws IOException { + ExtensionTypeRegistry.register(new UuidType()); + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new UuidType()))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + UuidVector vector = (UuidVector) root.getVector("a"); + vector.setValueCount(2); + vector.set(0, u1); + vector.set(1, u2); + root.setRowCount(2); + + final File file = File.createTempFile("uuidtest", ".arrow"); + try (final WritableByteChannel channel = FileChannel + .open(Paths.get(file.getAbsolutePath()), StandardOpenOption.WRITE); + final ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + try (final SeekableByteChannel channel = Files.newByteChannel(Paths.get(file.getAbsolutePath())); + final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { + reader.loadNextBatch(); + final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + Assert.assertEquals(root.getSchema(), readerRoot.getSchema()); + + final Field field = readerRoot.getSchema().getFields().get(0); + final UuidType expectedType = new UuidType(); + Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + expectedType.extensionName()); + Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + expectedType.serialize()); + + final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); + Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + if (!vector.isNull(i)) { + Assert.assertEquals(vector.getObject(i), deserialized.getObject(i)); + } + } + } + } + } + + /** + * Test that a custom UUID type can be read as its underlying type. + */ + @Test + public void readUnderlyingType() throws IOException { + ExtensionTypeRegistry.register(new UuidType()); + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new UuidType()))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + UuidVector vector = (UuidVector) root.getVector("a"); + vector.setValueCount(2); + vector.set(0, u1); + vector.set(1, u2); + root.setRowCount(2); + + final File file = File.createTempFile("uuidtest", ".arrow"); + try (final WritableByteChannel channel = FileChannel + .open(Paths.get(file.getAbsolutePath()), StandardOpenOption.WRITE); + final ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + ExtensionTypeRegistry.unregister(new UuidType()); + + try (final SeekableByteChannel channel = Files.newByteChannel(Paths.get(file.getAbsolutePath())); + final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { + reader.loadNextBatch(); + final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + Assert.assertEquals(1, readerRoot.getSchema().getFields().size()); + Assert.assertEquals("a", readerRoot.getSchema().getFields().get(0).getName()); + Assert.assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary); + Assert.assertEquals(16, + ((ArrowType.FixedSizeBinary) readerRoot.getSchema().getFields().get(0).getType()).getByteWidth()); + + final Field field = readerRoot.getSchema().getFields().get(0); + final UuidType expectedType = new UuidType(); + Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + expectedType.extensionName()); + Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + expectedType.serialize()); + + final FixedSizeBinaryVector deserialized = (FixedSizeBinaryVector) readerRoot.getFieldVectors().get(0); + Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + if (!vector.isNull(i)) { + final UUID uuid = vector.getObject(i); + final ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + Assert.assertArrayEquals(bb.array(), deserialized.get(i)); + } + } + } + } + } + + static class UuidType extends ExtensionType { + + @Override + public ArrowType storageType() { + return new ArrowType.FixedSizeBinary(16); + } + + @Override + public String extensionName() { + return "uuid"; + } + + @Override + public boolean extensionEquals(ExtensionType other) { + return other instanceof UuidType; + } + + @Override + public ArrowType deserialize(ArrowType storageType, String serializedData) { + if (!storageType.equals(storageType())) { + throw new UnsupportedOperationException("Cannot construct UuidType from underlying type " + storageType); + } + return new UuidType(); + } + + @Override + public String serialize() { + return ""; + } + + @Override + public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) { + return new UuidVector(name, allocator, new FixedSizeBinaryVector(name, allocator, 16)); + } + + } + + static class UuidVector extends ExtensionTypeVector { + + public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) { + super(name, allocator, underlyingVector); + } + + @Override + public UUID getObject(int index) { + final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index)); + return new UUID(bb.getLong(), bb.getLong()); + } + + public void set(int index, UUID uuid) { + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + getUnderlyingVector().set(index, bb.array()); + } + } +}