diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java index 8629cf93470b8..585bc8ce2f450 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileReader.java @@ -123,11 +123,6 @@ public void initialize() throws IOException { if (footer.getRecordBatches().size() == 0) { return; } - // Read and load all dictionaries from schema - for (int i = 0; i < dictionaries.size(); i++) { - ArrowDictionaryBatch dictionaryBatch = readDictionary(); - loadDictionary(dictionaryBatch); - } } /** @@ -164,6 +159,13 @@ public boolean loadNextBatch() throws IOException { ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); ArrowRecordBatch batch = readRecordBatch(in, block, allocator); loadRecordBatch(batch); + + // Read and load all dictionaries from schema + for (int i = 0; i < dictionaries.size(); i++) { + ArrowDictionaryBatch dictionaryBatch = readDictionary(); + loadDictionary(dictionaryBatch); + } + return true; } else { return false; @@ -185,7 +187,7 @@ public List getRecordBlocks() throws IOException { } /** - * Loads record batch for the given block. + * Loads record batch and dictionaries for the given block. */ public boolean loadRecordBatch(ArrowBlock block) throws IOException { ensureInitialized(); @@ -193,6 +195,8 @@ public boolean loadRecordBatch(ArrowBlock block) throws IOException { if (blockIndex == -1) { throw new IllegalArgumentException("Arrow block does not exist in record batches: " + block); } + + currentDictionaryBatch = blockIndex * dictionaries.size(); currentRecordBatch = blockIndex; return loadNextBatch(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 71db79087a3e4..2dd136af1778a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -52,7 +52,6 @@ public class ArrowFileWriter extends ArrowWriter { private final List recordBlocks = new ArrayList<>(); private Map metaData; - private boolean dictionariesWritten = false; public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { super(root, provider, out); @@ -129,12 +128,6 @@ protected void endInternal(WriteChannel out) throws IOException { @Override protected void ensureDictionariesWritten(DictionaryProvider provider, Set dictionaryIdsUsed) throws IOException { - if (dictionariesWritten) { - return; - } - dictionariesWritten = true; - // Write out all dictionaries required. - // Replacement dictionaries are not supported in the IPC file format. for (long id : dictionaryIdsUsed) { Dictionary dictionary = provider.lookup(id); writeDictionaryBatch(dictionary); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java index 4fb5822786083..1739f0dbeffb9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java @@ -19,26 +19,39 @@ import static java.nio.channels.Channels.newChannel; import static org.apache.arrow.vector.TestUtils.newVarCharVector; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,6 +85,179 @@ public void testWriteComplex() throws IOException { } } + @Test + public void testMultiBatchWithOneDictionary() throws Exception { + File file = new File("target/mytest_multi_dictionary.arrow"); + writeSingleDictionary(file); + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadNextBatch()); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("bar", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("foo", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the 3rd batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(2))); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the first batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(0))); + + FieldVector encoded = root.getVector("vector"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + } + } + + @Test + public void testMultiBatchWithTwoDictionaries() throws Exception { + File file = new File("target/mytest_multi_dictionaries.arrow"); + writeTwoDictionaries(file); + + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadNextBatch()); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + FieldVector encoded2 = root.getVector("vector2"); + DictionaryEncoding dictionaryEncoding2 = encoded2.getField().getDictionary(); + Dictionary dictionary2 = reader.getDictionaryVectors().get(dictionaryEncoding2.getId()); + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + assertNull(decoded.getObject(0)); + assertNull(decoded.getObject(1)); + assertEquals("bur", decoded.getObject(2).toString()); + assertNull(decoded.getObject(3)); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("bar", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("foo", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + assertEquals("arg", decoded.getObject(0).toString()); + assertNull(decoded.getObject(1)); + assertNull(decoded.getObject(2)); + assertNull(decoded.getObject(3)); + } + + assertTrue(reader.loadNextBatch()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + + try (ValueVector decoded = DictionaryEncoder.decode(encoded2, dictionary2)) { + for (int i = 0; i < 4; i++) { + assertNull(decoded.getObject(i)); + } + } + } + + // load just the 3rd batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(2))); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("baz", decoded.getObject(0).toString()); + assertEquals("baz", decoded.getObject(1).toString()); + } + } + + // load just the first batch. + try (FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(reader.getRecordBlocks().size(), 3); + assertTrue(reader.loadRecordBatch(reader.getRecordBlocks().get(0))); + + FieldVector encoded = root.getVector("vector1"); + DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary(); + Dictionary dictionary = reader.getDictionaryVectors().get(dictionaryEncoding.getId()); + + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals("foo", decoded.getObject(0).toString()); + assertEquals("bar", decoded.getObject(1).toString()); + assertEquals("bar", decoded.getObject(2).toString()); + assertEquals("foo", decoded.getObject(3).toString()); + } + } + } + /** * Writes the contents of parents to file. If outStream is non-null, also writes it * to outStream in the streaming serialized format. @@ -99,7 +285,6 @@ private void write(FieldVector parent, File file, OutputStream outStream) throws @Test public void testFileStreamHasEos() throws IOException { - try (VarCharVector vector1 = newVarCharVector("varchar1", allocator)) { vector1.allocateNewSafe(); vector1.set(0, "foo".getBytes(StandardCharsets.UTF_8)); @@ -131,4 +316,183 @@ public void testFileStreamHasEos() throws IOException { } } } + + private void writeSingleDictionary(File file) throws Exception { + Map stringToIndex = new HashMap<>(); + + try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator)) { + DictionaryEncoding dictionaryEncoding = new DictionaryEncoding(42, false, new ArrowType.Int(16, false)); + + Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary); + + try (UInt2Vector vector = new UInt2Vector( + "vector", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding), + allocator)) { + vector.allocateNew(4); + dictionaryVector.allocateNew(2); + + dictionaryVector.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("foo", 0); + dictionaryVector.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("bar", 1); + + vector.set(0, stringToIndex.get("foo")); + vector.set(1, stringToIndex.get("bar")); + vector.set(2, stringToIndex.get("bar")); + vector.set(3, stringToIndex.get("foo")); + + VectorSchemaRoot root = VectorSchemaRoot.of(dictionaryVector, vector); + root.setRowCount(4); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel());) { + + // batch 1 + arrowWriter.start(); + arrowWriter.writeBatch(); + dictionaryVector.reset(); + vector.reset(); + stringToIndex.clear(); + + // batch 2 + // note the order is different for the strings + dictionaryVector.set(0, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("bar", 0); + dictionaryVector.set(1, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("foo", 1); + + vector.set(0, stringToIndex.get("bar")); + vector.set(1, stringToIndex.get("bar")); + vector.set(2, stringToIndex.get("foo")); + vector.set(3, stringToIndex.get("foo")); + + root.setRowCount(4); + + arrowWriter.writeBatch(); + + // batch 3 + dictionaryVector.reset(); + vector.reset(); + + stringToIndex.clear(); + dictionaryVector.set(0, "baz".getBytes(StandardCharsets.UTF_8)); + stringToIndex.put("baz", 0); + + vector.set(0, stringToIndex.get("baz")); + vector.set(1, stringToIndex.get("baz")); + + root.setRowCount(2); + + arrowWriter.writeBatch(); + arrowWriter.end(); + } + } + } + } + + private void writeTwoDictionaries(File file) throws Exception { + Map stringToIndex1 = new HashMap<>(); + Map stringToIndex2 = new HashMap<>(); + + try (VarCharVector dictionaryVector1 = new VarCharVector("dict1", allocator); + VarCharVector dictionaryVector2 = new VarCharVector("dict2", allocator)) { + DictionaryEncoding dictionaryEncoding1 = new DictionaryEncoding(1, false, new ArrowType.Int(16, false)); + DictionaryEncoding dictionaryEncoding2 = new DictionaryEncoding(2, false, new ArrowType.Int(16, false)); + + Dictionary dictionary1 = new Dictionary(dictionaryVector1, dictionaryEncoding1); + Dictionary dictionary2 = new Dictionary(dictionaryVector2, dictionaryEncoding2); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary1); + provider.put(dictionary2); + + try (UInt2Vector vector1 = new UInt2Vector( + "vector1", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding1), + allocator); + UInt2Vector vector2 = new UInt2Vector( + "vector2", + new FieldType(false, new ArrowType.Int(16, false), dictionaryEncoding2), + allocator)) { + vector1.allocateNew(4); + vector2.allocateNew(4); + dictionaryVector1.allocateNew(2); + dictionaryVector2.allocateNew(1); + + dictionaryVector1.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("foo", 0); + dictionaryVector1.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("bar", 1); + dictionaryVector2.set(0, "bur".getBytes(StandardCharsets.UTF_8)); + stringToIndex2.put("bur", 0); + + vector1.set(0, stringToIndex1.get("foo")); + vector1.set(1, stringToIndex1.get("bar")); + vector1.set(2, stringToIndex1.get("bar")); + vector1.set(3, stringToIndex1.get("foo")); + + vector2.set(2, stringToIndex2.get("bur")); + + VectorSchemaRoot root = VectorSchemaRoot.of(dictionaryVector1, vector1, dictionaryVector2, vector2); + root.setRowCount(4); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel());) { + + // batch 1 + arrowWriter.start(); + arrowWriter.writeBatch(); + System.out.println("WRITE vector: " + vector2); + System.out.println("WRITE dict: " + dictionaryVector2); + dictionaryVector1.reset(); + dictionaryVector2.reset(); + vector1.reset(); + vector2.reset(); + stringToIndex1.clear(); + stringToIndex2.clear(); + + // batch 2 + // note the order is different for the strings + dictionaryVector1.set(0, "bar".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("bar", 0); + dictionaryVector1.set(1, "foo".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("foo", 1); + dictionaryVector2.set(0, "arg".getBytes(StandardCharsets.UTF_8)); + stringToIndex2.put("arg", 0); + + vector1.set(0, stringToIndex1.get("bar")); + vector1.set(1, stringToIndex1.get("bar")); + vector1.set(2, stringToIndex1.get("foo")); + vector1.set(3, stringToIndex1.get("foo")); + + vector2.set(0, stringToIndex2.get("arg")); + + root.setRowCount(4); + + arrowWriter.writeBatch(); + + // batch 3 + dictionaryVector1.reset(); + dictionaryVector2.reset(); + vector1.reset(); + vector2.reset(); + stringToIndex1.clear(); + stringToIndex2.clear(); + + dictionaryVector1.set(0, "baz".getBytes(StandardCharsets.UTF_8)); + stringToIndex1.put("baz", 0); + + vector1.set(0, stringToIndex1.get("baz")); + vector1.set(1, stringToIndex1.get("baz")); + + // nothing for vector 2 + + root.setRowCount(2); + + arrowWriter.writeBatch(); + arrowWriter.end(); + } + } + } + } }