From b02283d227878bbb7079a6da3b25d02f99dd8b86 Mon Sep 17 00:00:00 2001 From: Chris Larsen Date: Mon, 23 Oct 2023 13:56:27 -0700 Subject: [PATCH] GH-38414 [Java] [Vector] Add Delta dictionary support. Add a delta encoding flag to the DictionaryEncoding class. Add a BaseDictionary interface (poor name but provides backwards compatibility) that Dictionary implements it and a new BatchedDictionary class that handles writing data to a dictionary and index allowing for flushing in a writer writeBatch call and either replacing or delta encoding the dictionary. --- .../apache/arrow/flight/DictionaryUtils.java | 4 +- .../org/apache/arrow/flight/FlightStream.java | 3 +- .../vector/dictionary/BaseDictionary.java | 44 ++ .../vector/dictionary/BatchedDictionary.java | 270 +++++++++++ .../arrow/vector/dictionary/Dictionary.java | 2 +- .../vector/dictionary/DictionaryEncoder.java | 24 +- .../dictionary/DictionaryHashTable.java | 48 +- .../vector/dictionary/DictionaryProvider.java | 32 +- .../dictionary/StructSubfieldEncoder.java | 10 +- .../arrow/vector/ipc/ArrowFileReader.java | 65 ++- .../arrow/vector/ipc/ArrowFileWriter.java | 16 +- .../apache/arrow/vector/ipc/ArrowReader.java | 7 + .../arrow/vector/ipc/ArrowStreamReader.java | 4 - .../arrow/vector/ipc/ArrowStreamWriter.java | 28 +- .../apache/arrow/vector/ipc/ArrowWriter.java | 14 +- .../arrow/vector/ipc/JsonFileReader.java | 5 + .../arrow/vector/ipc/JsonFileWriter.java | 4 +- .../apache/arrow/vector/table/BaseTable.java | 10 +- .../org/apache/arrow/vector/table/Table.java | 10 +- .../vector/types/pojo/DictionaryEncoding.java | 23 +- .../arrow/vector/util/DictionaryUtility.java | 3 +- .../apache/arrow/vector/util/Validator.java | 10 +- .../arrow/vector/util/VectorAppender.java | 4 +- .../dictionary/TestBatchedDictionary.java | 426 ++++++++++++++++++ .../apache/arrow/vector/ipc/BaseFileTest.java | 311 ++++++++++++- .../arrow/vector/ipc/TestArrowFile.java | 84 +++- .../arrow/vector/ipc/TestArrowStream.java | 41 +- 27 files changed, 1394 insertions(+), 108 deletions(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java create mode 100644 java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java index 516dab01d8a1b..9970e44a39759 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -30,7 +30,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.IpcOption; @@ -66,7 +66,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe } // Create and write dictionary batches for (Long id : dictionaryIds) { - final Dictionary dictionary = provider.lookup(id); + final BaseDictionary dictionary = provider.lookup(id); final FieldVector vector = dictionary.getVector(); final int count = vector.getValueCount(); // Do NOT close this root, as it does not actually own the vector. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index ad4ffcbebdec1..f33cfde4ffd04 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -37,6 +37,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -268,7 +269,7 @@ public boolean next() { if (dictionaries == null) { throw new IllegalStateException("Dictionary ownership was claimed by the application."); } - final Dictionary dictionary = dictionaries.lookup(id); + final BaseDictionary dictionary = dictionaries.lookup(id); if (dictionary == null) { throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java new file mode 100644 index 0000000000000..0d785bff66b60 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +/** + * Interface for all dictionary types. + */ +public interface BaseDictionary { + + /** + * The dictionary vector containing unique entries. + */ + FieldVector getVector(); + + /** + * The encoding used for the dictionary vector. + */ + DictionaryEncoding getEncoding(); + + /** + * The type of the dictionary vector. + */ + ArrowType getVectorType(); + +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java new file mode 100644 index 0000000000000..59f54698f543a --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.util.hash.MurmurHasher; +import org.apache.arrow.util.VisibleForTesting; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ipc.ArrowWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** + * A dictionary implementation that can be used when writing batches of data to + * a stream or file. Supports delta or replacement encoding. + */ +public class BatchedDictionary implements Closeable, BaseDictionary { + + private final DictionaryEncoding encoding; + + private final BaseVariableWidthVector dictionary; + + private final BaseIntVector indexVector; + + private final DictionaryHashTable hashTable; + + private final boolean forFileIPC; + + private int deltaIndex; + + private int dictionaryIndex; + + private boolean wasReset; + + /** + * Creates a dictionary with two vectors of the given types. The dictionary vector + * will be named "{name}-dictionary". + *

+ * To use this dictionary, provide the dictionary vector to a {@link DictionaryProvider}, + * add the {@link #getIndexVector()} to the {@link org.apache.arrow.vector.VectorSchemaRoot} + * and call the {@link #setSafe(int, byte[], int, int)} or other set methods. + * + * @param name A name for the vector and dictionary. + * @param encoding The dictionary encoding to use. + * @param dictionaryType The type of the dictionary data. + * @param indexType The type of the encoded dictionary index. + * @param forFileIPC Whether the data will be written to a file or stream IPC. Throws an + * exception if a replacement dictionary is provided to a file IPC. + * @param allocator The allocator to use. + */ + public BatchedDictionary( + String name, + DictionaryEncoding encoding, + ArrowType dictionaryType, + ArrowType indexType, + boolean forFileIPC, + BufferAllocator allocator + ) { + this(name, encoding, dictionaryType, indexType, forFileIPC, allocator, "-dictionary"); + } + + /** + * Creates a dictionary with two vectors of the given types. + * + * @param name A name for the vector and dictionary. + * @param encoding The dictionary encoding to use. + * @param dictionaryType The type of the dictionary data. + * @param indexType The type of the encoded dictionary index. + * @param forFileIPC Whether the data will be written to a file or stream IPC. Throws an + * exception if a replacement dictionary is provided to a file IPC. + * @param allocator The allocator to use. + * @param suffix A non-null suffix to append to the name of the dictionary. + */ + public BatchedDictionary( + String name, + DictionaryEncoding encoding, + ArrowType dictionaryType, + ArrowType indexType, + boolean forFileIPC, + BufferAllocator allocator, + String suffix + ) { + this.encoding = encoding; + this.forFileIPC = forFileIPC; + FieldVector vector = new FieldType(false, dictionaryType, null) + .createNewSingleVector(name + suffix, allocator, null); + if (!(BaseVariableWidthVector.class.isAssignableFrom(vector.getClass()))) { + throw new IllegalArgumentException("Dictionary must be a superclass of 'BaseVariableWidthVector' " + + "such as 'VarCharVector'."); + } + dictionary = (BaseVariableWidthVector) vector; + vector = new FieldType(true, indexType, encoding) + .createNewSingleVector(name, allocator, null); + if (!(BaseIntVector.class.isAssignableFrom(vector.getClass()))) { + throw new IllegalArgumentException("Index vector must be a superclass type of 'BaseIntVector' " + + "such as 'IntVector' or 'Uint4Vector'."); + } + indexVector = (BaseIntVector) vector; + hashTable = new DictionaryHashTable(); + } + + /** + * Creates a dictionary that will populate the provided vectors with data. Useful if + * dictionaries need to be children of a parent vector. + * @param dictionary The dictionary to hold the original data. + * @param indexVector The index to store the encoded offsets. + * @param forFileIPC Whether the data will be written to a file or stream IPC. Throws an + * exception if a replacement dictionary is provided to a file IPC. + */ + public BatchedDictionary( + FieldVector dictionary, + FieldVector indexVector, + boolean forFileIPC + ) { + this.encoding = dictionary.getField().getDictionary(); + this.forFileIPC = forFileIPC; + if (!(BaseVariableWidthVector.class.isAssignableFrom(dictionary.getClass()))) { + throw new IllegalArgumentException("Dictionary must be a superclass of 'BaseVariableWidthVector' " + + "such as 'VarCharVector'."); + } + if (dictionary.getField().isNullable()) { + throw new IllegalArgumentException("Dictionary must be non-nullable."); + } + this.dictionary = (BaseVariableWidthVector) dictionary; + if (!(BaseIntVector.class.isAssignableFrom(indexVector.getClass()))) { + throw new IllegalArgumentException("Index vector must be a superclass type of 'BaseIntVector' " + + "such as 'IntVector' or 'Uint4Vector'."); + } + this.indexVector = (BaseIntVector) indexVector; + hashTable = new DictionaryHashTable(); + } + + /** + * The index vector. + */ + public FieldVector getIndexVector() { + return indexVector; + } + + @Override + public FieldVector getVector() { + return dictionary; + } + + @Override + public ArrowType getVectorType() { + return dictionary.getField().getType(); + } + + @Override + public DictionaryEncoding getEncoding() { + return encoding; + } + + /** + * Considers the entire byte array as the dictionary value. If the value is null, + * a null will be written to the index. + * + * @param index the value to change + * @param value the value to write. + */ + public void setSafe(int index, byte[] value) { + if (value == null) { + setNull(index); + return; + } + setSafe(index, value, 0, value.length); + } + + /** + * Encodes the given range in the dictionary. If the value is null, a null will be + * written to the index. + * + * @param index the value to change + * @param value the value to write. + * @param offset An offset into the value array. + * @param len The length of the value to write. + */ + public void setSafe(int index, byte[] value, int offset, int len) { + if (value == null || len == 0) { + setNull(index); + return; + } + int di = getIndex(value, offset, len); + indexVector.setWithPossibleTruncate(index, di); + } + + /** + * Set the element at the given index to null. + * + * @param index the value to change + */ + public void setNull(int index) { + indexVector.setNull(index); + } + + @Override + public void close() throws IOException { + dictionary.close(); + indexVector.close(); + } + + /** + * Mark the dictionary as complete for the batch. Called by the {@link ArrowWriter} + * on {@link ArrowWriter#writeBatch()}. + */ + public void mark() { + dictionary.setValueCount(dictionaryIndex); + // not setting the index vector value count. That will happen when the user calls + // VectorSchemaRoot#setRowCount(). + } + + /** + * Resets the dictionary to be used for a new batch. Called by the {@link ArrowWriter} on + * {@link ArrowWriter#writeBatch()}. + */ + public void reset() { + wasReset = true; + dictionaryIndex = 0; + dictionary.reset(); + indexVector.reset(); + if (!forFileIPC && !encoding.isDelta()) { + // replacement mode. + deltaIndex = 0; + hashTable.clear(); + } + } + + private int getIndex(byte[] value, int offset, int len) { + int hash = MurmurHasher.hashCode(value, offset, len, 0); + int i = hashTable.getIndex(hash); + if (i >= 0) { + return i; + } else { + if (wasReset && forFileIPC && !encoding.isDelta()) { + throw new IllegalStateException("Dictionary was reset and is not in delta mode. " + + "This is not supported for file IPC."); + } + hashTable.addEntry(hash, deltaIndex); + dictionary.setSafe(dictionaryIndex++, value, offset, len); + return deltaIndex++; + } + } + + @VisibleForTesting + DictionaryHashTable getHashTable() { + return hashTable; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java index 6f40e5814b972..5544855953d40 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -28,7 +28,7 @@ * A dictionary (integer to Value mapping) that is used to facilitate * dictionary encoding compression. */ -public class Dictionary { +public class Dictionary implements BaseDictionary { private final DictionaryEncoding encoding; private final FieldVector dictionary; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 4368501ffc7b5..959a1eb1d6627 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -38,20 +38,20 @@ public class DictionaryEncoder { private final DictionaryHashTable hashTable; - private final Dictionary dictionary; + private final BaseDictionary dictionary; private final BufferAllocator allocator; /** * Construct an instance. */ - public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator) { + public DictionaryEncoder(BaseDictionary dictionary, BufferAllocator allocator) { this (dictionary, allocator, SimpleHasher.INSTANCE); } /** * Construct an instance. */ - public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator, ArrowBufHasher hasher) { + public DictionaryEncoder(BaseDictionary dictionary, BufferAllocator allocator, ArrowBufHasher hasher) { this.dictionary = dictionary; this.allocator = allocator; hashTable = new DictionaryHashTable(dictionary.getVector(), hasher); @@ -64,7 +64,7 @@ public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator, Arrow * @param dictionary dictionary used for encoding * @return dictionary encoded vector */ - public static ValueVector encode(ValueVector vector, Dictionary dictionary) { + public static ValueVector encode(ValueVector vector, BaseDictionary dictionary) { DictionaryEncoder encoder = new DictionaryEncoder(dictionary, vector.getAllocator()); return encoder.encode(vector); } @@ -76,7 +76,7 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { * @param dictionary dictionary used to decode the values * @return vector with values restored from dictionary */ - public static ValueVector decode(ValueVector indices, Dictionary dictionary) { + public static ValueVector decode(ValueVector indices, BaseDictionary dictionary) { return decode(indices, dictionary, indices.getAllocator()); } @@ -88,7 +88,7 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary) { * @param allocator allocator the decoded values use * @return vector with values restored from dictionary */ - public static ValueVector decode(ValueVector indices, Dictionary dictionary, BufferAllocator allocator) { + public static ValueVector decode(ValueVector indices, BaseDictionary dictionary, BufferAllocator allocator) { int count = indices.getValueCount(); ValueVector dictionaryVector = dictionary.getVector(); int dictionaryCount = dictionaryVector.getValueCount(); @@ -185,8 +185,11 @@ static void retrieveIndexVector( public ValueVector encode(ValueVector vector) { Field valueField = vector.getField(); - FieldType indexFieldType = new FieldType(valueField.isNullable(), dictionary.getEncoding().getIndexType(), - dictionary.getEncoding(), valueField.getMetadata()); + FieldType indexFieldType = new FieldType( + valueField.isNullable(), + dictionary.getEncoding().getIndexType(), + dictionary.getEncoding(), + valueField.getMetadata()); Field indexField = new Field(valueField.getName(), indexFieldType, null); // vector to hold our indices (dictionary encoded values) @@ -211,8 +214,9 @@ public ValueVector encode(ValueVector vector) { /** * Decodes a vector with the dictionary in this encoder. * - * {@link DictionaryEncoder#decode(ValueVector, Dictionary, BufferAllocator)} should be used instead if only decoding - * is required as it can avoid building the {@link DictionaryHashTable} which only makes sense when encoding. + * {@link DictionaryEncoder#decode(ValueVector, BaseDictionary, BufferAllocator)} + * should be used instead if only decoding is required as it can avoid building the + * {@link DictionaryHashTable} which only makes sense when encoding. */ public ValueVector decode(ValueVector indices) { return decode(indices, dictionary, allocator); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java index 9926a8e2a637f..fc00cdcc3d39f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java @@ -96,9 +96,11 @@ public DictionaryHashTable(int initialCapacity, ValueVector dictionary, ArrowBuf this.hasher = hasher; - // build hash table - for (int i = 0; i < this.dictionary.getValueCount(); i++) { - put(i); + if (dictionary != null) { + // build hash table + for (int i = 0; i < this.dictionary.getValueCount(); i++) { + put(i); + } } } @@ -110,6 +112,18 @@ public DictionaryHashTable(ValueVector dictionary) { this(dictionary, SimpleHasher.INSTANCE); } + /** + * Creates an empty cable used for batch writing of dictionaries. + */ + public DictionaryHashTable() { + this(DEFAULT_INITIAL_CAPACITY, null, SimpleHasher.INSTANCE); + + if (table == EMPTY_TABLE) { + inflateTable(threshold); + } + } + + /** * Compute the capacity with given threshold and create init table. */ @@ -194,6 +208,34 @@ void createEntry(int hash, int index, int bucketIndex) { size++; } + /** + * Returns the corresponding dictionary index entry given a hash code. If the hash has + * not been written to the table, returns -1. + * + * @param hash The hash to lookup. + * @return The dictionary index if present, -1 if not. + */ + int getIndex(int hash) { + int i = indexFor(hash, table.length); + for (DictionaryHashTable.Entry e = table[i]; e != null; e = e.next) { + if (e.hash == hash) { + return e.index; + } + } + return -1; + } + + /** + * Adds an entry to the hash table. + * + * @param hash The hash to add. + * @param index The corresponding dictionary index. + */ + void addEntry(int hash, int index) { + int bucketIndex = indexFor(hash, table.length); + addEntry(hash, index, bucketIndex); + } + /** * Add Entry at the specified location of the table. */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java index f64c32be0f3e9..9845e36a54330 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -30,24 +30,29 @@ public interface DictionaryProvider { /** Return the dictionary for the given ID. */ - Dictionary lookup(long id); + BaseDictionary lookup(long id); /** Get all dictionary IDs. */ Set getDictionaryIds(); + /** + * Reset all batched dictionaries associated with this provider. + */ + void resetBatchedDictionaries(); + /** * Implementation of {@link DictionaryProvider} that is backed by a hash-map. */ class MapDictionaryProvider implements AutoCloseable, DictionaryProvider { - private final Map map; + private final Map map; /** * Constructs a new instance from the given dictionaries. */ - public MapDictionaryProvider(Dictionary... dictionaries) { + public MapDictionaryProvider(BaseDictionary... dictionaries) { this.map = new HashMap<>(); - for (Dictionary dictionary : dictionaries) { + for (BaseDictionary dictionary : dictionaries) { put(dictionary); } } @@ -62,14 +67,14 @@ public MapDictionaryProvider(Dictionary... dictionaries) { @VisibleForTesting public void copyStructureFrom(DictionaryProvider other, BufferAllocator allocator) { for (Long id : other.getDictionaryIds()) { - Dictionary otherDict = other.lookup(id); - Dictionary newDict = new Dictionary(otherDict.getVector().getField().createVector(allocator), + BaseDictionary otherDict = other.lookup(id); + BaseDictionary newDict = new Dictionary(otherDict.getVector().getField().createVector(allocator), otherDict.getEncoding()); put(newDict); } } - public void put(Dictionary dictionary) { + public void put(BaseDictionary dictionary) { map.put(dictionary.getEncoding().getId(), dictionary); } @@ -79,15 +84,24 @@ public final Set getDictionaryIds() { } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { return map.get(id); } @Override public void close() { - for (Dictionary dictionary : map.values()) { + for (BaseDictionary dictionary : map.values()) { dictionary.getVector().close(); } } + + @Override + public void resetBatchedDictionaries() { + map.values().forEach( dictionary -> { + if (dictionary instanceof BatchedDictionary) { + ((BatchedDictionary) dictionary).reset(); + } + }); + } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java index 8500528a62b60..75cf63437b61c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java @@ -109,7 +109,7 @@ public StructVector encode(StructVector vector, Map columnToDicti if (dictionaryId == null) { childrenFields.add(childVector.getField()); } else { - Dictionary dictionary = provider.lookup(dictionaryId); + BaseDictionary dictionary = provider.lookup(dictionaryId); Preconditions.checkNotNull(dictionary, "Dictionary not found with id:" + dictionaryId); FieldType indexFieldType = new FieldType(childVector.getField().isNullable(), dictionary.getEncoding().getIndexType(), dictionary.getEncoding()); @@ -177,7 +177,7 @@ public static StructVector decode(StructVector vector, List childFields = new ArrayList<>(); for (int i = 0; i < childCount; i++) { FieldVector childVector = getChildVector(vector, i); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); + BaseDictionary dictionary = getChildVectorDictionary(childVector, provider); // childVector is not encoded. if (dictionary == null) { childFields.add(childVector.getField()); @@ -192,7 +192,7 @@ public static StructVector decode(StructVector vector, // get child vector FieldVector childVector = getChildVector(vector, index); FieldVector decodedChildVector = getChildVector(decoded, index); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); + BaseDictionary dictionary = getChildVectorDictionary(childVector, provider); if (dictionary == null) { childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount); } else { @@ -213,11 +213,11 @@ public static StructVector decode(StructVector vector, /** * Get the child vector dictionary, return null if not dictionary encoded. */ - private static Dictionary getChildVectorDictionary(FieldVector childVector, + private static BaseDictionary getChildVectorDictionary(FieldVector childVector, DictionaryProvider.MapDictionaryProvider provider) { DictionaryEncoding dictionaryEncoding = childVector.getField().getDictionary(); if (dictionaryEncoding != null) { - Dictionary dictionary = provider.lookup(dictionaryEncoding.getId()); + BaseDictionary dictionary = provider.lookup(dictionaryEncoding.getId()); Preconditions.checkNotNull(dictionary, "Dictionary not found with id:" + dictionary); return dictionary; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java index 8629cf93470b8..fad45816b93bd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java @@ -22,6 +22,7 @@ import java.nio.channels.SeekableByteChannel; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -50,6 +51,7 @@ public class ArrowFileReader extends ArrowReader { private SeekableReadChannel in; private ArrowFooter footer; + private int estimatedDictionaryRecordBatch = 0; private int currentDictionaryBatch = 0; private int currentRecordBatch = 0; @@ -123,11 +125,6 @@ public void initialize() throws IOException { if (footer.getRecordBatches().size() == 0) { return; } - // Read and load all dictionaries from schema - for (int i = 0; i < dictionaries.size(); i++) { - ArrowDictionaryBatch dictionaryBatch = readDictionary(); - loadDictionary(dictionaryBatch); - } } /** @@ -140,21 +137,6 @@ public Map getMetaData() { return new HashMap<>(); } - /** - * Read a dictionary batch from the source, will be invoked after the schema has been read and - * called N times, where N is the number of dictionaries indicated by the schema Fields. - * - * @return the read ArrowDictionaryBatch - * @throws IOException on error - */ - public ArrowDictionaryBatch readDictionary() throws IOException { - if (currentDictionaryBatch >= footer.getDictionaries().size()) { - throw new IOException("Requested more dictionaries than defined in footer: " + currentDictionaryBatch); - } - ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); - return readDictionaryBatch(in, block, allocator); - } - /** Returns true if a batch was read, false if no more batches. */ @Override public boolean loadNextBatch() throws IOException { @@ -164,12 +146,54 @@ public boolean loadNextBatch() throws IOException { ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); ArrowRecordBatch batch = readRecordBatch(in, block, allocator); loadRecordBatch(batch); + loadDictionaries(); return true; } else { return false; } } + /** + * Loads any dictionaries that may be needed by the given record batch. It attempts + * to read as little as possible but may read in more deltas than are necessary for blocks + * toward the end of the file. + */ + private void loadDictionaries() throws IOException { + // initial load + if (currentDictionaryBatch == 0) { + for (int i = 0; i < dictionaries.size(); i++) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + ArrowDictionaryBatch dictionaryBatch = readDictionaryBatch(in, block, allocator); + loadDictionary(dictionaryBatch); + } + estimatedDictionaryRecordBatch++; + } else { + // we need to look for delta dictionaries. It involves a look-ahead, unfortunately. + HashSet visited = new HashSet(); + while (estimatedDictionaryRecordBatch < currentRecordBatch && + currentDictionaryBatch < footer.getDictionaries().size()) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + ArrowDictionaryBatch dictionaryBatch = readDictionaryBatch(in, block, allocator); + long dictionaryId = dictionaryBatch.getDictionaryId(); + if (visited.contains(dictionaryId)) { + // done + currentDictionaryBatch--; + estimatedDictionaryRecordBatch++; + } else if (!dictionaries.containsKey(dictionaryId)) { + throw new IOException("Dictionary ID " + dictionaryId + " was written " + + "after the initial batch. The file does not follow the IPC file protocol."); + } else if (!dictionaryBatch.isDelta()) { + throw new IOException("Dictionary ID " + dictionaryId + " was written as a replacement " + + "after the initial batch. Replacement dictionaries are not currently allowed in the IPC file protocol."); + } else { + loadDictionary(dictionaryBatch); + } + } + if (currentDictionaryBatch >= footer.getDictionaries().size()) { + estimatedDictionaryRecordBatch++; + } + } + } public List getDictionaryBlocks() throws IOException { ensureInitialized(); @@ -194,6 +218,7 @@ public boolean loadRecordBatch(ArrowBlock block) throws IOException { throw new IllegalArgumentException("Arrow block does not exist in record batches: " + block); } currentRecordBatch = blockIndex; + loadDictionaries(); return loadNextBatch(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 71db79087a3e4..c0018efea3bf8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -29,7 +29,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -130,14 +130,24 @@ protected void endInternal(WriteChannel out) throws IOException { protected void ensureDictionariesWritten(DictionaryProvider provider, Set dictionaryIdsUsed) throws IOException { if (dictionariesWritten) { + for (long id : dictionaryIdsUsed) { + BaseDictionary dictionary = provider.lookup(id); + if (dictionary.getEncoding().isDelta()) { + writeDictionaryBatch(dictionary, false); + } + // TODO - It would be useful to throw an exception here if a replacement dictionary was found + // with modifications. Replacements are not currently allowed in files. For now, we just drop it + // and throw an exception in the BatchedDictionary and hope users use that class. + } return; } + dictionariesWritten = true; // Write out all dictionaries required. // Replacement dictionaries are not supported in the IPC file format. for (long id : dictionaryIdsUsed) { - Dictionary dictionary = provider.lookup(id); - writeDictionaryBatch(dictionary); + BaseDictionary dictionary = provider.lookup(id); + writeDictionaryBatch(dictionary, true); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java index 04c57d7e82fef..4b7aa83720cd0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java @@ -151,6 +151,11 @@ public void close(boolean closeReadSource) throws IOException { } } + @Override + public void resetBatchedDictionaries() { + // no-op + } + /** * Close the underlying read source. * @@ -243,6 +248,8 @@ protected void loadDictionary(ArrowDictionaryBatch dictionaryBatch) { } return; } + // TODO - If the super class is a file reader it may be good to throw an exception here + // the dictionary failed to satisfy the spec (i.e. being a replacement dictionary) load(dictionaryBatch, vector); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java index a0096aaf3ee56..6464a97c99efd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java @@ -177,10 +177,6 @@ public boolean loadNextBatch() throws IOException { * When read a record batch, check whether its dictionaries are available. */ private void checkDictionaries() throws IOException { - // if all dictionaries are loaded, return. - if (loadedDictionaryCount == dictionaries.size()) { - return; - } for (FieldVector vector : getVectorSchemaRoot().getFieldVectors()) { DictionaryEncoding encoding = vector.getField().getDictionary(); if (encoding != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java index 928e1de4c5f6b..1cf43ecb20787 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java @@ -26,13 +26,11 @@ import java.util.Optional; import java.util.Set; -import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; @@ -41,7 +39,7 @@ * Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel. */ public class ArrowStreamWriter extends ArrowWriter { - private final Map previousDictionaries = new HashMap<>(); + private final Map previousDictionaries = new HashMap<>(); /** * Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream. @@ -135,30 +133,16 @@ protected void ensureDictionariesWritten(DictionaryProvider provider, Set throws IOException { // write out any dictionaries that have changes for (long id : dictionaryIdsUsed) { - Dictionary dictionary = provider.lookup(id); - FieldVector vector = dictionary.getVector(); - if (previousDictionaries.containsKey(id) && - VectorEqualsVisitor.vectorEquals(vector, previousDictionaries.get(id))) { - // Dictionary was previously written and hasn't changed - continue; - } - writeDictionaryBatch(dictionary); - // Store a copy of the vector in case it is later mutated - if (previousDictionaries.containsKey(id)) { - previousDictionaries.get(id).close(); - } - previousDictionaries.put(id, copyVector(vector)); + BaseDictionary dictionary = provider.lookup(id); + boolean isInitial = previousDictionaries.containsKey(id) ? false : true; + writeDictionaryBatch(dictionary, isInitial); + previousDictionaries.put(id, true); } } @Override public void close() { super.close(); - try { - AutoCloseables.close(previousDictionaries.values()); - } catch (Exception e) { - throw new RuntimeException(e); - } } private static FieldVector copyVector(FieldVector source) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index a33c55de53f23..c1ef3d425a4ec 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -32,7 +32,8 @@ import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; import org.apache.arrow.vector.compression.NoCompressionCodec; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; +import org.apache.arrow.vector.dictionary.BatchedDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -123,9 +124,15 @@ public void writeBatch() throws IOException { try (ArrowRecordBatch batch = unloader.getRecordBatch()) { writeRecordBatch(batch); } + if (dictionaryProvider != null) { + dictionaryProvider.resetBatchedDictionaries(); + } } - protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { + protected void writeDictionaryBatch(BaseDictionary dictionary, boolean isInitial) throws IOException { + if (dictionary instanceof BatchedDictionary) { + ((BatchedDictionary) dictionary).mark(); + } FieldVector vector = dictionary.getVector(); long id = dictionary.getEncoding().getId(); int count = vector.getValueCount(); @@ -135,7 +142,8 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { count); VectorUnloader unloader = new VectorUnloader(dictRoot); ArrowRecordBatch batch = unloader.getRecordBatch(); - ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false); + boolean isDelta = isInitial ? false : dictionary.getEncoding().isDelta(); + ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, isDelta); try { writeDictionaryBatch(dictionaryBatch); } finally { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 0c23a664f62d6..d8d58bf4f1307 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -259,6 +259,11 @@ public int skip(int numBatches) throws IOException { return numBatches; } + @Override + public void resetBatchedDictionaries() { + // no-op + } + private abstract class BufferReader { protected abstract ArrowBuf read(BufferAllocator allocator, int count) throws IOException; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index f5e267e81256c..e0be27d468ab5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -68,7 +68,7 @@ import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -173,7 +173,7 @@ private void writeDictionaryBatches(JsonGenerator generator, Set dictionar generator.writeObjectField("id", id); generator.writeFieldName("data"); - Dictionary dictionary = provider.lookup(id); + BaseDictionary dictionary = provider.lookup(id); FieldVector vector = dictionary.getVector(); List fields = Collections.singletonList(vector.getField()); List vectors = Collections.singletonList(vector); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java b/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java index 9f645b64bc5f6..eb766c617682a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java @@ -29,7 +29,7 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.reader.FieldReader; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.Field; @@ -385,7 +385,7 @@ public DictionaryProvider getDictionaryProvider() { * @return A ValueVector */ public ValueVector decode(String vectorName, long dictionaryId) { - Dictionary dictionary = getDictionary(dictionaryId); + BaseDictionary dictionary = getDictionary(dictionaryId); FieldVector vector = getVector(vectorName); if (vector == null) { @@ -405,7 +405,7 @@ public ValueVector decode(String vectorName, long dictionaryId) { * @return A ValueVector */ public ValueVector encode(String vectorName, long dictionaryId) { - Dictionary dictionary = getDictionary(dictionaryId); + BaseDictionary dictionary = getDictionary(dictionaryId); FieldVector vector = getVector(vectorName); if (vector == null) { throw new IllegalArgumentException( @@ -419,12 +419,12 @@ public ValueVector encode(String vectorName, long dictionaryId) { * Returns the dictionary with given id. * @param dictionaryId A long integer that is the id returned by the dictionary's getId() method */ - private Dictionary getDictionary(long dictionaryId) { + private BaseDictionary getDictionary(long dictionaryId) { if (dictionaryProvider == null) { throw new IllegalStateException("No dictionary provider is present in table."); } - Dictionary dictionary = dictionaryProvider.lookup(dictionaryId); + BaseDictionary dictionary = dictionaryProvider.lookup(dictionaryId); if (dictionary == null) { throw new IllegalArgumentException("No dictionary with id '%n' exists in the table"); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java b/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java index 5768bb0ec75ec..7442342f37986 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java @@ -28,6 +28,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -113,13 +114,18 @@ public Table copy() { Dictionary[] dictionaryCopies = new Dictionary[ids.size()]; int i = 0; for (Long id : ids) { - Dictionary src = dictionaryProvider.lookup(id); + BaseDictionary src = dictionaryProvider.lookup(id); FieldVector srcVector = src.getVector(); FieldVector destVector = srcVector.getField().createVector(srcVector.getAllocator()); destVector.copyFromSafe(0, srcVector.getValueCount(), srcVector); // TODO: Remove safe copy for perf DictionaryEncoding srcEncoding = src.getEncoding(); Dictionary dest = new Dictionary(destVector, - new DictionaryEncoding(srcEncoding.getId(), srcEncoding.isOrdered(), srcEncoding.getIndexType())); + new DictionaryEncoding( + srcEncoding.getId(), + srcEncoding.isOrdered(), + srcEncoding.getIndexType(), + srcEncoding.isDelta() + )); dictionaryCopies[i] = dest; i++; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java index 8d41b92d867e9..3d7f86dcf6c68 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java @@ -33,6 +33,7 @@ public class DictionaryEncoding { private final long id; private final boolean ordered; private final Int indexType; + private final boolean isDelta; /** * Constructs a new instance. @@ -42,14 +43,29 @@ public class DictionaryEncoding { * @param indexType (nullable). The integer type to use for indexing in the dictionary. Defaults to a signed * 32 bit integer. */ + public DictionaryEncoding(long id, boolean ordered, Int indexType) { + this(id, ordered, indexType, false); + } + + /** + * Constructs a new instance. + * + * @param id The ID of the dictionary to use for encoding. + * @param ordered Whether the keys in values in the dictionary are ordered. + * @param indexType (nullable). The integer type to use for indexing in the dictionary. Defaults to a signed + * 32 bit integer. + * @param isDelta Whether the dictionary is a delta dictionary. + */ @JsonCreator public DictionaryEncoding( @JsonProperty("id") long id, @JsonProperty("isOrdered") boolean ordered, - @JsonProperty("indexType") Int indexType) { + @JsonProperty("indexType") Int indexType, + @JsonProperty("isDelta") Boolean isDelta) { this.id = id; this.ordered = ordered; this.indexType = indexType == null ? new Int(32, true) : indexType; + this.isDelta = isDelta == null ? false : isDelta; } public long getId() { @@ -65,6 +81,11 @@ public Int getIndexType() { return indexType; } + @JsonGetter("isDelta") + public boolean isDelta() { + return isDelta; + } + @Override public String toString() { return "DictionaryEncoding[id=" + id + ",ordered=" + ordered + ",indexType=" + indexType + "]"; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java index 9592f3975ab99..4154cfae44399 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java @@ -24,6 +24,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -58,7 +59,7 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se children = field.getChildren(); } else { long id = encoding.getId(); - Dictionary dictionary = provider.lookup(id); + BaseDictionary dictionary = provider.lookup(id); if (dictionary == null) { throw new IllegalArgumentException("Could not find dictionary with ID " + id); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java index 0c9ad1e2753f1..a0bc187c8f838 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java @@ -24,7 +24,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -70,8 +70,8 @@ public static void compareDictionaries( } long id = encodings1.get(i).getId(); - Dictionary dict1 = provider1.lookup(id); - Dictionary dict2 = provider2.lookup(id); + BaseDictionary dict1 = provider1.lookup(id); + BaseDictionary dict2 = provider2.lookup(id); if (dict1 == null || dict2 == null) { throw new IllegalArgumentException("The DictionaryProvider did not contain the required " + @@ -101,8 +101,8 @@ public static void compareDictionaryProviders( ids1 + "\n" + ids2); } for (long id : ids1) { - Dictionary dict1 = provider1.lookup(id); - Dictionary dict2 = provider2.lookup(id); + BaseDictionary dict1 = provider1.lookup(id); + BaseDictionary dict2 = provider2.lookup(id); try { compareFieldVectors(dict1.getVector(), dict2.getVector()); } catch (IllegalArgumentException e) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java b/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java index c5de380f9c173..7b8c8e2fb6b6a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java @@ -108,8 +108,8 @@ public ValueVector visit(BaseVariableWidthVector deltaVector, Void value) { int newValueCount = targetVector.getValueCount() + deltaVector.getValueCount(); - int targetDataSize = targetVector.getOffsetBuffer().getInt( - (long) targetVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); + int targetDataSize = targetVector.getValueCount() > 0 ? targetVector.getOffsetBuffer().getInt( + (long) targetVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH) : 0; int deltaDataSize = deltaVector.getOffsetBuffer().getInt( (long) deltaVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); int newValueCapacity = targetDataSize + deltaDataSize; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java b/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java new file mode 100644 index 0000000000000..d0aa813652a6a --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class TestBatchedDictionary { + + private static final DictionaryEncoding DELTA = + new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true); + private static final DictionaryEncoding SINGLE = + new DictionaryEncoding(24, false, new ArrowType.Int(16, false)); + private static final byte[] FOO = "foo".getBytes(StandardCharsets.UTF_8); + private static final byte[] BAR = "bar".getBytes(StandardCharsets.UTF_8); + private static final byte[] BAZ = "baz".getBytes(StandardCharsets.UTF_8); + + private BufferAllocator allocator; + + private static List validDictionaryTypes = Arrays.asList( + new ArrowType.Utf8(), + ArrowType.Binary.INSTANCE + ); + private static List invalidDictionaryTypes = Arrays.asList( + new ArrowType.LargeUtf8(), + new ArrowType.LargeBinary(), + new ArrowType.Bool(), + new ArrowType.Int(8, false) + ); + private static List validIndexTypes = Arrays.asList( + new ArrowType.Int(16, false), + new ArrowType.Int(8, false), + new ArrowType.Int(32, false), + new ArrowType.Int(64, false), + new ArrowType.Int(16, true), + new ArrowType.Int(8, true), + new ArrowType.Int(32, true), + new ArrowType.Int(64, true) + ); + private static List invalidIndexTypes = Arrays.asList( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + new ArrowType.Bool(), + new ArrowType.Utf8() + ); + + @BeforeEach + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @AfterEach + public void shutdown() { + allocator.close(); + } + + public static Collection validTypes() { + List params = new ArrayList<>(); + for (ArrowType dictType : validDictionaryTypes) { + for (ArrowType indexType : validIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + return params; + } + + public static Collection invalidTypes() { + List params = new ArrayList<>(); + for (ArrowType dictType : invalidDictionaryTypes) { + for (ArrowType indexType : validIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + for (ArrowType dictType : validDictionaryTypes) { + for (ArrowType indexType : invalidIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + return params; + } + + @ParameterizedTest + @MethodSource("validTypes") + public void testValidDictionaryTypes(ArrowType dictType, ArrowType indexType) throws IOException { + new BatchedDictionary( + "vector", + DELTA, + dictType, + indexType, + false, + allocator + ).close(); + } + + @ParameterizedTest + @MethodSource("validTypes") + public void testValidDictionaryVectors(ArrowType dictType, ArrowType indexType) throws IOException { + try (FieldVector dictVector = new FieldType(false, dictType, SINGLE) + .createNewSingleVector("d", allocator, null); + FieldVector indexVector = new FieldType(false, indexType, SINGLE) + .createNewSingleVector("i", allocator, null);) { + new BatchedDictionary( + dictVector, + indexVector, + false + ).close(); + } + } + + @ParameterizedTest + @MethodSource("invalidTypes") + public void testInvalidTypes(ArrowType dictType, ArrowType indexType) { + assertThrows(IllegalArgumentException.class, () -> { + new BatchedDictionary( + "vector", + SINGLE, + dictType, + indexType, + false, + allocator + ); + }); + } + + @ParameterizedTest + @MethodSource("invalidTypes") + public void testInvalidValidVectors(ArrowType dictType, ArrowType indexType) { + assertThrows(IllegalArgumentException.class, () -> { + try (FieldVector dictVector = new FieldType(false, dictType, SINGLE) + .createNewSingleVector("d", allocator, null); + FieldVector indexVector = new FieldType(false, indexType, SINGLE) + .createNewSingleVector("i", allocator, null);) { + new BatchedDictionary( + dictVector, + indexVector, + false + ).close(); + } + }); + } + + @Test + public void testSuffix() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + assertEquals("vector-dictionary", dictionary.getVector().getField().getName()); + assertEquals("vector", dictionary.getIndexVector().getField().getName()); + } + } + + @Test + public void testTypedUnique() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, BAZ); + dictionary.mark(); + assertEquals(3, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, "foo", "bar", "baz"); + } + } + + @Test + public void testExistingdUnique() throws IOException { + List vectors = existingVectors(SINGLE); + try (BatchedDictionary dictionary = new BatchedDictionary( + vectors.get(0), + vectors.get(1), + false + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, BAZ); + dictionary.mark(); + assertEquals(3, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, "foo", "bar", "baz"); + } + } + + @Test + public void testTypedUniqueNulls() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setNull(0); + dictionary.setSafe(1, BAR); + dictionary.setNull(2); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, null, "bar", null); + } + } + + @Test + public void testExistingUniqueNulls() throws IOException { + List vectors = existingVectors(SINGLE); + try (BatchedDictionary dictionary = new BatchedDictionary( + vectors.get(0), + vectors.get(1), + false + )) { + dictionary.setNull(0); + dictionary.setSafe(1, BAR); + dictionary.setNull(2); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, null, "bar", null); + } + } + + @Test + public void testTypedReused() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, FOO); + dictionary.setSafe(3, FOO); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(4); + assertDecoded(dictionary, "foo", "bar", "foo", "foo"); + } + } + + @Test + public void testTypedResetReplacement() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "foo", "bar"); + + dictionary.reset(); + assertEquals(0, dictionary.getHashTable().size); + dictionary.setSafe(0, BAZ); + dictionary.setNull(1); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "baz", null); + } + } + + @Test + public void testTypedResetDelta() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + DELTA, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "foo", "bar"); + + dictionary.reset(); + assertEquals(2, dictionary.getHashTable().size); + dictionary.setSafe(0, BAZ); + dictionary.setSafe(1, FOO); + dictionary.setSafe(2, BAR); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertEquals(3, dictionary.getHashTable().size); + + // on read the dictionaries must be merged. Let's look at the int index. + UInt2Vector index = (UInt2Vector) dictionary.getIndexVector(); + assertEquals(2, index.get(0)); + assertEquals(0, index.get(1)); + assertEquals(1, index.get(2)); + } + } + + @Test + public void testTypedNullData() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + false, + allocator + )) { + dictionary.setSafe(0, null); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, null, "bar"); + } + } + + @Test + public void testReplacementNotAllowed() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + true, + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.reset(); + dictionary.setSafe(0, BAR); + assertThrows(IllegalStateException.class, () -> { + dictionary.setSafe(1, BAZ); + }); + } + } + + void assertDecoded(BatchedDictionary dictionary, String... expected) { + try (ValueVector decoded = DictionaryEncoder.decode(dictionary.getIndexVector(), dictionary)) { + assertEquals(expected.length, decoded.getValueCount()); + for (int i = 0; i < expected.length; i++) { + if (expected[i] == null) { + assertNull(decoded.getObject(i)); + } else { + assertNotNull(decoded.getObject(i)); + assertEquals(expected[i], decoded.getObject(i).toString()); + } + } + } + } + + private List existingVectors(DictionaryEncoding encoding) { + FieldVector dictionaryVector = new FieldType(false, validDictionaryTypes.get(0), null) + .createNewSingleVector("vector-dictionary", allocator, null); + FieldVector indexVector = new FieldType(true, validIndexTypes.get(0), encoding) + .createNewSingleVector("fector", allocator, null); + return Arrays.asList(new FieldVector[] { dictionaryVector, indexVector }); + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java index 8663c0c49990d..509b7b059829d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java @@ -20,18 +20,29 @@ import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.time.LocalTime; import java.time.ZoneId; import java.time.ZoneOffset; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -48,6 +59,7 @@ import org.apache.arrow.vector.UInt2Vector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -74,6 +86,8 @@ import org.apache.arrow.vector.complex.writer.UInt2Writer; import org.apache.arrow.vector.complex.writer.UInt4Writer; import org.apache.arrow.vector.complex.writer.UInt8Writer; +import org.apache.arrow.vector.dictionary.BaseDictionary; +import org.apache.arrow.vector.dictionary.BatchedDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -87,6 +101,8 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.provider.Arguments; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -400,7 +416,7 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(2, vector2.getObject(4)); Assert.assertEquals(null, vector2.getObject(5)); - Dictionary dictionary1 = provider.lookup(1L); + BaseDictionary dictionary1 = provider.lookup(1L); Assert.assertNotNull(dictionary1); VarCharVector dictionaryVector = ((VarCharVector) dictionary1.getVector()); Assert.assertEquals(3, dictionaryVector.getValueCount()); @@ -408,7 +424,7 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(new Text("bar"), dictionaryVector.getObject(1)); Assert.assertEquals(new Text("baz"), dictionaryVector.getObject(2)); - Dictionary dictionary2 = provider.lookup(2L); + BaseDictionary dictionary2 = provider.lookup(2L); Assert.assertNotNull(dictionary2); dictionaryVector = ((VarCharVector) dictionary2.getVector()); Assert.assertEquals(3, dictionaryVector.getValueCount()); @@ -470,7 +486,7 @@ protected void validateNestedDictionary(VectorSchemaRoot root, DictionaryProvide Assert.assertEquals(Arrays.asList(0), vector.getObject(1)); Assert.assertEquals(Arrays.asList(1), vector.getObject(2)); - Dictionary dictionary = provider.lookup(2L); + BaseDictionary dictionary = provider.lookup(2L); Assert.assertNotNull(dictionary); VarCharVector dictionaryVector = ((VarCharVector) dictionary.getVector()); Assert.assertEquals(2, dictionaryVector.getValueCount()); @@ -846,4 +862,293 @@ protected void validateListAsMapData(VectorSchemaRoot root) { } } } + + /** + * Utility to write permutations of dictionary encoding. + * + * state == 1, one delta dictionary. + * state == 2, one standalone dictionary. + * state == 3, one of each + * state == 4, delta with nothing at start and end + * state == 5, both deltas + * state == 6, both deltas and standalone + * state == 7, replacement dictionary + */ + protected void writeDataMultiBatchWithDictionaries(OutputStream stream, int state) throws IOException { + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + DictionaryEncoding deltaEncoding = + new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true); + DictionaryEncoding replacementEncoding = + new DictionaryEncoding(24, false, new ArrowType.Int(16, false), false); + DictionaryEncoding deltaCEncoding = + new DictionaryEncoding(1, false, new ArrowType.Int(16, false), true); + DictionaryEncoding replacementEncodingUpdated = + new DictionaryEncoding(2, false, new ArrowType.Int(16, false), false); + + boolean isFile = stream instanceof FileOutputStream; + try (BatchedDictionary vectorA = newDictionary("vectorA", deltaEncoding, isFile); + BatchedDictionary vectorB = newDictionary("vectorB", replacementEncoding, isFile); + BatchedDictionary vectorC = newDictionary("vectorC", deltaCEncoding, isFile); + BatchedDictionary vectorD = newDictionary("vectorD", replacementEncodingUpdated, isFile);) { + switch (state) { + case 1: + provider.put(vectorA); + break; + case 2: + provider.put(vectorB); + break; + case 3: + provider.put(vectorA); + provider.put(vectorB); + break; + case 4: + provider.put(vectorC); + break; + case 5: + provider.put(vectorA); + provider.put(vectorC); + break; + case 6: + provider.put(vectorA); + provider.put(vectorB); + provider.put(vectorC); + break; + case 7: + provider.put(vectorD); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + + VectorSchemaRoot root = null; + switch (state) { + case 1: + root = VectorSchemaRoot.of(vectorA.getIndexVector()); + break; + case 2: + root = VectorSchemaRoot.of(vectorB.getIndexVector()); + break; + case 3: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorB.getIndexVector()); + break; + case 4: + root = VectorSchemaRoot.of(vectorC.getIndexVector()); + break; + case 5: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorC.getIndexVector()); + break; + case 6: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorB.getIndexVector(), vectorC.getIndexVector()); + break; + case 7: + root = VectorSchemaRoot.of(vectorD.getIndexVector()); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + + ArrowWriter arrowWriter = null; + try { + if (stream instanceof FileOutputStream) { + FileChannel channel = ((FileOutputStream) stream).getChannel(); + arrowWriter = new ArrowFileWriter(root, provider, channel); + } else { + arrowWriter = new ArrowStreamWriter(root, provider, stream); + } + + vectorA.setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(0, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(1, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorC.setNull(0); + vectorC.setNull(1); + vectorD.setSafe(0, "porro".getBytes(StandardCharsets.UTF_8)); + vectorD.setSafe(1, "amet".getBytes(StandardCharsets.UTF_8)); + + // batch 1 + arrowWriter.start(); + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 2 + vectorA.setSafe(0, "meep".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorC.setSafe(0, "qui".getBytes(StandardCharsets.UTF_8)); + vectorC.setSafe(1, "dolor".getBytes(StandardCharsets.UTF_8)); + vectorD.setSafe(0, "amet".getBytes(StandardCharsets.UTF_8)); + if (state == 7) { + vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 3 + vectorA.setNull(0); + vectorA.setNull(1); + vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorB.setNull(1); + vectorC.setNull(0); + vectorC.setSafe(1, "qui".getBytes(StandardCharsets.UTF_8)); + vectorD.setNull(0); + if (state == 7) { + vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 4 + vectorA.setSafe(0, "bar".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "zap".getBytes(StandardCharsets.UTF_8)); + vectorB.setNull(0); + vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorC.setNull(0); + vectorC.setNull(1); + if (state == 7) { + vectorD.setSafe(0, "quia".getBytes(StandardCharsets.UTF_8)); + } + vectorD.setNull(1); + + root.setRowCount(2); + arrowWriter.writeBatch(); + + arrowWriter.end(); + } catch (Exception e) { + if (arrowWriter != null) { + arrowWriter.close(); + } + throw e; + } + } + } + + Map valuesPerBlock = new HashMap(); + + { + valuesPerBlock.put(0, new String[][]{ + {"foo", "bar"}, + {"lorem", "ipsum"}, + {null, null}, + {"porro", "amet"} + }); + valuesPerBlock.put(1, new String[][]{ + {"meep", "bar"}, + {"ipsum", "lorem"}, + {"qui", "dolor"}, + {"amet", "quia"} + }); + valuesPerBlock.put(2, new String[][]{ + {null, null}, + {"ipsum", null}, + {null, "qui"}, + {null, "quia"} + }); + valuesPerBlock.put(3, new String[][]{ + {"bar", "zap"}, + {null, "lorem"}, + {null, null}, + {"quia", null} + }); + } + + protected void assertDictionary(FieldVector encoded, ArrowReader reader, String... expected) throws Exception { + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + Assertions.assertEquals(expected.length, encoded.getValueCount()); + for (int i = 0; i < expected.length; i++) { + if (expected[i] == null) { + Assertions.assertNull(decoded.getObject(i)); + } else { + assertNotNull(decoded.getObject(i)); + Assertions.assertEquals(expected[i], decoded.getObject(i).toString()); + } + } + } + } + + protected void assertBlock(File file, int block, int state) throws Exception { + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + reader.loadRecordBatch(reader.getRecordBlocks().get(block)); + assertBlock(reader, block, state); + } + } + + protected void assertBlock(ArrowReader reader, int block, int state) throws Exception { + VectorSchemaRoot r = reader.getVectorSchemaRoot(); + FieldVector dictA = r.getVector("vectorA"); + FieldVector dictB = r.getVector("vectorB"); + FieldVector dictC = r.getVector("vectorC"); + FieldVector dictD = r.getVector("vectorD"); + + switch (state) { + case 1: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertNull(dictB); + assertNull(dictC); + assertNull(dictD); + break; + case 2: + assertNull(dictA); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertNull(dictC); + assertNull(dictD); + break; + case 3: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertNull(dictC); + assertNull(dictD); + break; + case 4: + assertNull(dictA); + assertNull(dictB); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case 5: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertNull(dictB); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case 6: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case 7: + assertNull(dictA); + assertNull(dictB); + assertNull(dictC); + assertDictionary(dictD, reader, valuesPerBlock.get(block)[3]); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + } + + protected static Collection dictionaryParams() { + List params = new ArrayList<>(); + for (int i = 1; i < 8; i++) { + params.add(Arguments.of(i)); + } + return params; + } + + protected BatchedDictionary newDictionary(String name, DictionaryEncoding encoding, boolean isFile) { + return new BatchedDictionary( + name, + encoding, + new ArrowType.Utf8(), + new ArrowType.Int(16, false), + isFile, + allocator + ); + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java index 4fb5822786083..7f66cee6e55ba 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java @@ -20,11 +20,13 @@ import static java.nio.channels.Channels.newChannel; import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -33,19 +35,36 @@ import java.util.List; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.types.pojo.Field; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TestArrowFile extends BaseFileTest { private static final Logger LOGGER = LoggerFactory.getLogger(TestArrowFile.class); + // overriding here since a number of other UTs sharing the BaseFileTest class use + // legacy JUnit. + @BeforeEach + public void init() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + @Test public void testWrite() throws IOException { File file = new File("target/mytest_write.arrow"); @@ -131,4 +150,67 @@ public void testFileStreamHasEos() throws IOException { } } } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionaries(int state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == 7) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + for (int i = 0; i < 4; i++) { + reader.loadNextBatch(); + assertBlock(reader, i, state); + } + } + } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionariesOutOfOrder(int state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_ooo_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == 7) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + int[] order = new int[] {2, 1, 3, 0}; + for (int i = 0; i < 4; i++) { + int block = order[i]; + reader.loadRecordBatch(reader.getRecordBlocks().get(block)); + assertBlock(reader, block, state); + } + } + } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionariesSeek(int state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_seek_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == 7) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + for (int i = 0; i < 4; i++) { + assertBlock(file, i, state); + } + } + } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java index 9348cd3a66708..c13328f9e8a1a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java @@ -27,14 +27,32 @@ import java.nio.channels.Channels; import java.util.Collections; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; public class TestArrowStream extends BaseFileTest { + + // overriding here since a number of other UTs sharing the BaseFileTest class use + // legacy JUnit. + @BeforeEach + public void init() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + @Test public void testEmptyStream() throws IOException { Schema schema = MessageSerializerTest.testSchema(); @@ -64,7 +82,7 @@ public void testStreamZeroLengthBatch() throws IOException { try (IntVector vector = new IntVector("foo", allocator);) { Schema schema = new Schema(Collections.singletonList(vector.getField())); try (VectorSchemaRoot root = - new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); + new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) { vector.setValueCount(0); root.setRowCount(0); @@ -131,7 +149,7 @@ public void testReadWriteMultipleBatches() throws IOException { try (IntVector vector = new IntVector("foo", allocator);) { Schema schema = new Schema(Collections.singletonList(vector.getField())); try (VectorSchemaRoot root = - new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); + new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) { writeBatchData(writer, vector, root); } @@ -144,4 +162,21 @@ public void testReadWriteMultipleBatches() throws IOException { validateBatchData(reader, vector); } } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionaries(int state) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeDataMultiBatchWithDictionaries(out, state); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + out.close(); + + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { + for (int i = 0; i < 4; i++) { + reader.loadNextBatch(); + assertBlock(reader, i, state); + } + } + } }