diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java b/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java index d6479a3ba4ca8..132a2044a613a 100644 --- a/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java +++ b/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java @@ -29,7 +29,7 @@ import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -110,7 +110,7 @@ void export(ArrowArray array, FieldVector vector, DictionaryProvider dictionaryP } if (dictionaryEncoding != null) { - Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + BaseDictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); checkNotNull(dictionary, "Dictionary lookup failed on export of dictionary encoded array"); data.dictionary = ArrowArray.allocateNew(allocator); diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java b/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java index 7132887ddeed5..af851880c5449 100644 --- a/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java +++ b/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java @@ -29,7 +29,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -103,7 +103,7 @@ private void doImport(ArrowArray.Snapshot snapshot) { DictionaryEncoding encoding = vector.getField().getDictionary(); checkNotNull(encoding, "Missing encoding on import of ArrowArray with dictionary"); - Dictionary dictionary = dictionaryProvider.lookup(encoding.getId()); + BaseDictionary dictionary = dictionaryProvider.lookup(encoding.getId()); checkNotNull(dictionary, "Dictionary lookup failed on import of ArrowArray with dictionary"); // reset the dictionary vector to the initial state diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java index b39a3be9b842f..b8b504d249100 100644 --- a/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java +++ b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java @@ -26,7 +26,7 @@ import java.util.stream.Collectors; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.Schema; @@ -52,12 +52,12 @@ final class ArrowArrayStreamReader extends ArrowReader { } @Override - public Map getDictionaryVectors() { + public Map getDictionaryVectors() { return provider.getDictionaryIds().stream().collect(Collectors.toMap(Function.identity(), provider::lookup)); } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { return provider.lookup(id); } diff --git a/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java b/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java index 4a84f11704c9a..fa534525efcf7 100644 --- a/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java +++ b/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java @@ -21,7 +21,8 @@ import java.util.Map; import java.util.Set; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; +import org.apache.arrow.vector.dictionary.BatchedDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; /** @@ -39,14 +40,14 @@ */ public class CDataDictionaryProvider implements DictionaryProvider, AutoCloseable { - private final Map map; + private final Map map; public CDataDictionaryProvider() { this.map = new HashMap<>(); } - void put(Dictionary dictionary) { - Dictionary previous = map.put(dictionary.getEncoding().getId(), dictionary); + void put(BaseDictionary dictionary) { + BaseDictionary previous = map.put(dictionary.getEncoding().getId(), dictionary); if (previous != null) { previous.getVector().close(); } @@ -58,16 +59,25 @@ public final Set getDictionaryIds() { } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { return map.get(id); } @Override public void close() { - for (Dictionary dictionary : map.values()) { + for (BaseDictionary dictionary : map.values()) { dictionary.getVector().close(); } map.clear(); } + @Override + public void resetDictionaries() { + map.values().forEach( dictionary -> { + if (dictionary instanceof BatchedDictionary) { + ((BatchedDictionary) dictionary).reset(); + } + }); + } + } diff --git a/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java b/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java index 04d41a4e4f9b0..dda5867024042 100644 --- a/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java +++ b/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java @@ -28,7 +28,7 @@ import org.apache.arrow.c.jni.PrivateData; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; @@ -95,7 +95,7 @@ void export(ArrowSchema schema, Field field, DictionaryProvider dictionaryProvid } if (dictionaryEncoding != null) { - Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + BaseDictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); checkNotNull(dictionary, "Dictionary lookup failed on export of field with dictionary"); data.dictionary = ArrowSchema.allocateNew(allocator); diff --git a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java index 68d4fc2a81e68..6a09a075ecd50 100644 --- a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java @@ -42,6 +42,7 @@ import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowReader; @@ -244,7 +245,7 @@ void roundtrip(Schema schema, List batches, DictionaryProvider } assertThat(reader.loadNextBatch()).isFalse(); assertThat(reader.getDictionaryIds()).isEqualTo(provider.getDictionaryIds()); - for (Map.Entry entry : reader.getDictionaryVectors().entrySet()) { + for (Map.Entry entry : reader.getDictionaryVectors().entrySet()) { final FieldVector expected = provider.lookup(entry.getKey()).getVector(); final FieldVector actual = entry.getValue().getVector(); assertVectorsEqual(expected, actual); @@ -286,7 +287,7 @@ static class InMemoryArrowReader extends ArrowReader { } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { return provider.lookup(id); } @@ -296,7 +297,7 @@ public Set getDictionaryIds() { } @Override - public Map getDictionaryVectors() { + public Map getDictionaryVectors() { return getDictionaryIds().stream().collect(Collectors.toMap(Function.identity(), this::lookup)); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java index 516dab01d8a1b..9970e44a39759 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -30,7 +30,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.IpcOption; @@ -66,7 +66,7 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe } // Create and write dictionary batches for (Long id : dictionaryIds) { - final Dictionary dictionary = provider.lookup(id); + final BaseDictionary dictionary = provider.lookup(id); final FieldVector vector = dictionary.getVector(); final int count = vector.getValueCount(); // Do NOT close this root, as it does not actually own the vector. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index ad4ffcbebdec1..0269a41deeb17 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -37,7 +37,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; @@ -268,7 +268,7 @@ public boolean next() { if (dictionaries == null) { throw new IllegalStateException("Dictionary ownership was claimed by the application."); } - final Dictionary dictionary = dictionaries.lookup(id); + final BaseDictionary dictionary = dictionaries.lookup(id); if (dictionary == null) { throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); } @@ -410,12 +410,12 @@ public void onNext(ArrowMessage msg) { } final List fields = new ArrayList<>(); - final Map dictionaryMap = new HashMap<>(); + final Map dictionaryMap = new HashMap<>(); for (final Field originalField : schema.getFields()) { final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); fields.add(updatedField); } - for (final Map.Entry entry : dictionaryMap.entrySet()) { + for (final Map.Entry entry : dictionaryMap.entrySet()) { dictionaries.put(entry.getValue()); } schema = new Schema(fields, schema.getCustomMetadata()); diff --git a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/TestArrowBufPointer.java b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/TestArrowBufPointer.java index 49c10787fbe8d..0fb4308ed516b 100644 --- a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/TestArrowBufPointer.java +++ b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/TestArrowBufPointer.java @@ -206,7 +206,7 @@ public int hashCode(ArrowBuf buf, long offset, long length) { @Override public int hashCode(byte[] buf, int offset, int length) { - return 0; + throw new UnsupportedOperationException("Not used in UT."); } @Override diff --git a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/hash/TestArrowBufHasher.java b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/hash/TestArrowBufHasher.java index d016b7b50e25b..70c488f05217b 100644 --- a/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/hash/TestArrowBufHasher.java +++ b/java/memory/memory-core/src/test/java/org/apache/arrow/memory/util/hash/TestArrowBufHasher.java @@ -150,11 +150,11 @@ private void verifyHashCodeNotEqual(ArrowBuf buf1, byte[] ba1, int offset1, int @Parameterized.Parameters(name = "hasher = {0}") public static Collection getHasher() { return Arrays.asList( - new Object[] {SimpleHasher.class.getSimpleName(), - SimpleHasher.INSTANCE}, - new Object[] {MurmurHasher.class.getSimpleName(), - new MurmurHasher() - } + new Object[] {SimpleHasher.class.getSimpleName(), + SimpleHasher.INSTANCE}, + new Object[] {MurmurHasher.class.getSimpleName(), + new MurmurHasher() + } ); } } diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java index 714cb416bf996..060ce3487de04 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java @@ -40,6 +40,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; @@ -207,7 +208,7 @@ public void testFlatDictionary() throws IOException { Assert.assertEquals(1, readVector.getObject(4)); Assert.assertEquals(2, readVector.getObject(5)); - Dictionary dictionary = reader.lookup(1L); + BaseDictionary dictionary = reader.lookup(1L); Assert.assertNotNull(dictionary); VarCharVector dictionaryVector = ((VarCharVector) dictionary.getVector()); Assert.assertEquals(3, dictionaryVector.getValueCount()); @@ -289,7 +290,7 @@ public void testNestedDictionary() throws IOException { Assert.assertEquals(Arrays.asList(0), readVector.getObject(1)); Assert.assertEquals(Arrays.asList(1), readVector.getObject(2)); - Dictionary readDictionary = reader.lookup(2L); + BaseDictionary readDictionary = reader.lookup(2L); Assert.assertNotNull(readDictionary); VarCharVector dictionaryVector = ((VarCharVector) readDictionary.getVector()); Assert.assertEquals(2, dictionaryVector.getValueCount()); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java new file mode 100644 index 0000000000000..15139c60c33d0 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BaseDictionary.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ipc.ArrowWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +/** + * Base interface for various dictionary implementations. Implementations include + * {@link Dictionary} for encoding a complete vector and {@link BatchedDictionary} + * for continuous encoding of a vector. + * These methods provide means of accessing the dictionary vector containing the + * encoded data. + */ +public interface BaseDictionary { + + /** + * The dictionary vector containing unique entries. + */ + FieldVector getVector(); + + /** + * The encoding used for the dictionary vector. + */ + DictionaryEncoding getEncoding(); + + /** + * The type of the dictionary vector. + */ + ArrowType getVectorType(); + + /** + * Mark the dictionary as complete for the batch. Called by the {@link ArrowWriter} + * on {@link ArrowWriter#writeBatch()}. + * @return The number of values written to the dictionary. + */ + int mark(); + + /** + * Resets the dictionary to be used for a new batch. Called by the {@link ArrowWriter} on + * {@link ArrowWriter#writeBatch()}. + */ + void reset(); + +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java new file mode 100644 index 0000000000000..75cd99cbb3a91 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/BatchedDictionary.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.util.hash.MurmurHasher; +import org.apache.arrow.util.VisibleForTesting; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** + * A dictionary implementation for continuous encoding of data in a dictionary and + * index vector as opposed to the {@link Dictionary} that encodes a complete vector. + * Supports delta or replacement encoding. + */ +public class BatchedDictionary implements Closeable, BaseDictionary { + + private final DictionaryEncoding encoding; + + private final BaseVariableWidthVector dictionary; + + private final BaseIntVector indexVector; + + private final DictionaryHashTable hashTable; + + private int deltaIndex; + + private int dictionaryIndex; + + private boolean oneTimeEncoding; + + private boolean wasReset; + + private boolean vectorsProvided; + + /** + * Creates a dictionary and index vector with the respective types. The dictionary vector + * will be named "{name}-dictionary". + *

+ * To use this dictionary, provide the dictionary vector to a {@link DictionaryProvider}, + * add the {@link #getIndexVector()} to the {@link org.apache.arrow.vector.VectorSchemaRoot} + * and call the {@link #setSafe(int, byte[], int, int)} or other set methods. + * + * @param name A name for the vector and dictionary. + * @param encoding The dictionary encoding to use. + * @param dictionaryType The type of the dictionary data. + * @param indexType The type of the encoded dictionary index. + * @param allocator The allocator to use. + */ + public BatchedDictionary( + String name, + DictionaryEncoding encoding, + ArrowType dictionaryType, + ArrowType indexType, + BufferAllocator allocator + ) { + this(name, encoding, dictionaryType, indexType, allocator, "-dictionary"); + } + + /** + * Creates a dictionary index vector with the respective types. + * + * @param name A name for the vector and dictionary. + * @param encoding The dictionary encoding to use. + * @param dictionaryType The type of the dictionary data. + * @param indexType The type of the encoded dictionary index. + * @param allocator The allocator to use. + * @param suffix A non-null suffix to append to the name of the dictionary. + */ + public BatchedDictionary( + String name, + DictionaryEncoding encoding, + ArrowType dictionaryType, + ArrowType indexType, + BufferAllocator allocator, + String suffix + ) { + this(name, encoding, dictionaryType, indexType, allocator, suffix, false); + } + + /** + * Creates a dictionary index vector with the respective types. + * + * @param name A name for the vector and dictionary. + * @param encoding The dictionary encoding to use. + * @param dictionaryType The type of the dictionary data. + * @param indexType The type of the encoded dictionary index. + * @param allocator The allocator to use. + * @param suffix A non-null suffix to append to the name of the dictionary. + * @param oneTimeEncoding A mode where the entries can be added to the dictionary until + * the first stream batch is written. After that, any new entries + * to the dictionary will throw an exception. + */ + public BatchedDictionary( + String name, + DictionaryEncoding encoding, + ArrowType dictionaryType, + ArrowType indexType, + BufferAllocator allocator, + String suffix, + boolean oneTimeEncoding + ) { + this.encoding = encoding; + this.oneTimeEncoding = oneTimeEncoding; + if (dictionaryType.getTypeID() != ArrowType.ArrowTypeID.Utf8 && + dictionaryType.getTypeID() != ArrowType.ArrowTypeID.Binary) { + throw new IllegalArgumentException("Dictionary must be a superclass of 'BaseVariableWidthVector' " + + "such as 'VarCharVector'."); + } + if (indexType.getTypeID() != ArrowType.ArrowTypeID.Int) { + throw new IllegalArgumentException("Index vector must be a superclass type of 'BaseIntVector' " + + "such as 'IntVector' or 'Uint4Vector'."); + } + FieldVector vector = new FieldType(false, dictionaryType, null) + .createNewSingleVector(name + suffix, allocator, null); + dictionary = (BaseVariableWidthVector) vector; + vector = new FieldType(true, indexType, encoding) + .createNewSingleVector(name, allocator, null); + indexVector = (BaseIntVector) vector; + hashTable = new DictionaryHashTable(); + } + + /** + * Creates a dictionary that will populate the provided vectors with data. Useful if + * dictionaries need to be children of a parent vector. + * WARNING: The caller is responsible for closing the provided vectors. + * + * @param dictionary The dictionary to hold the original data. + * @param indexVector The index to store the encoded offsets. + */ + public BatchedDictionary( + FieldVector dictionary, + FieldVector indexVector + ) { + this(dictionary, indexVector, false); + } + + /** + * Creates a dictionary that will populate the provided vectors with data. Useful if + * dictionaries need to be children of a parent vector. + * WARNING: The caller is responsible for closing the provided vectors. + * + * @param dictionary The dictionary to hold the original data. + * @param indexVector The index to store the encoded offsets. + * @param oneTimeEncoding A mode where the entries can be added to the dictionary until + * the first stream batch is written. After that, any new entries + * to the dictionary will throw an exception. + */ + public BatchedDictionary( + FieldVector dictionary, + FieldVector indexVector, + boolean oneTimeEncoding + ) { + this.encoding = dictionary.getField().getDictionary(); + this.oneTimeEncoding = oneTimeEncoding; + vectorsProvided = true; + if (!(BaseVariableWidthVector.class.isAssignableFrom(dictionary.getClass()))) { + throw new IllegalArgumentException("Dictionary must be a superclass of 'BaseVariableWidthVector' " + + "such as 'VarCharVector'."); + } + if (dictionary.getField().isNullable()) { + throw new IllegalArgumentException("Dictionary must be non-nullable."); + } + this.dictionary = (BaseVariableWidthVector) dictionary; + if (!(BaseIntVector.class.isAssignableFrom(indexVector.getClass()))) { + throw new IllegalArgumentException("Index vector must be a superclass type of 'BaseIntVector' " + + "such as 'IntVector' or 'Uint4Vector'."); + } + this.indexVector = (BaseIntVector) indexVector; + hashTable = new DictionaryHashTable(); + } + + /** + * The index vector. + */ + public FieldVector getIndexVector() { + return indexVector; + } + + @Override + public FieldVector getVector() { + return dictionary; + } + + @Override + public ArrowType getVectorType() { + return dictionary.getField().getType(); + } + + @Override + public DictionaryEncoding getEncoding() { + return encoding; + } + + /** + * Considers the entire byte array as the dictionary value. If the value is null, + * a null will be written to the index. + * + * @param index the value to change + * @param value the value to write. + */ + public void setSafe(int index, byte[] value) { + if (value == null) { + setNull(index); + return; + } + setSafe(index, value, 0, value.length); + } + + /** + * Encodes the given range in the dictionary. If the value is null, a null will be + * written to the index. + * + * @param index the value to change + * @param value the value to write. + * @param offset An offset into the value array. + * @param len The length of the value to write. + */ + public void setSafe(int index, byte[] value, int offset, int len) { + if (value == null || len == 0) { + setNull(index); + return; + } + int di = getIndex(value, offset, len); + indexVector.setWithPossibleTruncate(index, di); + } + + /** + * Set the element at the given index to null. + * + * @param index the value to change + */ + public void setNull(int index) { + indexVector.setNull(index); + } + + @Override + public void close() throws IOException { + if (!vectorsProvided) { + dictionary.close(); + indexVector.close(); + } + } + + @Override + public int mark() { + dictionary.setValueCount(dictionaryIndex); + // not setting the index vector value count. That will happen when the user calls + // VectorSchemaRoot#setRowCount(). + if (wasReset && oneTimeEncoding && !encoding.isDelta()) { + return 0; + } + return dictionaryIndex; + } + + @Override + public void reset() { + wasReset = true; + if (!oneTimeEncoding) { + dictionaryIndex = 0; + dictionary.reset(); + } + indexVector.reset(); + if (!oneTimeEncoding && !encoding.isDelta()) { + // replacement mode. + deltaIndex = 0; + hashTable.clear(); + } + } + + private int getIndex(byte[] value, int offset, int len) { + int hash = MurmurHasher.hashCode(value, offset, len, 0); + int i = hashTable.getIndex(hash); + if (i >= 0) { + return i; + } else { + if (wasReset && oneTimeEncoding && !encoding.isDelta()) { + throw new IllegalStateException("Dictionary was reset, not delta encoded and configured for onetime encoding."); + } + hashTable.addEntry(hash, deltaIndex); + dictionary.setSafe(dictionaryIndex++, value, offset, len); + return deltaIndex++; + } + } + + @VisibleForTesting + DictionaryHashTable getHashTable() { + return hashTable; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java index 6f40e5814b972..bffc0bbc04461 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -28,7 +28,7 @@ * A dictionary (integer to Value mapping) that is used to facilitate * dictionary encoding compression. */ -public class Dictionary { +public class Dictionary implements BaseDictionary { private final DictionaryEncoding encoding; private final FieldVector dictionary; @@ -72,4 +72,14 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(encoding, dictionary); } + + @Override + public int mark() { + return dictionary.getValueCount(); + } + + @Override + public void reset() { + // no-op + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index 4368501ffc7b5..959a1eb1d6627 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -38,20 +38,20 @@ public class DictionaryEncoder { private final DictionaryHashTable hashTable; - private final Dictionary dictionary; + private final BaseDictionary dictionary; private final BufferAllocator allocator; /** * Construct an instance. */ - public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator) { + public DictionaryEncoder(BaseDictionary dictionary, BufferAllocator allocator) { this (dictionary, allocator, SimpleHasher.INSTANCE); } /** * Construct an instance. */ - public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator, ArrowBufHasher hasher) { + public DictionaryEncoder(BaseDictionary dictionary, BufferAllocator allocator, ArrowBufHasher hasher) { this.dictionary = dictionary; this.allocator = allocator; hashTable = new DictionaryHashTable(dictionary.getVector(), hasher); @@ -64,7 +64,7 @@ public DictionaryEncoder(Dictionary dictionary, BufferAllocator allocator, Arrow * @param dictionary dictionary used for encoding * @return dictionary encoded vector */ - public static ValueVector encode(ValueVector vector, Dictionary dictionary) { + public static ValueVector encode(ValueVector vector, BaseDictionary dictionary) { DictionaryEncoder encoder = new DictionaryEncoder(dictionary, vector.getAllocator()); return encoder.encode(vector); } @@ -76,7 +76,7 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { * @param dictionary dictionary used to decode the values * @return vector with values restored from dictionary */ - public static ValueVector decode(ValueVector indices, Dictionary dictionary) { + public static ValueVector decode(ValueVector indices, BaseDictionary dictionary) { return decode(indices, dictionary, indices.getAllocator()); } @@ -88,7 +88,7 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary) { * @param allocator allocator the decoded values use * @return vector with values restored from dictionary */ - public static ValueVector decode(ValueVector indices, Dictionary dictionary, BufferAllocator allocator) { + public static ValueVector decode(ValueVector indices, BaseDictionary dictionary, BufferAllocator allocator) { int count = indices.getValueCount(); ValueVector dictionaryVector = dictionary.getVector(); int dictionaryCount = dictionaryVector.getValueCount(); @@ -185,8 +185,11 @@ static void retrieveIndexVector( public ValueVector encode(ValueVector vector) { Field valueField = vector.getField(); - FieldType indexFieldType = new FieldType(valueField.isNullable(), dictionary.getEncoding().getIndexType(), - dictionary.getEncoding(), valueField.getMetadata()); + FieldType indexFieldType = new FieldType( + valueField.isNullable(), + dictionary.getEncoding().getIndexType(), + dictionary.getEncoding(), + valueField.getMetadata()); Field indexField = new Field(valueField.getName(), indexFieldType, null); // vector to hold our indices (dictionary encoded values) @@ -211,8 +214,9 @@ public ValueVector encode(ValueVector vector) { /** * Decodes a vector with the dictionary in this encoder. * - * {@link DictionaryEncoder#decode(ValueVector, Dictionary, BufferAllocator)} should be used instead if only decoding - * is required as it can avoid building the {@link DictionaryHashTable} which only makes sense when encoding. + * {@link DictionaryEncoder#decode(ValueVector, BaseDictionary, BufferAllocator)} + * should be used instead if only decoding is required as it can avoid building the + * {@link DictionaryHashTable} which only makes sense when encoding. */ public ValueVector decode(ValueVector indices) { return decode(indices, dictionary, allocator); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java index 9926a8e2a637f..c54a50601fa00 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java @@ -96,9 +96,11 @@ public DictionaryHashTable(int initialCapacity, ValueVector dictionary, ArrowBuf this.hasher = hasher; - // build hash table - for (int i = 0; i < this.dictionary.getValueCount(); i++) { - put(i); + if (dictionary != null) { + // build hash table + for (int i = 0; i < this.dictionary.getValueCount(); i++) { + put(i); + } } } @@ -110,6 +112,18 @@ public DictionaryHashTable(ValueVector dictionary) { this(dictionary, SimpleHasher.INSTANCE); } + /** + * Creates an empty table used for batch writing of dictionaries. + */ + public DictionaryHashTable() { + this(DEFAULT_INITIAL_CAPACITY, null, SimpleHasher.INSTANCE); + + if (table == EMPTY_TABLE) { + inflateTable(threshold); + } + } + + /** * Compute the capacity with given threshold and create init table. */ @@ -194,6 +208,34 @@ void createEntry(int hash, int index, int bucketIndex) { size++; } + /** + * Returns the corresponding dictionary index entry given a hash code. If the hash has + * not been written to the table, returns -1. + * + * @param hash The hash to lookup. + * @return The dictionary index if present, -1 if not. + */ + int getIndex(int hash) { + int i = indexFor(hash, table.length); + for (DictionaryHashTable.Entry e = table[i]; e != null; e = e.next) { + if (e.hash == hash) { + return e.index; + } + } + return -1; + } + + /** + * Adds an entry to the hash table. + * + * @param hash The hash to add. + * @param index The corresponding dictionary index. + */ + void addEntry(int hash, int index) { + int bucketIndex = indexFor(hash, table.length); + addEntry(hash, index, bucketIndex); + } + /** * Add Entry at the specified location of the table. */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java index f64c32be0f3e9..97d1ecd21730d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -30,24 +30,29 @@ public interface DictionaryProvider { /** Return the dictionary for the given ID. */ - Dictionary lookup(long id); + BaseDictionary lookup(long id); /** Get all dictionary IDs. */ Set getDictionaryIds(); + /** + * Reset all dictionaries associated with this provider. + */ + void resetDictionaries(); + /** * Implementation of {@link DictionaryProvider} that is backed by a hash-map. */ class MapDictionaryProvider implements AutoCloseable, DictionaryProvider { - private final Map map; + private final Map map; /** * Constructs a new instance from the given dictionaries. */ - public MapDictionaryProvider(Dictionary... dictionaries) { + public MapDictionaryProvider(BaseDictionary... dictionaries) { this.map = new HashMap<>(); - for (Dictionary dictionary : dictionaries) { + for (BaseDictionary dictionary : dictionaries) { put(dictionary); } } @@ -62,14 +67,14 @@ public MapDictionaryProvider(Dictionary... dictionaries) { @VisibleForTesting public void copyStructureFrom(DictionaryProvider other, BufferAllocator allocator) { for (Long id : other.getDictionaryIds()) { - Dictionary otherDict = other.lookup(id); - Dictionary newDict = new Dictionary(otherDict.getVector().getField().createVector(allocator), + BaseDictionary otherDict = other.lookup(id); + BaseDictionary newDict = new Dictionary(otherDict.getVector().getField().createVector(allocator), otherDict.getEncoding()); put(newDict); } } - public void put(Dictionary dictionary) { + public void put(BaseDictionary dictionary) { map.put(dictionary.getEncoding().getId(), dictionary); } @@ -79,15 +84,20 @@ public final Set getDictionaryIds() { } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { return map.get(id); } @Override public void close() { - for (Dictionary dictionary : map.values()) { + for (BaseDictionary dictionary : map.values()) { dictionary.getVector().close(); } } + + @Override + public void resetDictionaries() { + map.values().forEach( dictionary -> dictionary.reset() ); + } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java index 8500528a62b60..75cf63437b61c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/StructSubfieldEncoder.java @@ -109,7 +109,7 @@ public StructVector encode(StructVector vector, Map columnToDicti if (dictionaryId == null) { childrenFields.add(childVector.getField()); } else { - Dictionary dictionary = provider.lookup(dictionaryId); + BaseDictionary dictionary = provider.lookup(dictionaryId); Preconditions.checkNotNull(dictionary, "Dictionary not found with id:" + dictionaryId); FieldType indexFieldType = new FieldType(childVector.getField().isNullable(), dictionary.getEncoding().getIndexType(), dictionary.getEncoding()); @@ -177,7 +177,7 @@ public static StructVector decode(StructVector vector, List childFields = new ArrayList<>(); for (int i = 0; i < childCount; i++) { FieldVector childVector = getChildVector(vector, i); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); + BaseDictionary dictionary = getChildVectorDictionary(childVector, provider); // childVector is not encoded. if (dictionary == null) { childFields.add(childVector.getField()); @@ -192,7 +192,7 @@ public static StructVector decode(StructVector vector, // get child vector FieldVector childVector = getChildVector(vector, index); FieldVector decodedChildVector = getChildVector(decoded, index); - Dictionary dictionary = getChildVectorDictionary(childVector, provider); + BaseDictionary dictionary = getChildVectorDictionary(childVector, provider); if (dictionary == null) { childVector.makeTransferPair(decodedChildVector).splitAndTransfer(0, valueCount); } else { @@ -213,11 +213,11 @@ public static StructVector decode(StructVector vector, /** * Get the child vector dictionary, return null if not dictionary encoded. */ - private static Dictionary getChildVectorDictionary(FieldVector childVector, + private static BaseDictionary getChildVectorDictionary(FieldVector childVector, DictionaryProvider.MapDictionaryProvider provider) { DictionaryEncoding dictionaryEncoding = childVector.getField().getDictionary(); if (dictionaryEncoding != null) { - Dictionary dictionary = provider.lookup(dictionaryEncoding.getId()); + BaseDictionary dictionary = provider.lookup(dictionaryEncoding.getId()); Preconditions.checkNotNull(dictionary, "Dictionary not found with id:" + dictionary); return dictionary; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java index 8629cf93470b8..fcb6ca89abaf3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java @@ -22,6 +22,7 @@ import java.nio.channels.SeekableByteChannel; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -50,6 +51,7 @@ public class ArrowFileReader extends ArrowReader { private SeekableReadChannel in; private ArrowFooter footer; + private int estimatedDictionaryRecordBatch = 0; private int currentDictionaryBatch = 0; private int currentRecordBatch = 0; @@ -123,11 +125,6 @@ public void initialize() throws IOException { if (footer.getRecordBatches().size() == 0) { return; } - // Read and load all dictionaries from schema - for (int i = 0; i < dictionaries.size(); i++) { - ArrowDictionaryBatch dictionaryBatch = readDictionary(); - loadDictionary(dictionaryBatch); - } } /** @@ -140,21 +137,6 @@ public Map getMetaData() { return new HashMap<>(); } - /** - * Read a dictionary batch from the source, will be invoked after the schema has been read and - * called N times, where N is the number of dictionaries indicated by the schema Fields. - * - * @return the read ArrowDictionaryBatch - * @throws IOException on error - */ - public ArrowDictionaryBatch readDictionary() throws IOException { - if (currentDictionaryBatch >= footer.getDictionaries().size()) { - throw new IOException("Requested more dictionaries than defined in footer: " + currentDictionaryBatch); - } - ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); - return readDictionaryBatch(in, block, allocator); - } - /** Returns true if a batch was read, false if no more batches. */ @Override public boolean loadNextBatch() throws IOException { @@ -164,12 +146,59 @@ public boolean loadNextBatch() throws IOException { ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); ArrowRecordBatch batch = readRecordBatch(in, block, allocator); loadRecordBatch(batch); + try { + loadDictionaries(); + } catch (IOException e) { + batch.close(); + throw e; + } return true; } else { return false; } } + /** + * Loads any dictionaries that may be needed by the given record batch. It attempts + * to read as little as possible but may read in more deltas than are necessary for blocks + * toward the end of the file. + */ + private void loadDictionaries() throws IOException { + // initial load + if (currentDictionaryBatch == 0) { + for (int i = 0; i < dictionaries.size(); i++) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + ArrowDictionaryBatch dictionaryBatch = readDictionaryBatch(in, block, allocator); + loadDictionary(dictionaryBatch, false); + } + estimatedDictionaryRecordBatch++; + } else { + // we need to look for delta dictionaries. It involves a look-ahead, unfortunately. + HashSet visited = new HashSet(); + while (estimatedDictionaryRecordBatch < currentRecordBatch && + currentDictionaryBatch < footer.getDictionaries().size()) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + ArrowDictionaryBatch dictionaryBatch = readDictionaryBatch(in, block, allocator); + long dictionaryId = dictionaryBatch.getDictionaryId(); + if (visited.contains(dictionaryId)) { + // done + currentDictionaryBatch--; + estimatedDictionaryRecordBatch++; + } else if (!dictionaries.containsKey(dictionaryId)) { + throw new IOException("Dictionary ID " + dictionaryId + " was written " + + "after the initial batch. The file does not follow the IPC file protocol."); + } else if (!dictionaryBatch.isDelta()) { + throw new IOException("Dictionary ID " + dictionaryId + " was written as a replacement " + + "after the initial batch. Replacement dictionaries are not currently allowed in the IPC file protocol."); + } else { + loadDictionary(dictionaryBatch, true); + } + } + if (currentDictionaryBatch >= footer.getDictionaries().size()) { + estimatedDictionaryRecordBatch++; + } + } + } public List getDictionaryBlocks() throws IOException { ensureInitialized(); @@ -194,6 +223,7 @@ public boolean loadRecordBatch(ArrowBlock block) throws IOException { throw new IllegalArgumentException("Arrow block does not exist in record batches: " + block); } currentRecordBatch = blockIndex; + loadDictionaries(); return loadNextBatch(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 71db79087a3e4..781535aa94c10 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -29,7 +29,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -130,14 +130,24 @@ protected void endInternal(WriteChannel out) throws IOException { protected void ensureDictionariesWritten(DictionaryProvider provider, Set dictionaryIdsUsed) throws IOException { if (dictionariesWritten) { + for (long id : dictionaryIdsUsed) { + BaseDictionary dictionary = provider.lookup(id); + if (dictionary.getEncoding().isDelta()) { + writeDictionaryBatch(dictionary, false); + } else if (dictionary.mark() > 0) { + throw new IllegalStateException("Replacement dictionaries are not supported in the " + + "IPC file format. Dictionary ID: " + dictionary.getEncoding().getId()); + } + } return; } + dictionariesWritten = true; // Write out all dictionaries required. // Replacement dictionaries are not supported in the IPC file format. for (long id : dictionaryIdsUsed) { - Dictionary dictionary = provider.lookup(id); - writeDictionaryBatch(dictionary); + BaseDictionary dictionary = provider.lookup(id); + writeDictionaryBatch(dictionary, true); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java index 04c57d7e82fef..548338471a97a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java @@ -31,7 +31,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.NoCompressionCodec; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; @@ -49,7 +49,7 @@ public abstract class ArrowReader implements DictionaryProvider, AutoCloseable { protected final BufferAllocator allocator; private VectorLoader loader; private VectorSchemaRoot root; - protected Map dictionaries; + protected Map dictionaries; private boolean initialized = false; private final CompressionCodec.Factory compressionFactory; @@ -80,7 +80,7 @@ public VectorSchemaRoot getVectorSchemaRoot() throws IOException { * @return Map of dictionaries to dictionary id, empty if no dictionaries loaded * @throws IOException if reading of schema fails */ - public Map getDictionaryVectors() throws IOException { + public Map getDictionaryVectors() throws IOException { ensureInitialized(); return dictionaries; } @@ -92,7 +92,7 @@ public Map getDictionaryVectors() throws IOException { * @return the requested dictionary or null if not found */ @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { if (!initialized) { throw new IllegalStateException("Unable to lookup until reader has been initialized"); } @@ -141,7 +141,7 @@ public void close() throws IOException { public void close(boolean closeReadSource) throws IOException { if (initialized) { root.close(); - for (Dictionary dictionary : dictionaries.values()) { + for (BaseDictionary dictionary : dictionaries.values()) { dictionary.getVector().close(); } } @@ -151,6 +151,11 @@ public void close(boolean closeReadSource) throws IOException { } } + @Override + public void resetDictionaries() { + // no-op + } + /** * Close the underlying read source. * @@ -185,7 +190,7 @@ protected void initialize() throws IOException { Schema originalSchema = readSchema(); List fields = new ArrayList<>(originalSchema.getFields().size()); List vectors = new ArrayList<>(originalSchema.getFields().size()); - Map dictionaries = new HashMap<>(); + Map dictionaries = new HashMap<>(); // Convert fields with dictionaries to have the index type for (Field field : originalSchema.getFields()) { @@ -228,9 +233,9 @@ protected void loadRecordBatch(ArrowRecordBatch batch) { * * @param dictionaryBatch dictionary batch to load */ - protected void loadDictionary(ArrowDictionaryBatch dictionaryBatch) { + protected void loadDictionary(ArrowDictionaryBatch dictionaryBatch, boolean validateReplacements) { long id = dictionaryBatch.getDictionaryId(); - Dictionary dictionary = dictionaries.get(id); + BaseDictionary dictionary = dictionaries.get(id); if (dictionary == null) { throw new IllegalArgumentException("Dictionary ID " + id + " not defined in schema"); } @@ -242,6 +247,9 @@ protected void loadDictionary(ArrowDictionaryBatch dictionaryBatch) { VectorBatchAppender.batchAppend(vector, deltaVector); } return; + } else if (validateReplacements && getClass() == ArrowFileReader.class) { + throw new IllegalStateException("Replacement dictionaries are not supported in " + + "the IPC file format. Dictionary ID: " + dictionary.getEncoding().getId()); } load(dictionaryBatch, vector); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java index a0096aaf3ee56..0a417f370d7db 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java @@ -164,7 +164,7 @@ public boolean loadNextBatch() throws IOException { } else if (result.getMessage().headerType() == MessageHeader.DictionaryBatch) { // if it's dictionary message, read dictionary message out and continue to read unless get a batch or eos. ArrowDictionaryBatch dictionaryBatch = readDictionary(result); - loadDictionary(dictionaryBatch); + loadDictionary(dictionaryBatch, false); loadedDictionaryCount++; return loadNextBatch(); } else { @@ -177,10 +177,6 @@ public boolean loadNextBatch() throws IOException { * When read a record batch, check whether its dictionaries are available. */ private void checkDictionaries() throws IOException { - // if all dictionaries are loaded, return. - if (loadedDictionaryCount == dictionaries.size()) { - return; - } for (FieldVector vector : getVectorSchemaRoot().getFieldVectors()) { DictionaryEncoding encoding = vector.getField().getDictionary(); if (encoding != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java index 928e1de4c5f6b..1cf43ecb20787 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java @@ -26,13 +26,11 @@ import java.util.Optional; import java.util.Set; -import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; @@ -41,7 +39,7 @@ * Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel. */ public class ArrowStreamWriter extends ArrowWriter { - private final Map previousDictionaries = new HashMap<>(); + private final Map previousDictionaries = new HashMap<>(); /** * Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream. @@ -135,30 +133,16 @@ protected void ensureDictionariesWritten(DictionaryProvider provider, Set throws IOException { // write out any dictionaries that have changes for (long id : dictionaryIdsUsed) { - Dictionary dictionary = provider.lookup(id); - FieldVector vector = dictionary.getVector(); - if (previousDictionaries.containsKey(id) && - VectorEqualsVisitor.vectorEquals(vector, previousDictionaries.get(id))) { - // Dictionary was previously written and hasn't changed - continue; - } - writeDictionaryBatch(dictionary); - // Store a copy of the vector in case it is later mutated - if (previousDictionaries.containsKey(id)) { - previousDictionaries.get(id).close(); - } - previousDictionaries.put(id, copyVector(vector)); + BaseDictionary dictionary = provider.lookup(id); + boolean isInitial = previousDictionaries.containsKey(id) ? false : true; + writeDictionaryBatch(dictionary, isInitial); + previousDictionaries.put(id, true); } } @Override public void close() { super.close(); - try { - AutoCloseables.close(previousDictionaries.values()); - } catch (Exception e) { - throw new RuntimeException(e); - } } private static FieldVector copyVector(FieldVector source) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index a33c55de53f23..112961d9ed7cd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -32,7 +32,7 @@ import org.apache.arrow.vector.compression.CompressionCodec; import org.apache.arrow.vector.compression.CompressionUtil; import org.apache.arrow.vector.compression.NoCompressionCodec; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -123,9 +123,13 @@ public void writeBatch() throws IOException { try (ArrowRecordBatch batch = unloader.getRecordBatch()) { writeRecordBatch(batch); } + if (dictionaryProvider != null) { + dictionaryProvider.resetDictionaries(); + } } - protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { + protected void writeDictionaryBatch(BaseDictionary dictionary, boolean isInitial) throws IOException { + dictionary.mark(); FieldVector vector = dictionary.getVector(); long id = dictionary.getEncoding().getId(); int count = vector.getValueCount(); @@ -135,7 +139,8 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { count); VectorUnloader unloader = new VectorUnloader(dictRoot); ArrowRecordBatch batch = unloader.getRecordBatch(); - ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false); + boolean isDelta = isInitial ? false : dictionary.getEncoding().isDelta(); + ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, isDelta); try { writeDictionaryBatch(dictionaryBatch); } finally { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 0c23a664f62d6..d46a8f23f46dd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -58,7 +58,7 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.TypeLayout; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.Types; @@ -87,7 +87,7 @@ public class JsonFileReader implements AutoCloseable, DictionaryProvider { private final JsonParser parser; private final BufferAllocator allocator; private Schema schema; - private Map dictionaries; + private Map dictionaries; private Boolean started = false; /** @@ -108,7 +108,7 @@ public JsonFileReader(File inputFile, BufferAllocator allocator) throws JsonPars } @Override - public Dictionary lookup(long id) { + public BaseDictionary lookup(long id) { if (!started) { throw new IllegalStateException("Unable to lookup until after read() has started"); } @@ -155,7 +155,7 @@ private void readDictionaryBatches() throws JsonParseException, IOException { // Lookup what dictionary for the batch about to be read long id = readNextField("id", Long.class); - Dictionary dict = dictionaries.get(id); + BaseDictionary dict = dictionaries.get(id); if (dict == null) { throw new IllegalArgumentException("Dictionary with id: " + id + " missing encoding from schema Field"); } @@ -259,6 +259,11 @@ public int skip(int numBatches) throws IOException { return numBatches; } + @Override + public void resetDictionaries() { + // no-op + } + private abstract class BufferReader { protected abstract ArrowBuf read(BufferAllocator allocator, int count) throws IOException; @@ -805,7 +810,7 @@ private byte[] decodeHexSafe(String hexString) throws IOException { public void close() throws IOException { parser.close(); if (dictionaries != null) { - for (Dictionary dictionary : dictionaries.values()) { + for (BaseDictionary dictionary : dictionaries.values()) { dictionary.getVector().close(); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index f5e267e81256c..e0be27d468ab5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -68,7 +68,7 @@ import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -173,7 +173,7 @@ private void writeDictionaryBatches(JsonGenerator generator, Set dictionar generator.writeObjectField("id", id); generator.writeFieldName("data"); - Dictionary dictionary = provider.lookup(id); + BaseDictionary dictionary = provider.lookup(id); FieldVector vector = dictionary.getVector(); List fields = Collections.singletonList(vector.getField()); List vectors = Collections.singletonList(vector); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java b/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java index 9f645b64bc5f6..eb766c617682a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/table/BaseTable.java @@ -29,7 +29,7 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.reader.FieldReader; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.Field; @@ -385,7 +385,7 @@ public DictionaryProvider getDictionaryProvider() { * @return A ValueVector */ public ValueVector decode(String vectorName, long dictionaryId) { - Dictionary dictionary = getDictionary(dictionaryId); + BaseDictionary dictionary = getDictionary(dictionaryId); FieldVector vector = getVector(vectorName); if (vector == null) { @@ -405,7 +405,7 @@ public ValueVector decode(String vectorName, long dictionaryId) { * @return A ValueVector */ public ValueVector encode(String vectorName, long dictionaryId) { - Dictionary dictionary = getDictionary(dictionaryId); + BaseDictionary dictionary = getDictionary(dictionaryId); FieldVector vector = getVector(vectorName); if (vector == null) { throw new IllegalArgumentException( @@ -419,12 +419,12 @@ public ValueVector encode(String vectorName, long dictionaryId) { * Returns the dictionary with given id. * @param dictionaryId A long integer that is the id returned by the dictionary's getId() method */ - private Dictionary getDictionary(long dictionaryId) { + private BaseDictionary getDictionary(long dictionaryId) { if (dictionaryProvider == null) { throw new IllegalStateException("No dictionary provider is present in table."); } - Dictionary dictionary = dictionaryProvider.lookup(dictionaryId); + BaseDictionary dictionary = dictionaryProvider.lookup(dictionaryId); if (dictionary == null) { throw new IllegalArgumentException("No dictionary with id '%n' exists in the table"); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java b/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java index 5768bb0ec75ec..7442342f37986 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/table/Table.java @@ -28,6 +28,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -113,13 +114,18 @@ public Table copy() { Dictionary[] dictionaryCopies = new Dictionary[ids.size()]; int i = 0; for (Long id : ids) { - Dictionary src = dictionaryProvider.lookup(id); + BaseDictionary src = dictionaryProvider.lookup(id); FieldVector srcVector = src.getVector(); FieldVector destVector = srcVector.getField().createVector(srcVector.getAllocator()); destVector.copyFromSafe(0, srcVector.getValueCount(), srcVector); // TODO: Remove safe copy for perf DictionaryEncoding srcEncoding = src.getEncoding(); Dictionary dest = new Dictionary(destVector, - new DictionaryEncoding(srcEncoding.getId(), srcEncoding.isOrdered(), srcEncoding.getIndexType())); + new DictionaryEncoding( + srcEncoding.getId(), + srcEncoding.isOrdered(), + srcEncoding.getIndexType(), + srcEncoding.isDelta() + )); dictionaryCopies[i] = dest; i++; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java index 8d41b92d867e9..3d7f86dcf6c68 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java @@ -33,6 +33,7 @@ public class DictionaryEncoding { private final long id; private final boolean ordered; private final Int indexType; + private final boolean isDelta; /** * Constructs a new instance. @@ -42,14 +43,29 @@ public class DictionaryEncoding { * @param indexType (nullable). The integer type to use for indexing in the dictionary. Defaults to a signed * 32 bit integer. */ + public DictionaryEncoding(long id, boolean ordered, Int indexType) { + this(id, ordered, indexType, false); + } + + /** + * Constructs a new instance. + * + * @param id The ID of the dictionary to use for encoding. + * @param ordered Whether the keys in values in the dictionary are ordered. + * @param indexType (nullable). The integer type to use for indexing in the dictionary. Defaults to a signed + * 32 bit integer. + * @param isDelta Whether the dictionary is a delta dictionary. + */ @JsonCreator public DictionaryEncoding( @JsonProperty("id") long id, @JsonProperty("isOrdered") boolean ordered, - @JsonProperty("indexType") Int indexType) { + @JsonProperty("indexType") Int indexType, + @JsonProperty("isDelta") Boolean isDelta) { this.id = id; this.ordered = ordered; this.indexType = indexType == null ? new Int(32, true) : indexType; + this.isDelta = isDelta == null ? false : isDelta; } public long getId() { @@ -65,6 +81,11 @@ public Int getIndexType() { return indexType; } + @JsonGetter("isDelta") + public boolean isDelta() { + return isDelta; + } + @Override public String toString() { return "DictionaryEncoding[id=" + id + ",ordered=" + ordered + ",indexType=" + indexType + "]"; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java index 9592f3975ab99..b6b4dd83b58b2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java @@ -24,6 +24,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -58,7 +59,7 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se children = field.getChildren(); } else { long id = encoding.getId(); - Dictionary dictionary = provider.lookup(id); + BaseDictionary dictionary = provider.lookup(id); if (dictionary == null) { throw new IllegalArgumentException("Could not find dictionary with ID " + id); } @@ -104,7 +105,7 @@ public static boolean needConvertToMessageFormat(Field field) { * Convert field and child fields that have a dictionary encoding to memory format, so fields * have the index type. */ - public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map dictionaries) { + public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map dictionaries) { DictionaryEncoding encoding = field.getDictionary(); List children = field.getChildren(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java index 0c9ad1e2753f1..a0bc187c8f838 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java @@ -24,7 +24,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -70,8 +70,8 @@ public static void compareDictionaries( } long id = encodings1.get(i).getId(); - Dictionary dict1 = provider1.lookup(id); - Dictionary dict2 = provider2.lookup(id); + BaseDictionary dict1 = provider1.lookup(id); + BaseDictionary dict2 = provider2.lookup(id); if (dict1 == null || dict2 == null) { throw new IllegalArgumentException("The DictionaryProvider did not contain the required " + @@ -101,8 +101,8 @@ public static void compareDictionaryProviders( ids1 + "\n" + ids2); } for (long id : ids1) { - Dictionary dict1 = provider1.lookup(id); - Dictionary dict2 = provider2.lookup(id); + BaseDictionary dict1 = provider1.lookup(id); + BaseDictionary dict2 = provider2.lookup(id); try { compareFieldVectors(dict1.getVector(), dict2.getVector()); } catch (IllegalArgumentException e) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java b/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java index c5de380f9c173..7b8c8e2fb6b6a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/VectorAppender.java @@ -108,8 +108,8 @@ public ValueVector visit(BaseVariableWidthVector deltaVector, Void value) { int newValueCount = targetVector.getValueCount() + deltaVector.getValueCount(); - int targetDataSize = targetVector.getOffsetBuffer().getInt( - (long) targetVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); + int targetDataSize = targetVector.getValueCount() > 0 ? targetVector.getOffsetBuffer().getInt( + (long) targetVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH) : 0; int deltaDataSize = deltaVector.getOffsetBuffer().getInt( (long) deltaVector.getValueCount() * BaseVariableWidthVector.OFFSET_WIDTH); int newValueCapacity = targetDataSize + deltaDataSize; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java b/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java new file mode 100644 index 0000000000000..72cc3e5aee663 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/dictionary/TestBatchedDictionary.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class TestBatchedDictionary { + + private static final DictionaryEncoding DELTA = + new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true); + private static final DictionaryEncoding SINGLE = + new DictionaryEncoding(24, false, new ArrowType.Int(16, false)); + private static final byte[] FOO = "foo".getBytes(StandardCharsets.UTF_8); + private static final byte[] BAR = "bar".getBytes(StandardCharsets.UTF_8); + private static final byte[] BAZ = "baz".getBytes(StandardCharsets.UTF_8); + + private BufferAllocator allocator; + + private static List validDictionaryTypes = Arrays.asList( + new ArrowType.Utf8(), + ArrowType.Binary.INSTANCE + ); + private static List invalidDictionaryTypes = Arrays.asList( + new ArrowType.LargeUtf8(), + new ArrowType.LargeBinary(), + new ArrowType.Bool(), + new ArrowType.Int(8, false) + ); + private static List validIndexTypes = Arrays.asList( + new ArrowType.Int(16, false), + new ArrowType.Int(8, false), + new ArrowType.Int(32, false), + new ArrowType.Int(64, false), + new ArrowType.Int(16, true), + new ArrowType.Int(8, true), + new ArrowType.Int(32, true), + new ArrowType.Int(64, true) + ); + private static List invalidIndexTypes = Arrays.asList( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + new ArrowType.Bool(), + new ArrowType.Utf8() + ); + + @BeforeEach + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @AfterEach + public void shutdown() { + allocator.close(); + } + + public static Collection validTypes() { + List params = new ArrayList<>(); + for (ArrowType dictType : validDictionaryTypes) { + for (ArrowType indexType : validIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + return params; + } + + public static Collection invalidTypes() { + List params = new ArrayList<>(); + for (ArrowType dictType : invalidDictionaryTypes) { + for (ArrowType indexType : validIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + for (ArrowType dictType : validDictionaryTypes) { + for (ArrowType indexType : invalidIndexTypes) { + params.add(Arguments.arguments(dictType, indexType)); + } + } + return params; + } + + @ParameterizedTest + @MethodSource("validTypes") + public void testValidDictionaryTypes(ArrowType dictType, ArrowType indexType) throws IOException { + new BatchedDictionary( + "vector", + DELTA, + dictType, + indexType, + allocator + ).close(); + } + + @ParameterizedTest + @MethodSource("validTypes") + public void testValidDictionaryVectors(ArrowType dictType, ArrowType indexType) throws IOException { + try (FieldVector dictVector = new FieldType(false, dictType, SINGLE) + .createNewSingleVector("d", allocator, null); + FieldVector indexVector = new FieldType(false, indexType, SINGLE) + .createNewSingleVector("i", allocator, null);) { + new BatchedDictionary( + dictVector, + indexVector + ).close(); + } + } + + @ParameterizedTest + @MethodSource("invalidTypes") + public void testInvalidTypes(ArrowType dictType, ArrowType indexType) { + assertThrows(IllegalArgumentException.class, () -> { + new BatchedDictionary( + "vector", + SINGLE, + dictType, + indexType, + allocator + ); + }); + } + + @ParameterizedTest + @MethodSource("invalidTypes") + public void testInvalidValidVectors(ArrowType dictType, ArrowType indexType) { + assertThrows(IllegalArgumentException.class, () -> { + try (FieldVector dictVector = new FieldType(false, dictType, SINGLE) + .createNewSingleVector("d", allocator, null); + FieldVector indexVector = new FieldType(false, indexType, SINGLE) + .createNewSingleVector("i", allocator, null);) { + new BatchedDictionary( + dictVector, + indexVector + ).close(); + } + }); + } + + @Test + public void testSuffix() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + assertEquals("vector-dictionary", dictionary.getVector().getField().getName()); + assertEquals("vector", dictionary.getIndexVector().getField().getName()); + } + } + + @Test + public void testTypedUnique() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, BAZ); + dictionary.mark(); + assertEquals(3, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, "foo", "bar", "baz"); + } + } + + @Test + public void testExistingUnique() throws IOException { + List vectors = existingVectors(SINGLE); + try (BatchedDictionary dictionary = new BatchedDictionary( + vectors.get(0), + vectors.get(1) + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, BAZ); + dictionary.mark(); + assertEquals(3, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, "foo", "bar", "baz"); + } + vectors.forEach(vector -> vector.close()); + } + + @Test + public void testTypedUniqueNulls() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setNull(0); + dictionary.setSafe(1, BAR); + dictionary.setNull(2); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, null, "bar", null); + } + } + + @Test + public void testExistingUniqueNulls() throws IOException { + List vectors = existingVectors(SINGLE); + try (BatchedDictionary dictionary = new BatchedDictionary( + vectors.get(0), + vectors.get(1) + )) { + dictionary.setNull(0); + dictionary.setSafe(1, BAR); + dictionary.setNull(2); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertDecoded(dictionary, null, "bar", null); + } + vectors.forEach(vector -> vector.close()); + } + + @Test + public void testTypedReused() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.setSafe(2, FOO); + dictionary.setSafe(3, FOO); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(4); + assertDecoded(dictionary, "foo", "bar", "foo", "foo"); + } + } + + @Test + public void testTypedResetReplacement() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "foo", "bar"); + + dictionary.reset(); + assertEquals(0, dictionary.getHashTable().size); + dictionary.setSafe(0, BAZ); + dictionary.setNull(1); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "baz", null); + } + } + + @Test + public void testTypedResetDelta() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + DELTA, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setSafe(0, FOO); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(2, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, "foo", "bar"); + + dictionary.reset(); + assertEquals(2, dictionary.getHashTable().size); + dictionary.setSafe(0, BAZ); + dictionary.setSafe(1, FOO); + dictionary.setSafe(2, BAR); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(3); + assertEquals(3, dictionary.getHashTable().size); + + // on read the dictionaries must be merged. Let's look at the int index. + UInt2Vector index = (UInt2Vector) dictionary.getIndexVector(); + assertEquals(2, index.get(0)); + assertEquals(0, index.get(1)); + assertEquals(1, index.get(2)); + } + } + + @Test + public void testTypedNullData() throws IOException { + try (BatchedDictionary dictionary = new BatchedDictionary( + "vector", + SINGLE, + validDictionaryTypes.get(0), + validIndexTypes.get(0), + allocator + )) { + dictionary.setSafe(0, null); + dictionary.setSafe(1, BAR); + dictionary.mark(); + assertEquals(1, dictionary.getVector().getValueCount()); + assertEquals(0, dictionary.getIndexVector().getValueCount()); + dictionary.getIndexVector().setValueCount(2); + assertDecoded(dictionary, null, "bar"); + } + } + + void assertDecoded(BatchedDictionary dictionary, String... expected) { + try (ValueVector decoded = DictionaryEncoder.decode(dictionary.getIndexVector(), dictionary)) { + assertEquals(expected.length, decoded.getValueCount()); + for (int i = 0; i < expected.length; i++) { + if (expected[i] == null) { + assertNull(decoded.getObject(i)); + } else { + assertNotNull(decoded.getObject(i)); + assertEquals(expected[i], decoded.getObject(i).toString()); + } + } + } + } + + private List existingVectors(DictionaryEncoding encoding) { + FieldVector dictionaryVector = new FieldType(false, validDictionaryTypes.get(0), null) + .createNewSingleVector("vector-dictionary", allocator, null); + FieldVector indexVector = new FieldType(true, validIndexTypes.get(0), encoding) + .createNewSingleVector("fector", allocator, null); + return Arrays.asList(new FieldVector[] { dictionaryVector, indexVector }); + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java index 8663c0c49990d..305ee88bb5c11 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java @@ -20,18 +20,29 @@ import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.time.LocalTime; import java.time.ZoneId; import java.time.ZoneOffset; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -48,6 +59,7 @@ import org.apache.arrow.vector.UInt2Vector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -74,6 +86,8 @@ import org.apache.arrow.vector.complex.writer.UInt2Writer; import org.apache.arrow.vector.complex.writer.UInt4Writer; import org.apache.arrow.vector.complex.writer.UInt8Writer; +import org.apache.arrow.vector.dictionary.BaseDictionary; +import org.apache.arrow.vector.dictionary.BatchedDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -87,6 +101,7 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; +import org.junit.jupiter.params.provider.Arguments; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -400,7 +415,7 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(2, vector2.getObject(4)); Assert.assertEquals(null, vector2.getObject(5)); - Dictionary dictionary1 = provider.lookup(1L); + BaseDictionary dictionary1 = provider.lookup(1L); Assert.assertNotNull(dictionary1); VarCharVector dictionaryVector = ((VarCharVector) dictionary1.getVector()); Assert.assertEquals(3, dictionaryVector.getValueCount()); @@ -408,7 +423,7 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(new Text("bar"), dictionaryVector.getObject(1)); Assert.assertEquals(new Text("baz"), dictionaryVector.getObject(2)); - Dictionary dictionary2 = provider.lookup(2L); + BaseDictionary dictionary2 = provider.lookup(2L); Assert.assertNotNull(dictionary2); dictionaryVector = ((VarCharVector) dictionary2.getVector()); Assert.assertEquals(3, dictionaryVector.getValueCount()); @@ -470,7 +485,7 @@ protected void validateNestedDictionary(VectorSchemaRoot root, DictionaryProvide Assert.assertEquals(Arrays.asList(0), vector.getObject(1)); Assert.assertEquals(Arrays.asList(1), vector.getObject(2)); - Dictionary dictionary = provider.lookup(2L); + BaseDictionary dictionary = provider.lookup(2L); Assert.assertNotNull(dictionary); VarCharVector dictionaryVector = ((VarCharVector) dictionary.getVector()); Assert.assertEquals(2, dictionaryVector.getValueCount()); @@ -846,4 +861,306 @@ protected void validateListAsMapData(VectorSchemaRoot root) { } } } + + + enum DictionaryUTState { + /** One delta dictionary with updates in each batch. */ + ONE_DELTA, + /** A replacement dictionary that is only updated in the first batch. */ + ONE_REPLACEMENT_NOT_UPDATED, + /** The delta and non-updated replacement dictionaries. */ + DELTA_AND_REPLACEMENT_NOT_UPDATED, + /** Delta without any data in the first or last batches. */ + DELTA_MID_BATCH, + /** Two delta dictionaries for differentiation with different batches written to test + * read offset issues. */ + TWO_DELTAS, + /** Both deltas and the non-updated replacement dictionary. */ + TWO_DELTAS_AND_REPLACEMENT_NOT_UPDATED, + /** A replacement dictionary with updates between batches that should pass in stream + * uses but fail in file writes. */ + REPLACEMENT_UPDATED + } + + /** + * Utility to write permutations of dictionary encoding. + */ + protected void writeDataMultiBatchWithDictionaries(OutputStream stream, DictionaryUTState state) throws IOException { + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + DictionaryEncoding deltaEncoding = + new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true); + DictionaryEncoding replacementEncoding = + new DictionaryEncoding(24, false, new ArrowType.Int(16, false), false); + DictionaryEncoding deltaCEncoding = + new DictionaryEncoding(1, false, new ArrowType.Int(16, false), true); + DictionaryEncoding replacementEncodingUpdated = + new DictionaryEncoding(2, false, new ArrowType.Int(16, false), false); + + try (BatchedDictionary vectorA = newDictionary("vectorA", deltaEncoding, false); + BatchedDictionary vectorB = newDictionary("vectorB", replacementEncoding, true); + BatchedDictionary vectorC = newDictionary("vectorC", deltaCEncoding, false); + BatchedDictionary vectorD = newDictionary("vectorD", replacementEncodingUpdated, false);) { + switch (state) { + case ONE_DELTA: + provider.put(vectorA); + break; + case ONE_REPLACEMENT_NOT_UPDATED: + provider.put(vectorB); + break; + case DELTA_AND_REPLACEMENT_NOT_UPDATED: + provider.put(vectorA); + provider.put(vectorB); + break; + case DELTA_MID_BATCH: + provider.put(vectorC); + break; + case TWO_DELTAS: + provider.put(vectorA); + provider.put(vectorC); + break; + case TWO_DELTAS_AND_REPLACEMENT_NOT_UPDATED: + provider.put(vectorA); + provider.put(vectorB); + provider.put(vectorC); + break; + case REPLACEMENT_UPDATED: + provider.put(vectorD); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + + VectorSchemaRoot root = null; + switch (state) { + case ONE_DELTA: + root = VectorSchemaRoot.of(vectorA.getIndexVector()); + break; + case ONE_REPLACEMENT_NOT_UPDATED: + root = VectorSchemaRoot.of(vectorB.getIndexVector()); + break; + case DELTA_AND_REPLACEMENT_NOT_UPDATED: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorB.getIndexVector()); + break; + case DELTA_MID_BATCH: + root = VectorSchemaRoot.of(vectorC.getIndexVector()); + break; + case TWO_DELTAS: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorC.getIndexVector()); + break; + case TWO_DELTAS_AND_REPLACEMENT_NOT_UPDATED: + root = VectorSchemaRoot.of(vectorA.getIndexVector(), vectorB.getIndexVector(), vectorC.getIndexVector()); + break; + case REPLACEMENT_UPDATED: + root = VectorSchemaRoot.of(vectorD.getIndexVector()); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + + ArrowWriter arrowWriter = null; + try { + if (stream instanceof FileOutputStream) { + FileChannel channel = ((FileOutputStream) stream).getChannel(); + arrowWriter = new ArrowFileWriter(root, provider, channel); + } else { + arrowWriter = new ArrowStreamWriter(root, provider, stream); + } + + vectorA.setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(0, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(1, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorC.setNull(0); + vectorC.setNull(1); + vectorD.setSafe(0, "porro".getBytes(StandardCharsets.UTF_8)); + vectorD.setSafe(1, "amet".getBytes(StandardCharsets.UTF_8)); + + // batch 1 + arrowWriter.start(); + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 2 + vectorA.setSafe(0, "meep".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorC.setSafe(0, "qui".getBytes(StandardCharsets.UTF_8)); + vectorC.setSafe(1, "dolor".getBytes(StandardCharsets.UTF_8)); + vectorD.setSafe(0, "amet".getBytes(StandardCharsets.UTF_8)); + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 3 + vectorA.setNull(0); + vectorA.setNull(1); + vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8)); + vectorB.setNull(1); + vectorC.setNull(0); + vectorC.setSafe(1, "qui".getBytes(StandardCharsets.UTF_8)); + vectorD.setNull(0); + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(2); + arrowWriter.writeBatch(); + + // batch 4 + vectorA.setSafe(0, "bar".getBytes(StandardCharsets.UTF_8)); + vectorA.setSafe(1, "zap".getBytes(StandardCharsets.UTF_8)); + vectorB.setNull(0); + vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8)); + vectorC.setNull(0); + vectorC.setNull(1); + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + vectorD.setSafe(0, "quia".getBytes(StandardCharsets.UTF_8)); + } + vectorD.setNull(1); + + root.setRowCount(2); + arrowWriter.writeBatch(); + + arrowWriter.end(); + } catch (Exception e) { + if (arrowWriter != null) { + arrowWriter.close(); + } + throw e; + } + } + } + + Map valuesPerBlock = new HashMap(); + + { + valuesPerBlock.put(0, new String[][]{ + {"foo", "bar"}, + {"lorem", "ipsum"}, + {null, null}, + {"porro", "amet"} + }); + valuesPerBlock.put(1, new String[][]{ + {"meep", "bar"}, + {"ipsum", "lorem"}, + {"qui", "dolor"}, + {"amet", "quia"} + }); + valuesPerBlock.put(2, new String[][]{ + {null, null}, + {"ipsum", null}, + {null, "qui"}, + {null, "quia"} + }); + valuesPerBlock.put(3, new String[][]{ + {"bar", "zap"}, + {null, "lorem"}, + {null, null}, + {"quia", null} + }); + } + + protected void assertDictionary(FieldVector encoded, ArrowReader reader, String... expected) throws Exception { + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + BaseDictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals(expected.length, encoded.getValueCount()); + for (int i = 0; i < expected.length; i++) { + if (expected[i] == null) { + assertNull(decoded.getObject(i)); + } else { + assertNotNull(decoded.getObject(i)); + assertEquals(expected[i], decoded.getObject(i).toString()); + } + } + } + } + + protected void assertBlock(File file, int block, DictionaryUTState state) throws Exception { + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + reader.loadRecordBatch(reader.getRecordBlocks().get(block)); + assertBlock(reader, block, state); + } + } + + protected void assertBlock(ArrowReader reader, int block, DictionaryUTState state) throws Exception { + VectorSchemaRoot r = reader.getVectorSchemaRoot(); + FieldVector dictA = r.getVector("vectorA"); + FieldVector dictB = r.getVector("vectorB"); + FieldVector dictC = r.getVector("vectorC"); + FieldVector dictD = r.getVector("vectorD"); + + switch (state) { + case ONE_DELTA: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertNull(dictB); + assertNull(dictC); + assertNull(dictD); + break; + case ONE_REPLACEMENT_NOT_UPDATED: + assertNull(dictA); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertNull(dictC); + assertNull(dictD); + break; + case DELTA_AND_REPLACEMENT_NOT_UPDATED: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertNull(dictC); + assertNull(dictD); + break; + case DELTA_MID_BATCH: + assertNull(dictA); + assertNull(dictB); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case TWO_DELTAS: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertNull(dictB); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case TWO_DELTAS_AND_REPLACEMENT_NOT_UPDATED: + assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]); + assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]); + assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]); + assertNull(dictD); + break; + case REPLACEMENT_UPDATED: + assertNull(dictA); + assertNull(dictB); + assertNull(dictC); + assertDictionary(dictD, reader, valuesPerBlock.get(block)[3]); + break; + default: + throw new IllegalStateException("Unsupported state: " + state); + } + } + + protected static Collection dictionaryParams() { + List params = new ArrayList<>(); + // number of unique states from writeDataMultiBatchWithDictionaries + for (DictionaryUTState state : DictionaryUTState.values()) { + params.add(Arguments.of(state)); + } + return params; + } + + protected BatchedDictionary newDictionary(String name, DictionaryEncoding encoding, boolean oneTimeEncoding) { + return new BatchedDictionary( + name, + encoding, + new ArrowType.Utf8(), + new ArrowType.Int(16, false), + allocator, + "-dictionary", + oneTimeEncoding + ); + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/ITTestIPCWithLargeArrowBuffers.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/ITTestIPCWithLargeArrowBuffers.java index d3c91fd144356..72ab46fc53906 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/ITTestIPCWithLargeArrowBuffers.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/ITTestIPCWithLargeArrowBuffers.java @@ -33,6 +33,7 @@ import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -150,9 +151,9 @@ private void testReadLargeArrowData(boolean streamMode) throws IOException { logger.trace("Verifying encoded vector finished"); // verify dictionary - Map dictVectors = reader.getDictionaryVectors(); + Map dictVectors = reader.getDictionaryVectors(); assertEquals(1, dictVectors.size()); - Dictionary dictionary = dictVectors.get(DICTIONARY_ID); + BaseDictionary dictionary = dictVectors.get(DICTIONARY_ID); assertNotNull(dictionary); assertTrue(dictionary.getVector() instanceof BigIntVector); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java index 4fb5822786083..bc6c537fe192a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java @@ -20,11 +20,13 @@ import static java.nio.channels.Channels.newChannel; import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -33,19 +35,36 @@ import java.util.List; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.types.pojo.Field; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TestArrowFile extends BaseFileTest { private static final Logger LOGGER = LoggerFactory.getLogger(TestArrowFile.class); + // overriding here since a number of other UTs sharing the BaseFileTest class use + // legacy JUnit. + @BeforeEach + public void init() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + @Test public void testWrite() throws IOException { File file = new File("target/mytest_write.arrow"); @@ -131,4 +150,67 @@ public void testFileStreamHasEos() throws IOException { } } } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionaries(DictionaryUTState state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + for (int i = 0; i < 4; i++) { + reader.loadNextBatch(); + assertBlock(reader, i, state); + } + } + } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionariesOutOfOrder(DictionaryUTState state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_ooo_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + int[] order = new int[] {2, 1, 3, 0}; + for (int i = 0; i < 4; i++) { + int block = order[i]; + reader.loadRecordBatch(reader.getRecordBlocks().get(block)); + assertBlock(reader, block, state); + } + } + } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionariesSeek(DictionaryUTState state) throws Exception { + File file = new File("target/mytest_multi_batch_dictionaries_seek_" + state + ".arrow"); + try (FileOutputStream stream = new FileOutputStream(file)) { + if (state == DictionaryUTState.REPLACEMENT_UPDATED) { + assertThrows(IllegalStateException.class, () -> writeDataMultiBatchWithDictionaries(stream, state)); + return; + } else { + writeDataMultiBatchWithDictionaries(stream, state); + } + } + for (int i = 0; i < 4; i++) { + assertBlock(file, i, state); + } + } + } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java index 07875b25029ea..f5b0607ae1e84 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java @@ -68,6 +68,7 @@ import org.apache.arrow.vector.compare.TypeEqualsVisitor; import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -376,8 +377,8 @@ public void testWriteReadWithStructDictionaries() throws IOException { .rangeEquals(new Range(0, 0, encodedVector.getValueCount()))); // Read the dictionary - final Map readDictionaryMap = reader.getDictionaryVectors(); - final Dictionary readDictionary = + final Map readDictionaryMap = reader.getDictionaryVectors(); + final BaseDictionary readDictionary = readDictionaryMap.get(readEncoded.getField().getDictionary().getId()); assertNotNull(readDictionary); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java index 9348cd3a66708..8da1f389bd1e8 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java @@ -27,14 +27,32 @@ import java.nio.channels.Channels; import java.util.Collections; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; public class TestArrowStream extends BaseFileTest { + + // overriding here since a number of other UTs sharing the BaseFileTest class use + // legacy JUnit. + @BeforeEach + public void init() { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + @Test public void testEmptyStream() throws IOException { Schema schema = MessageSerializerTest.testSchema(); @@ -64,7 +82,7 @@ public void testStreamZeroLengthBatch() throws IOException { try (IntVector vector = new IntVector("foo", allocator);) { Schema schema = new Schema(Collections.singletonList(vector.getField())); try (VectorSchemaRoot root = - new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); + new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) { vector.setValueCount(0); root.setRowCount(0); @@ -131,7 +149,7 @@ public void testReadWriteMultipleBatches() throws IOException { try (IntVector vector = new IntVector("foo", allocator);) { Schema schema = new Schema(Collections.singletonList(vector.getField())); try (VectorSchemaRoot root = - new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); + new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount()); ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) { writeBatchData(writer, vector, root); } @@ -144,4 +162,21 @@ public void testReadWriteMultipleBatches() throws IOException { validateBatchData(reader, vector); } } + + @ParameterizedTest + @MethodSource("dictionaryParams") + public void testMultiBatchDictionaries(DictionaryUTState state) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeDataMultiBatchWithDictionaries(out, state); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + out.close(); + + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { + for (int i = 0; i < 4; i++) { + reader.loadNextBatch(); + assertBlock(reader, i, state); + } + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java index 6aa7a0c6df5c3..0e6e27b443c43 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestUIntDictionaryRoundTrip.java @@ -42,6 +42,7 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.BaseDictionary; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -129,9 +130,9 @@ private void readData( } // verify dictionary - Map dictVectors = reader.getDictionaryVectors(); + Map dictVectors = reader.getDictionaryVectors(); assertEquals(1, dictVectors.size()); - Dictionary dictionary = dictVectors.get(dictionaryID); + BaseDictionary dictionary = dictVectors.get(dictionaryID); assertNotNull(dictionary); assertTrue(dictionary.getVector() instanceof VarCharVector);