Skip to content

Commit

Permalink
DX-85876: Failure in UnionReader.read after DecimalVector promotion t…
Browse files Browse the repository at this point in the history
…o UnionVector (dremio#61)

When a DecimalVector is promoted to a UnionVector via a
PromotableWriter, the UnionVector will have the decimal vector in it's
internal struct vector, but the decimalVector field will not be set.
If UnionReader.read is then used to read from the UnionVector, it will
fail when it tries to read one of the promoted decimal values, due
to decimalVector being null, and the exact decimal type not being
provided.  This failure is unnecessary though as we have a pre-existing
decimal vector, the caller just does not know the exact type - and
it shouldn't be required to.

The change here is to check for a pre-existing decimal vector in the
internal struct when getDecimalVector() is called.  If one exists,
set the decimalVector field and return.  Otherwise, if none exists,
throw the exception.
  • Loading branch information
sgcowell authored and lriggs committed Sep 6, 2024
1 parent b8aea7f commit d97bc33
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
4 changes: 4 additions & 0 deletions java/vector/src/main/codegen/templates/UnionVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.BaseValueVector;
import org.apache.arrow.vector.BitVectorHelper;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.complex.AbstractStructVector;
Expand Down Expand Up @@ -305,6 +306,9 @@ public StructVector getStruct() {
public ${name}Vector get${name}Vector(String name) {
if (${uncappedName}Vector == null) {
${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class);
if (${uncappedName}Vector == null) {
throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector");
int vectorCount = internalStruct.size();
${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class);
if (internalStruct.size() > vectorCount) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;


import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.DirtyRootAllocator;
import org.apache.arrow.vector.LargeVarBinaryVector;
import org.apache.arrow.vector.LargeVarCharVector;
Expand All @@ -39,14 +42,18 @@
import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter;
import org.apache.arrow.vector.holders.DurationHolder;
import org.apache.arrow.vector.holders.FixedSizeBinaryHolder;
import org.apache.arrow.vector.holders.NullableDecimalHolder;
import org.apache.arrow.vector.holders.NullableIntHolder;
import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder;
import org.apache.arrow.vector.holders.TimeStampMilliTZHolder;
import org.apache.arrow.vector.holders.UnionHolder;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.util.DecimalUtility;
import org.apache.arrow.vector.util.Text;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -588,5 +595,41 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception {
assertEquals("row3", new String(uv.get(2)));
assertEquals("row4", new String(uv.get(3)));
}

@Test
public void testPromoteToUnionFromDecimal() throws Exception {
try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator);
final DecimalVector v = container.addOrGet("dec",
FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class);
final PromotableWriter writer = new PromotableWriter(v, container)) {

container.allocateNew();
container.setValueCount(1);

writer.setPosition(0);
writer.writeDecimal(new BigDecimal("0.1"));
writer.setPosition(1);
writer.writeInt(1);

container.setValueCount(3);

UnionVector unionVector = (UnionVector) container.getChild("dec");
UnionHolder holder = new UnionHolder();

unionVector.get(0, holder);
NullableDecimalHolder decimalHolder = new NullableDecimalHolder();
holder.reader.read(decimalHolder);

assertEquals(1, decimalHolder.isSet);
assertEquals(new BigDecimal("0.1"),
DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128));

unionVector.get(1, holder);
NullableIntHolder intHolder = new NullableIntHolder();
holder.reader.read(intHolder);

assertEquals(1, intHolder.isSet);
assertEquals(1, intHolder.value);
}
}
}

0 comments on commit d97bc33

Please sign in to comment.