diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index cf0ef06208d85..3b21a2bb9be4c 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseValueVector; import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -293,7 +294,10 @@ public StructVector getStruct() { <#if minor.class?starts_with("Decimal")> public ${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); + } } return ${uncappedName}Vector; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 9dce33122e881..9ba5cc4cf91d9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -21,17 +21,25 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import java.math.BigDecimal; + import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.NonNullableStructVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; +import org.apache.arrow.vector.holders.NullableDecimalHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.UnionHolder; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -164,4 +172,41 @@ public void testNoPromoteToUnionWithNull() throws Exception { lv.close(); } } + + @Test + public void testPromoteToUnionFromDecimal() throws Exception { + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = container.addOrGet("dec", + FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + container.setValueCount(1); + + writer.setPosition(0); + writer.writeDecimal(new BigDecimal("0.1")); + writer.setPosition(1); + writer.writeInt(1); + + container.setValueCount(3); + + UnionVector unionVector = (UnionVector) container.getChild("dec"); + UnionHolder holder = new UnionHolder(); + + unionVector.get(0, holder); + NullableDecimalHolder decimalHolder = new NullableDecimalHolder(); + holder.reader.read(decimalHolder); + + assertEquals(1, decimalHolder.isSet); + assertEquals(new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + + unionVector.get(1, holder); + NullableIntHolder intHolder = new NullableIntHolder(); + holder.reader.read(intHolder); + + assertEquals(1, intHolder.isSet); + assertEquals(1, intHolder.value); + } + } }