diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index ea79c5c2fba76..b93250e402fd1 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseValueVector; import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -305,6 +306,9 @@ public StructVector getStruct() { public ${name}Vector get${name}Vector(String name) { if (${uncappedName}Vector == null) { + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); int vectorCount = internalStruct.size(); ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class); if (internalStruct.size() > vectorCount) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index b7fc681c16118..b5627fe43c17c 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -22,6 +22,8 @@ import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; @@ -29,6 +31,7 @@ import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; @@ -41,14 +44,18 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableDecimalHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.UnionHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.Text; import org.junit.After; import org.junit.Before; @@ -598,5 +605,41 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } + + @Test + public void testPromoteToUnionFromDecimal() throws Exception { + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = container.addOrGet("dec", + FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + container.setValueCount(1); + + writer.setPosition(0); + writer.writeDecimal(new BigDecimal("0.1")); + writer.setPosition(1); + writer.writeInt(1); + + container.setValueCount(3); + + UnionVector unionVector = (UnionVector) container.getChild("dec"); + UnionHolder holder = new UnionHolder(); + + unionVector.get(0, holder); + NullableDecimalHolder decimalHolder = new NullableDecimalHolder(); + holder.reader.read(decimalHolder); + + assertEquals(1, decimalHolder.isSet); + assertEquals(new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + + unionVector.get(1, holder); + NullableIntHolder intHolder = new NullableIntHolder(); + holder.reader.read(intHolder); + + assertEquals(1, intHolder.isSet); + assertEquals(1, intHolder.value); + } } }