From 54c5e8b0175be5b9dedf0f32f2e8f555d7ded274 Mon Sep 17 00:00:00 2001 From: Scott Cowell Date: Wed, 10 Jan 2024 15:56:08 -0800 Subject: [PATCH] DX-85876: Failure in UnionReader.read after DecimalVector promotion to UnionVector (#61) When a DecimalVector is promoted to a UnionVector via a PromotableWriter, the UnionVector will have the decimal vector in it's internal struct vector, but the decimalVector field will not be set. If UnionReader.read is then used to read from the UnionVector, it will fail when it tries to read one of the promoted decimal values, due to decimalVector being null, and the exact decimal type not being provided. This failure is unnecessary though as we have a pre-existing decimal vector, the caller just does not know the exact type - and it shouldn't be required to. The change here is to check for a pre-existing decimal vector in the internal struct when getDecimalVector() is called. If one exists, set the decimalVector field and return. Otherwise, if none exists, throw the exception. --- .../main/codegen/templates/UnionVector.java | 6 ++- .../complex/impl/TestPromotableWriter.java | 43 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index ea79c5c2fba76..7724750a83d69 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseValueVector; import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -278,7 +279,10 @@ public StructVector getStruct() { <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> public ${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + } } return ${uncappedName}Vector; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 4c8c96a0d74d3..e80a52969870f 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -22,11 +22,13 @@ import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; @@ -39,8 +41,11 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableDecimalHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.UnionHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -48,6 +53,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.Text; +import org.apache.arrow.vector.util.DecimalUtility; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -589,4 +595,41 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { assertEquals("row4", new String(uv.get(3))); } } + + @Test + public void testPromoteToUnionFromDecimal() throws Exception { + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = container.addOrGet("dec", + FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + container.setValueCount(1); + + writer.setPosition(0); + writer.writeDecimal(new BigDecimal("0.1")); + writer.setPosition(1); + writer.writeInt(1); + + container.setValueCount(3); + + UnionVector unionVector = (UnionVector) container.getChild("dec"); + UnionHolder holder = new UnionHolder(); + + unionVector.get(0, holder); + NullableDecimalHolder decimalHolder = new NullableDecimalHolder(); + holder.reader.read(decimalHolder); + + assertEquals(1, decimalHolder.isSet); + assertEquals(new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + + unionVector.get(1, holder); + NullableIntHolder intHolder = new NullableIntHolder(); + holder.reader.read(intHolder); + + assertEquals(1, intHolder.isSet); + assertEquals(1, intHolder.value); + } + } }