extends BaseValueVector implements
+ FieldVector {
+
+ private final T underlyingVector;
+
+ public ExtensionTypeVector(String name, BufferAllocator allocator, T underlyingVector) {
+ super(name, allocator);
+ this.underlyingVector = underlyingVector;
+ }
+
+ /** Get the underlying vector. */
+ public T getUnderlyingVector() {
+ return underlyingVector;
+ }
+
+ @Override
+ public void allocateNew() throws OutOfMemoryException {
+ this.underlyingVector.allocateNew();
+ }
+
+ @Override
+ public boolean allocateNewSafe() {
+ return this.underlyingVector.allocateNewSafe();
+ }
+
+ @Override
+ public void reAlloc() {
+ this.underlyingVector.reAlloc();
+ }
+
+ @Override
+ public void setInitialCapacity(int numRecords) {
+ this.underlyingVector.setInitialCapacity(numRecords);
+ }
+
+ @Override
+ public int getValueCapacity() {
+ return this.underlyingVector.getValueCapacity();
+ }
+
+ @Override
+ public void reset() {
+ this.underlyingVector.reset();
+ }
+
+ @Override
+ public Field getField() {
+ return this.underlyingVector.getField();
+ }
+
+ @Override
+ public MinorType getMinorType() {
+ return MinorType.EXTENSIONTYPE;
+ }
+
+ @Override
+ public TransferPair getTransferPair(String ref, BufferAllocator allocator) {
+ return underlyingVector.getTransferPair(ref, allocator);
+ }
+
+ @Override
+ public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallBack callBack) {
+ return underlyingVector.getTransferPair(ref, allocator, callBack);
+ }
+
+ @Override
+ public TransferPair makeTransferPair(ValueVector target) {
+ return underlyingVector.makeTransferPair(target);
+ }
+
+ @Override
+ public FieldReader getReader() {
+ return underlyingVector.getReader();
+ }
+
+ @Override
+ public int getBufferSize() {
+ return underlyingVector.getBufferSize();
+ }
+
+ @Override
+ public int getBufferSizeFor(int valueCount) {
+ return underlyingVector.getBufferSizeFor(valueCount);
+ }
+
+ @Override
+ public ArrowBuf[] getBuffers(boolean clear) {
+ return underlyingVector.getBuffers(clear);
+ }
+
+ @Override
+ public ArrowBuf getValidityBuffer() {
+ return underlyingVector.getValidityBuffer();
+ }
+
+ @Override
+ public ArrowBuf getDataBuffer() {
+ return underlyingVector.getDataBuffer();
+ }
+
+ @Override
+ public ArrowBuf getOffsetBuffer() {
+ return underlyingVector.getOffsetBuffer();
+ }
+
+ @Override
+ public int getValueCount() {
+ return underlyingVector.getValueCount();
+ }
+
+ @Override
+ public void setValueCount(int valueCount) {
+ underlyingVector.setValueCount(valueCount);
+ }
+
+ /**
+ * Get the extension object at the specified index.
+ *
+ * Generally, this should access the underlying vector and construct the corresponding Java object from the raw
+ * data.
+ */
+ @Override
+ public abstract Object getObject(int index);
+
+ @Override
+ public int getNullCount() {
+ return underlyingVector.getNullCount();
+ }
+
+ @Override
+ public boolean isNull(int index) {
+ return underlyingVector.isNull(index);
+ }
+
+ @Override
+ public void initializeChildrenFromFields(List children) {
+ underlyingVector.initializeChildrenFromFields(children);
+ }
+
+ @Override
+ public List getChildrenFromFields() {
+ return underlyingVector.getChildrenFromFields();
+ }
+
+ @Override
+ public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers) {
+ underlyingVector.loadFieldBuffers(fieldNode, ownBuffers);
+ }
+
+ @Override
+ public List getFieldBuffers() {
+ return underlyingVector.getFieldBuffers();
+ }
+
+ @Override
+ public List getFieldInnerVectors() {
+ return underlyingVector.getFieldInnerVectors();
+ }
+
+ @Override
+ public long getValidityBufferAddress() {
+ return underlyingVector.getValidityBufferAddress();
+ }
+
+ @Override
+ public long getDataBufferAddress() {
+ return underlyingVector.getDataBufferAddress();
+ }
+
+ @Override
+ public long getOffsetBufferAddress() {
+ return underlyingVector.getOffsetBufferAddress();
+ }
+
+ @Override
+ public void clear() {
+ underlyingVector.clear();
+ }
+
+ @Override
+ public void close() {
+ underlyingVector.close();
+ }
+
+ @Override
+ public TransferPair getTransferPair(BufferAllocator allocator) {
+ return underlyingVector.getTransferPair(allocator);
+ }
+
+ @Override
+ public Iterator iterator() {
+ return underlyingVector.iterator();
+ }
+
+ @Override
+ public BufferAllocator getAllocator() {
+ return underlyingVector.getAllocator();
+ }
+}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java
index 709610476aa75..9eb7d95ab391d 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java
@@ -28,6 +28,7 @@
import org.apache.arrow.vector.DateMilliVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.DurationVector;
+import org.apache.arrow.vector.ExtensionTypeVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
@@ -105,6 +106,7 @@
import org.apache.arrow.vector.types.pojo.ArrowType.Date;
import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
import org.apache.arrow.vector.types.pojo.ArrowType.Duration;
+import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeBinary;
import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList;
import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint;
@@ -710,7 +712,23 @@ public FieldVector getNewVector(
public FieldWriter getNewFieldWriter(ValueVector vector) {
return new TimeStampNanoTZWriterImpl((TimeStampNanoTZVector) vector);
}
- };
+ },
+ EXTENSIONTYPE(null) {
+ @Override
+ public FieldVector getNewVector(
+ String name,
+ FieldType fieldType,
+ BufferAllocator allocator,
+ CallBack schemaChangeCallback) {
+ return ((ExtensionType) fieldType.getType()).getNewVector(name, fieldType, allocator);
+ }
+
+ @Override
+ public FieldWriter getNewFieldWriter(ValueVector vector) {
+ return ((ExtensionTypeVector) vector).getUnderlyingVector().getMinorType().getNewFieldWriter(vector);
+ }
+ },
+ ;
private final ArrowType type;
@@ -889,6 +907,11 @@ public MinorType visit(Interval type) {
public MinorType visit(Duration type) {
return MinorType.DURATION;
}
+
+ @Override
+ public MinorType visit(ExtensionType type) {
+ return MinorType.EXTENSIONTYPE;
+ }
});
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java
new file mode 100644
index 0000000000000..f347008b42627
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/ExtensionTypeRegistry.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.vector.types.pojo;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
+
+/**
+ * A registry of recognized extension types.
+ */
+public final class ExtensionTypeRegistry {
+ private static final ConcurrentMap registry = new ConcurrentHashMap<>();
+
+ public static void register(ExtensionType type) {
+ registry.put(type.extensionName(), type);
+ }
+
+ public static void unregister(ExtensionType type) {
+ registry.remove(type.extensionName());
+ }
+
+ public static ExtensionType lookup(String name) {
+ return registry.get(name);
+ }
+}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java
index 54a69fcf56e0f..99ceb6a0f993a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java
@@ -36,6 +36,10 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.TypeLayout;
+import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -49,6 +53,8 @@
*/
public class Field {
+ private static final Logger logger = LoggerFactory.getLogger(Field.class);
+
public static Field nullablePrimitive(String name, ArrowType.PrimitiveType type) {
return nullable(name, type);
}
@@ -111,9 +117,30 @@ public FieldVector createVector(BufferAllocator allocator) {
* Constructs a new instance from a flatbuffer representation of the field.
*/
public static Field convertField(org.apache.arrow.flatbuf.Field field) {
+ Map metadata = new HashMap<>();
+ for (int i = 0; i < field.customMetadataLength(); i++) {
+ KeyValue kv = field.customMetadata(i);
+ String key = kv.key();
+ String value = kv.value();
+ metadata.put(key == null ? "" : key, value == null ? "" : value);
+ }
+ metadata = Collections.unmodifiableMap(metadata);
+
String name = field.name();
boolean nullable = field.nullable();
ArrowType type = getTypeForField(field);
+
+ if (metadata.containsKey(ExtensionType.EXTENSION_METADATA_KEY_NAME)) {
+ final String extensionName = metadata.get(ExtensionType.EXTENSION_METADATA_KEY_NAME);
+ final String extensionMetadata = metadata.getOrDefault(ExtensionType.EXTENSION_METADATA_KEY_METADATA, "");
+ ExtensionType extensionType = ExtensionTypeRegistry.lookup(extensionName);
+ if (extensionType != null) {
+ type = extensionType.deserialize(type, extensionMetadata);
+ }
+ // Otherwise, we haven't registered the type
+ logger.info("Unrecognized extension type: {}", extensionName);
+ }
+
DictionaryEncoding dictionary = null;
org.apache.arrow.flatbuf.DictionaryEncoding dictionaryFB = field.dictionary();
if (dictionaryFB != null) {
@@ -131,14 +158,6 @@ public static Field convertField(org.apache.arrow.flatbuf.Field field) {
children.add(childField);
}
children = Collections.unmodifiableList(children);
- Map metadata = new HashMap<>();
- for (int i = 0; i < field.customMetadataLength(); i++) {
- KeyValue kv = field.customMetadata(i);
- String key = kv.key();
- String value = kv.value();
- metadata.put(key == null ? "" : key, value == null ? "" : value);
- }
- metadata = Collections.unmodifiableMap(metadata);
return new Field(name, nullable, type, dictionary, children, metadata);
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java
index cf33e56b25a11..4cc4067c997c3 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java
@@ -17,6 +17,7 @@
package org.apache.arrow.vector.types.pojo;
+import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.memory.BufferAllocator;
@@ -25,6 +26,7 @@
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
import org.apache.arrow.vector.util.CallBack;
/**
@@ -59,7 +61,20 @@ public FieldType(boolean nullable, ArrowType type, DictionaryEncoding dictionary
this.nullable = nullable;
this.type = Preconditions.checkNotNull(type);
this.dictionary = dictionary;
- this.metadata = metadata == null ? java.util.Collections.emptyMap() : Collections2.immutableMapCopy(metadata);
+ if (type instanceof ExtensionType) {
+ // Save the extension type name/metadata
+ final Map extensionMetadata = new HashMap<>();
+ extensionMetadata.put(ExtensionType.EXTENSION_METADATA_KEY_NAME, ((ExtensionType) type).extensionName());
+ extensionMetadata.put(ExtensionType.EXTENSION_METADATA_KEY_METADATA, ((ExtensionType) type).serialize());
+ if (metadata != null) {
+ for (Map.Entry entry : metadata.entrySet()) {
+ extensionMetadata.put(entry.getKey(), entry.getValue());
+ }
+ }
+ this.metadata = Collections2.immutableMapCopy(extensionMetadata);
+ } else {
+ this.metadata = metadata == null ? java.util.Collections.emptyMap() : Collections2.immutableMapCopy(metadata);
+ }
}
public boolean isNullable() {
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java
new file mode 100644
index 0000000000000..20d270c8988e6
--- /dev/null
+++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.vector.types.pojo;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.channels.WritableByteChannel;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.nio.file.StandardOpenOption;
+import java.util.Collections;
+import java.util.UUID;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.ExtensionTypeVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowFileReader;
+import org.apache.arrow.vector.ipc.ArrowFileWriter;
+import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestExtensionType {
+ /**
+ * Test that a custom UUID type can be round-tripped through a temporary file.
+ */
+ @Test
+ public void roundtripUuid() throws IOException {
+ ExtensionTypeRegistry.register(new UuidType());
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new UuidType())));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ UUID u1 = UUID.randomUUID();
+ UUID u2 = UUID.randomUUID();
+ UuidVector vector = (UuidVector) root.getVector("a");
+ vector.setValueCount(2);
+ vector.set(0, u1);
+ vector.set(1, u2);
+ root.setRowCount(2);
+
+ final File file = File.createTempFile("uuidtest", ".arrow");
+ try (final WritableByteChannel channel = FileChannel
+ .open(Paths.get(file.getAbsolutePath()), StandardOpenOption.WRITE);
+ final ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) {
+ writer.start();
+ writer.writeBatch();
+ writer.end();
+ }
+
+ try (final SeekableByteChannel channel = Files.newByteChannel(Paths.get(file.getAbsolutePath()));
+ final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
+ reader.loadNextBatch();
+ final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
+ Assert.assertEquals(root.getSchema(), readerRoot.getSchema());
+
+ final Field field = readerRoot.getSchema().getFields().get(0);
+ final UuidType expectedType = new UuidType();
+ Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
+ expectedType.extensionName());
+ Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
+ expectedType.serialize());
+
+ final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0);
+ Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount());
+ for (int i = 0; i < vector.getValueCount(); i++) {
+ Assert.assertEquals(vector.isNull(i), deserialized.isNull(i));
+ if (!vector.isNull(i)) {
+ Assert.assertEquals(vector.getObject(i), deserialized.getObject(i));
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Test that a custom UUID type can be read as its underlying type.
+ */
+ @Test
+ public void readUnderlyingType() throws IOException {
+ ExtensionTypeRegistry.register(new UuidType());
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new UuidType())));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ UUID u1 = UUID.randomUUID();
+ UUID u2 = UUID.randomUUID();
+ UuidVector vector = (UuidVector) root.getVector("a");
+ vector.setValueCount(2);
+ vector.set(0, u1);
+ vector.set(1, u2);
+ root.setRowCount(2);
+
+ final File file = File.createTempFile("uuidtest", ".arrow");
+ try (final WritableByteChannel channel = FileChannel
+ .open(Paths.get(file.getAbsolutePath()), StandardOpenOption.WRITE);
+ final ArrowFileWriter writer = new ArrowFileWriter(root, null, channel)) {
+ writer.start();
+ writer.writeBatch();
+ writer.end();
+ }
+
+ ExtensionTypeRegistry.unregister(new UuidType());
+
+ try (final SeekableByteChannel channel = Files.newByteChannel(Paths.get(file.getAbsolutePath()));
+ final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
+ reader.loadNextBatch();
+ final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
+ Assert.assertEquals(1, readerRoot.getSchema().getFields().size());
+ Assert.assertEquals("a", readerRoot.getSchema().getFields().get(0).getName());
+ Assert.assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary);
+ Assert.assertEquals(16,
+ ((ArrowType.FixedSizeBinary) readerRoot.getSchema().getFields().get(0).getType()).getByteWidth());
+
+ final Field field = readerRoot.getSchema().getFields().get(0);
+ final UuidType expectedType = new UuidType();
+ Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
+ expectedType.extensionName());
+ Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
+ expectedType.serialize());
+
+ final FixedSizeBinaryVector deserialized = (FixedSizeBinaryVector) readerRoot.getFieldVectors().get(0);
+ Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount());
+ for (int i = 0; i < vector.getValueCount(); i++) {
+ Assert.assertEquals(vector.isNull(i), deserialized.isNull(i));
+ if (!vector.isNull(i)) {
+ final UUID uuid = vector.getObject(i);
+ final ByteBuffer bb = ByteBuffer.allocate(16);
+ bb.putLong(uuid.getMostSignificantBits());
+ bb.putLong(uuid.getLeastSignificantBits());
+ Assert.assertArrayEquals(bb.array(), deserialized.get(i));
+ }
+ }
+ }
+ }
+ }
+
+ static class UuidType extends ExtensionType {
+
+ @Override
+ public ArrowType storageType() {
+ return new ArrowType.FixedSizeBinary(16);
+ }
+
+ @Override
+ public String extensionName() {
+ return "uuid";
+ }
+
+ @Override
+ public boolean extensionEquals(ExtensionType other) {
+ return other instanceof UuidType;
+ }
+
+ @Override
+ public ArrowType deserialize(ArrowType storageType, String serializedData) {
+ if (!storageType.equals(storageType())) {
+ throw new UnsupportedOperationException("Cannot construct UuidType from underlying type " + storageType);
+ }
+ return new UuidType();
+ }
+
+ @Override
+ public String serialize() {
+ return "";
+ }
+
+ @Override
+ public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) {
+ return new UuidVector(name, allocator, new FixedSizeBinaryVector(name, allocator, 16));
+ }
+
+ }
+
+ static class UuidVector extends ExtensionTypeVector {
+
+ public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) {
+ super(name, allocator, underlyingVector);
+ }
+
+ @Override
+ public UUID getObject(int index) {
+ final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index));
+ return new UUID(bb.getLong(), bb.getLong());
+ }
+
+ public void set(int index, UUID uuid) {
+ ByteBuffer bb = ByteBuffer.allocate(16);
+ bb.putLong(uuid.getMostSignificantBits());
+ bb.putLong(uuid.getLeastSignificantBits());
+ getUnderlyingVector().set(index, bb.array());
+ }
+ }
+}