From 90848d507457d30abb36e3ba07618dfc87c34cd6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2018 10:18:32 +0800 Subject: [PATCH] [SPARK-23284][SQL] Document the behavior of several ColumnVector's get APIs when accessing null slot ## What changes were proposed in this pull request? For some ColumnVector get APIs such as getDecimal, getBinary, getStruct, getArray, getInterval, getUTF8String, we should clearly document their behaviors when accessing null slot. They should return null in this case. Then we can remove null checks from the places using above APIs. For the APIs of primitive values like getInt, getInts, etc., this also documents their behaviors when accessing null slots. Their returning values are undefined and can be anything. ## How was this patch tested? Added tests into `ColumnarBatchSuite`. Author: Liang-Chi Hsieh Closes #20455 from viirya/SPARK-23272-followup. --- .../datasources/orc/OrcColumnVector.java | 3 + .../vectorized/MutableColumnarRow.java | 7 -- .../vectorized/WritableColumnVector.java | 5 ++ .../sql/vectorized/ArrowColumnVector.java | 4 + .../spark/sql/vectorized/ColumnVector.java | 63 ++++++++++------ .../spark/sql/vectorized/ColumnarRow.java | 7 -- .../vectorized/ColumnarBatchSuite.scala | 74 ++++++++++++++++++- 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index c8add4c9f486c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -154,12 +154,14 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); return Decimal.apply(data, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); BytesColumnVector col = bytesData; return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); @@ -167,6 +169,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); byte[] binary = new byte[bytesData.length[index]]; System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 307c19032dee5..4e4242fe8d9b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -127,43 +127,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getMap(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 9d447cdc79063..5275e4a91eac0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -341,6 +341,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -367,6 +368,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -384,6 +386,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -613,6 +616,7 @@ public final int appendStruct(boolean isNull) { // array offsets and lengths in the current column vector. @Override public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } @@ -620,6 +624,7 @@ public final ColumnarArray getArray(int rowId) { // second child column vector, and puts the offsets and lengths in the current column vector. @Override public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f3ece538c3b80..f8e37e995a17f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -101,21 +101,25 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 530d4d23d4eaf..ad99b450a4809 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -80,12 +80,14 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the boolean type value for rowId. + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract boolean getBoolean(int rowId); /** - * Gets boolean type values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -96,12 +98,14 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the byte type value for rowId. + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract byte getByte(int rowId); /** - * Gets byte type values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -112,12 +116,14 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the short type value for rowId. + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract short getShort(int rowId); /** - * Gets short type values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -128,12 +134,14 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the int type value for rowId. + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract int getInt(int rowId); /** - * Gets int type values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -144,12 +152,14 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the long type value for rowId. + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract long getLong(int rowId); /** - * Gets long type values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -160,12 +170,14 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the float type value for rowId. + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract float getFloat(int rowId); /** - * Gets float type values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -176,12 +188,14 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the double type value for rowId. + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract double getDouble(int rowId); /** - * Gets double type values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -192,7 +206,7 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the struct type value for rowId. + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. * * To support struct type, implementations must implement {@link #getChild(int)} and make this * vector a tree structure. The number of child vectors must be same as the number of fields of @@ -205,7 +219,7 @@ public final ColumnarRow getStruct(int rowId) { } /** - * Returns the array type value for rowId. + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. * * To support array type, implementations must construct an {@link ColumnarArray} and return it in * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all @@ -218,13 +232,13 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarArray getArray(int rowId); /** - * Returns the map type value for rowId. + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. * * In Spark, map type value is basically a key data array and a value data array. A key from the * key array with a index and a value from the value array with the same index contribute to * an entry of this map type value. * - * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * To support map type, implementations must construct a {@link ColumnarMap} and return it in * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data * of all the values of all the maps in this vector, and a pair of offset and length which @@ -233,24 +247,25 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarMap getMap(int ordinal); /** - * Returns the decimal type value for rowId. + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the string type value for rowId. Note that the returned UTF8String may point to the - * data of this column vector, please copy it if you want to keep it after this column vector is - * freed. + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the binary type value for rowId. + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. */ public abstract byte[] getBinary(int rowId); /** - * Returns the calendar interval type value for rowId. + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. * * In Spark, calendar interval type value is basically an integer value representing the number of * months in this interval, and a long value representing the number of microseconds in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index b400f7f93c1fe..f2f2279590023 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -119,43 +119,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getMap(rowId); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8fe2985836f2e..772f687526008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -572,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -619,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -647,6 +651,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += new CalendarInterval(0, 2000) column.putNull(2) + assert(column.getInterval(2) == null) reference += null months.putInt(3, 20) @@ -683,6 +688,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(0).numElements == 1) assert(column.getArray(1).numElements == 2) assert(column.isNullAt(2)) + assert(column.getArray(2) == null) assert(column.getArray(3).numElements == 0) assert(column.getArray(4).numElements == 3) @@ -785,6 +791,7 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, 1) column.putArray(1, 1, 2) column.putNull(2) + assert(column.getMap(2) == null) column.putArray(3, 3, 0) column.putArray(4, 3, 3) @@ -821,6 +828,7 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(0, 3.45) column.putNull(1) + assert(column.getStruct(1) == null) c1.putInt(2, 456) c2.putDouble(2, 5.67) @@ -1261,4 +1269,68 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("Decimal API", 4, DecimalType.IntDecimal) { + column => + + val reference = mutable.ArrayBuffer.empty[Decimal] + + var idx = 0 + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) + idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + reference += null + idx += 1 + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + + testVector("Binary APIs", 4, BinaryType) { + column => + + val reference = mutable.ArrayBuffer.empty[String] + var idx = 0 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" + idx += 1 + + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" + idx += 1 + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() + } }