From 92e2c733c6c7ff001fba6aa79d58c52fd56a64ad Mon Sep 17 00:00:00 2001 From: Chris Larsen Date: Mon, 9 Oct 2023 14:33:28 -0700 Subject: [PATCH] GH-38168: [Java] Fix multi-batch Dictionary in Arrow{Reader|Writer} When manually writing dictionary vectors and writing multiple batches in a single `ArrowFileWriter`, only the first dictionary batch was written and subsequent batches were ignored. On reading, the `ArrowFileReader` would load only the first batch and use that batch for decoding subsequent batches, resulting in errors or incorrect decodings. This patch will now flush the dictionaries on each batch write and load the batches for the dictionaries on read. Following the docs at https://arrow.apache.org/docs/format/Columnar.html#dictionary-messages. Note that this does not address the delta dictionary encoding issue as the writer does not currently havea means of setting the delta flag. Neither does it allow for streaming writes of dictionaries (though the unit tests show a work-around). Fix for #38168 --- .../arrow/vector/ipc/ArrowFileReader.java | 16 +- .../arrow/vector/ipc/ArrowFileWriter.java | 7 - .../arrow/vector/ipc/TestArrowFile.java | 366 +++++++++++++++++- 3 files changed, 375 insertions(+), 14 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java index 8629cf93470b8..585bc8ce2f450 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java @@ -123,11 +123,6 @@ public void initialize() throws IOException { if (footer.getRecordBatches().size() == 0) { return; } - // Read and load all dictionaries from schema - for (int i = 0; i < dictionaries.size(); i++) { - ArrowDictionaryBatch dictionaryBatch = readDictionary(); - loadDictionary(dictionaryBatch); - } } /** @@ -164,6 +159,13 @@ public boolean loadNextBatch() throws IOException { ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); ArrowRecordBatch batch = readRecordBatch(in, block, allocator); loadRecordBatch(batch); + + // Read and load all dictionaries from schema + for (int i = 0; i < dictionaries.size(); i++) { + ArrowDictionaryBatch dictionaryBatch = readDictionary(); + loadDictionary(dictionaryBatch); + } + return true; } else { return false; @@ -185,7 +187,7 @@ public List getRecordBlocks() throws IOException { } /** - * Loads record batch for the given block. + * Loads record batch and dictionaries for the given block. */ public boolean loadRecordBatch(ArrowBlock block) throws IOException { ensureInitialized(); @@ -193,6 +195,8 @@ public boolean loadRecordBatch(ArrowBlock block) throws IOException { if (blockIndex == -1) { throw new IllegalArgumentException("Arrow block does not exist in record batches: " + block); } + + currentDictionaryBatch = blockIndex * dictionaries.size(); currentRecordBatch = blockIndex; return loadNextBatch(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 71db79087a3e4..2dd136af1778a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -52,7 +52,6 @@ public class ArrowFileWriter extends ArrowWriter { private final List recordBlocks = new ArrayList<>(); private Map metaData; - private boolean dictionariesWritten = false; public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { super(root, provider, out); @@ -129,12 +128,6 @@ protected void endInternal(WriteChannel out) throws IOException { @Override protected void ensureDictionariesWritten(DictionaryProvider provider, Set dictionaryIdsUsed) throws IOException { - if (dictionariesWritten) { - return; - } - dictionariesWritten = true; - // Write out all dictionaries required. - // Replacement dictionaries are not supported in the IPC file format. for (long id : dictionaryIdsUsed) { Dictionary dictionary = provider.lookup(id); writeDictionaryBatch(dictionary); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java index 4fb5822786083..1739f0dbeffb9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java @@ -19,26 +19,39 @@ import static java.nio.channels.Channels.newChannel; import static org.apache.arrow.vector.TestUtils.newVarCharVector; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,6 +85,179 @@ public void testWriteComplex() throws IOException { } } + @Test + public void testMultiBatchWithOneDictionary() throws Exception { + File file = new File("target/mytest_multi_dictionary.arrow"); + writeSingleDictionary(file); + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadNextBatch()); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("bar", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("foo", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the 3rd batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(2))); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the first batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(0))); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + } + } + + @Test + public void testMultiBatchWithTwoDictionaries() throws Exception { + File file = new File("target/mytest_multi_dictionaries.arrow"); + writeTwoDictionaries(file); + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadNextBatch()); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + FieldVector encoded2 = root.getVector("vector2"); + DictionaryEncoding dictionaryEncoding2 = encoded2.getField().getDictionary(); + Dictionary dictionary2 = reader.getDictionaryVectors().get(dictionaryEncoding2.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + assertNull(decoded.getObject(0)); + assertNull(decoded.getObject(1)); + assertEquals("bur", decoded.getObject(2).toString()); + assertNull(decoded.getObject(3)); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("bar", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("foo", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + assertEquals("arg", decoded.getObject(0).toString()); + assertNull(decoded.getObject(1)); + assertNull(decoded.getObject(2)); + assertNull(decoded.getObject(3)); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + for (int i = 0; i < 4; i++) { + assertNull(decoded.getObject(i)); + } + } + } + + // load just the 3rd batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(2))); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the first batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(0))); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + } + } + /** * Writes the contents of parents to file. If outStream is non-null, also writes it * to outStream in the streaming serialized format. @@ -99,7 +285,6 @@ private void write(FieldVector parent, File file, OutputStream outStream) throws @Test public void testFileStreamHasEos() throws IOException { - try (VarCharVector vector1 = newVarCharVector("varchar1", allocator)) { vector1.allocateNewSafe(); vector1.set(0, "foo".getBytes(StandardCharsets.UTF_8)); @@ -131,4 +316,183 @@ public void testFileStreamHasEos() throws IOException { } } } + + private void writeSingleDictionary(File file) throws Exception { + Map stringToIndex = new HashMap<>(); + + try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator)) { + DictionaryEncoding dictionaryEncoding = new DictionaryEncoding(42, false, new ArrowType.Int(16, false)); + + Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary); + + try (UInt2Vector vector = new UInt2Vector( + "vector", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding), + allocator)) { + vector.allocateNew(4); + dictionaryVector.allocateNew(2); + + dictionaryVector.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("foo", 0); + dictionaryVector.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("bar", 1); + + vector.set(0, stringToIndex.get("foo")); + vector.set(1, stringToIndex.get("bar")); + vector.set(2, stringToIndex.get("bar")); + vector.set(3, stringToIndex.get("foo")); + + VectorSchemaRoot root = VectorSchemaRoot.of(dictionaryVector, vector); + root.setRowCount(4); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel());) { + + // batch 1 + arrowWriter.start(); + arrowWriter.writeBatch(); + dictionaryVector.reset(); + vector.reset(); + stringToIndex.clear(); + + // batch 2 + // note the order is different for the strings + dictionaryVector.set(0, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("bar", 0); + dictionaryVector.set(1, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("foo", 1); + + vector.set(0, stringToIndex.get("bar")); + vector.set(1, stringToIndex.get("bar")); + vector.set(2, stringToIndex.get("foo")); + vector.set(3, stringToIndex.get("foo")); + + root.setRowCount(4); + + arrowWriter.writeBatch(); + + // batch 3 + dictionaryVector.reset(); + vector.reset(); + + stringToIndex.clear(); + dictionaryVector.set(0, "baz".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("baz", 0); + + vector.set(0, stringToIndex.get("baz")); + vector.set(1, stringToIndex.get("baz")); + + root.setRowCount(2); + + arrowWriter.writeBatch(); + arrowWriter.end(); + } + } + } + } + + private void writeTwoDictionaries(File file) throws Exception { + Map stringToIndex1 = new HashMap<>(); + Map stringToIndex2 = new HashMap<>(); + + try (VarCharVector dictionaryVector1 = new VarCharVector("dict1", allocator); + VarCharVector dictionaryVector2 = new VarCharVector("dict2", allocator)) { + DictionaryEncoding dictionaryEncoding1 = new DictionaryEncoding(1, false, new ArrowType.Int(16, false)); + DictionaryEncoding dictionaryEncoding2 = new DictionaryEncoding(2, false, new ArrowType.Int(16, false)); + + Dictionary dictionary1 = new Dictionary(dictionaryVector1, dictionaryEncoding1); + Dictionary dictionary2 = new Dictionary(dictionaryVector2, dictionaryEncoding2); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary1); + provider.put(dictionary2); + + try (UInt2Vector vector1 = new UInt2Vector( + "vector1", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding1), + allocator); + UInt2Vector vector2 = new UInt2Vector( + "vector2", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding2), + allocator)) { + vector1.allocateNew(4); + vector2.allocateNew(4); + dictionaryVector1.allocateNew(2); + dictionaryVector2.allocateNew(1); + + dictionaryVector1.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("foo", 0); + dictionaryVector1.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("bar", 1); + dictionaryVector2.set(0, "bur".getBytes(StandardCharsets.UTF_8)); + stringToIndex2.put("bur", 0); + + vector1.set(0, stringToIndex1.get("foo")); + vector1.set(1, stringToIndex1.get("bar")); + vector1.set(2, stringToIndex1.get("bar")); + vector1.set(3, stringToIndex1.get("foo")); + + vector2.set(2, stringToIndex2.get("bur")); + + VectorSchemaRoot root = VectorSchemaRoot.of(dictionaryVector1, vector1, dictionaryVector2, vector2); + root.setRowCount(4); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel());) { + + // batch 1 + arrowWriter.start(); + arrowWriter.writeBatch(); + System.out.println("WRITE vector: " + vector2); + System.out.println("WRITE dict: " + dictionaryVector2); + dictionaryVector1.reset(); + dictionaryVector2.reset(); + vector1.reset(); + vector2.reset(); + stringToIndex1.clear(); + stringToIndex2.clear(); + + // batch 2 + // note the order is different for the strings + dictionaryVector1.set(0, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("bar", 0); + dictionaryVector1.set(1, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("foo", 1); + dictionaryVector2.set(0, "arg".getBytes(StandardCharsets.UTF_8)); + stringToIndex2.put("arg", 0); + + vector1.set(0, stringToIndex1.get("bar")); + vector1.set(1, stringToIndex1.get("bar")); + vector1.set(2, stringToIndex1.get("foo")); + vector1.set(3, stringToIndex1.get("foo")); + + vector2.set(0, stringToIndex2.get("arg")); + + root.setRowCount(4); + + arrowWriter.writeBatch(); + + // batch 3 + dictionaryVector1.reset(); + dictionaryVector2.reset(); + vector1.reset(); + vector2.reset(); + stringToIndex1.clear(); + stringToIndex2.clear(); + + dictionaryVector1.set(0, "baz".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("baz", 0); + + vector1.set(0, stringToIndex1.get("baz")); + vector1.set(1, stringToIndex1.get("baz")); + + // nothing for vector 2 + + root.setRowCount(2); + + arrowWriter.writeBatch(); + arrowWriter.end(); + } + } + } + } }