Skip to content

Commit

Permalink
[SPARK-20253][SQL] Remove unnecessary nullchecks of a return value fr…
Browse files Browse the repository at this point in the history
…om Spark runtime routines in generated Java code

## What changes were proposed in this pull request?

This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns ``null`` or not (e.g. ``ArrayData.toDoubleArray()`` never returns ``null``). Thus, we can eliminate a null check for the return value from the Spark runtime routine.

When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ``ArrayData.toDoubleArray()`` never returns ``null```, we can eliminate null checks at lines 90-92, and 97.

```java
val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache
ds.count
ds.map(e => e).show
```

Without this PR
```java
/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           if (deserializetoobject_funcResult == null) {
/* 091 */             deserializetoobject_isNull = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull = deserializetoobject_value == null;
/* 098 */       }
/* 099 */
/* 100 */       boolean mapelements_isNull = true;
/* 101 */       double[] mapelements_value = null;
/* 102 */       if (!false) {
/* 103 */         mapelements_resultIsNull = false;
/* 104 */
/* 105 */         if (!mapelements_resultIsNull) {
/* 106 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 107 */           mapelements_argValue = deserializetoobject_value;
/* 108 */         }
/* 109 */
/* 110 */         mapelements_isNull = mapelements_resultIsNull;
/* 111 */         if (!mapelements_isNull) {
/* 112 */           Object mapelements_funcResult = null;
/* 113 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 114 */           if (mapelements_funcResult == null) {
/* 115 */             mapelements_isNull = true;
/* 116 */           } else {
/* 117 */             mapelements_value = (double[]) mapelements_funcResult;
/* 118 */           }
/* 119 */
/* 120 */         }
/* 121 */         mapelements_isNull = mapelements_value == null;
/* 122 */       }
/* 123 */
/* 124 */       serializefromobject_resultIsNull = false;
/* 125 */
/* 126 */       if (!serializefromobject_resultIsNull) {
/* 127 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 128 */         serializefromobject_argValue = mapelements_value;
/* 129 */       }
/* 130 */
/* 131 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 132 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 133 */       serializefromobject_isNull = serializefromobject_value == null;
/* 134 */       serializefromobject_holder.reset();
/* 135 */
/* 136 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 137 */
/* 138 */       if (serializefromobject_isNull) {
/* 139 */         serializefromobject_rowWriter.setNullAt(0);
/* 140 */       } else {
/* 141 */         // Remember the current cursor so that we can calculate how many bytes are
/* 142 */         // written later.
/* 143 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 144 */
/* 145 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 146 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 147 */           // grow the global buffer before writing data.
/* 148 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 149 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 150 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 151 */
/* 152 */         } else {
/* 153 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 154 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 155 */
/* 156 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 157 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 158 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 159 */             } else {
/* 160 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 161 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 162 */             }
/* 163 */           }
/* 164 */         }
/* 165 */
/* 166 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 167 */       }
/* 168 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 169 */       append(serializefromobject_result);
/* 170 */       if (shouldStop()) return;
/* 171 */     }
/* 172 */   }
```

With this PR (removed most of lines 90-97 in the above code)
```java
/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 091 */
/* 092 */         }
/* 093 */
/* 094 */       }
/* 095 */
/* 096 */       boolean mapelements_isNull = true;
/* 097 */       double[] mapelements_value = null;
/* 098 */       if (!false) {
/* 099 */         mapelements_resultIsNull = false;
/* 100 */
/* 101 */         if (!mapelements_resultIsNull) {
/* 102 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 103 */           mapelements_argValue = deserializetoobject_value;
/* 104 */         }
/* 105 */
/* 106 */         mapelements_isNull = mapelements_resultIsNull;
/* 107 */         if (!mapelements_isNull) {
/* 108 */           Object mapelements_funcResult = null;
/* 109 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 110 */           if (mapelements_funcResult == null) {
/* 111 */             mapelements_isNull = true;
/* 112 */           } else {
/* 113 */             mapelements_value = (double[]) mapelements_funcResult;
/* 114 */           }
/* 115 */
/* 116 */         }
/* 117 */         mapelements_isNull = mapelements_value == null;
/* 118 */       }
/* 119 */
/* 120 */       serializefromobject_resultIsNull = false;
/* 121 */
/* 122 */       if (!serializefromobject_resultIsNull) {
/* 123 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 124 */         serializefromobject_argValue = mapelements_value;
/* 125 */       }
/* 126 */
/* 127 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 128 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 129 */       serializefromobject_isNull = serializefromobject_value == null;
/* 130 */       serializefromobject_holder.reset();
/* 131 */
/* 132 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 133 */
/* 134 */       if (serializefromobject_isNull) {
/* 135 */         serializefromobject_rowWriter.setNullAt(0);
/* 136 */       } else {
/* 137 */         // Remember the current cursor so that we can calculate how many bytes are
/* 138 */         // written later.
/* 139 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 140 */
/* 141 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 142 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 143 */           // grow the global buffer before writing data.
/* 144 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 145 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 146 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 147 */
/* 148 */         } else {
/* 149 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 150 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 151 */
/* 152 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 153 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 154 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 155 */             } else {
/* 156 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 157 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 158 */             }
/* 159 */           }
/* 160 */         }
/* 161 */
/* 162 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 163 */       }
/* 164 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 165 */       append(serializefromobject_result);
/* 166 */       if (shouldStop()) return;
/* 167 */     }
/* 168 */   }
```

## How was this patch tested?

Add test suites to ``DatasetPrimitiveSuite``

Author: Kazuaki Ishizaki <[email protected]>

Closes #17569 from kiszk/SPARK-20253.
  • Loading branch information
kiszk authored and cloud-fan committed Apr 10, 2017
1 parent 261eaf5 commit 7a63f5e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection {
getPath :: Nil)

case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)

case t if t <:< localTypeOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
returnNullable = false)

case t if t <:< localTypeOf[BigDecimal] =>
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)

case t if t <:< localTypeOf[java.math.BigInteger] =>
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
returnNullable = false)

case t if t <:< localTypeOf[scala.math.BigInt] =>
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
returnNullable = false)

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
Expand All @@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
val arrayCls = arrayClassFor(elementType)

if (elementNullable) {
Invoke(arrayData, "array", arrayCls)
Invoke(arrayData, "array", arrayCls, returnNullable = false)
} else {
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => "toIntArray"
Expand All @@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection {
case other => throw new IllegalStateException("expect primitive array element type " +
"but got " + other)
}
Invoke(arrayData, primitiveMethod, arrayCls)
Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
}

case t if t <:< localTypeOf[Seq[_]] =>
Expand Down Expand Up @@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection {
Invoke(
MapObjects(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
returnNullable = false),
schemaFor(keyType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
ObjectType(classOf[Array[Any]]), returnNullable = false)

val valueData =
Invoke(
MapObjects(
p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
returnNullable = false),
schemaFor(valueType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
ObjectType(classOf[Array[Any]]), returnNullable = false)

StaticInvoke(
ArrayBasedMapData.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ object RowEncoder {
udtClass,
Nil,
dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt, inputObject :: Nil)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)

case TimestampType =>
StaticInvoke(
Expand Down Expand Up @@ -136,16 +136,18 @@ object RowEncoder {
case t @ MapType(kt, vt, valueNullable) =>
val keys =
Invoke(
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
returnNullable = false),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedKeys = serializerFor(keys, ArrayType(kt, false))

val values =
Invoke(
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
returnNullable = false),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))

NewInstance(
Expand Down Expand Up @@ -262,17 +264,18 @@ object RowEncoder {
input :: Nil)

case _: DecimalType =>
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
returnNullable = false)

case StringType =>
Invoke(input, "toString", ObjectType(classOf[String]))
Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)

case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
ObjectType(classOf[Array[_]]), returnNullable = false)
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,25 +225,26 @@ case class Invoke(
getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure our null bit is still
// set correctly.
val assignResult = if (!returnNullable) {
s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
} else {
s"""
if ($funcResult != null) {
${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
} else {
${ev.isNull} = true;
}
"""
}
s"""
Object $funcResult = null;
${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
if ($funcResult == null) {
${ev.isNull} = true;
} else {
${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
}
$assignResult
"""
}

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}

val code = s"""
${obj.code}
boolean ${ev.isNull} = true;
Expand All @@ -254,7 +255,6 @@ case class Invoke(
if (!${ev.isNull}) {
$evaluate
}
$postNullCheck
}
"""
ev.copy(code = code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(dsBoolean.map(e => !e), false, true)
}

test("mapPrimitiveArray") {
val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS()
checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4))
checkDataset(dsInt.map(e => null: Array[Int]), null, null)

val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS()
checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D))
checkDataset(dsDouble.map(e => null: Array[Double]), null, null)
}

test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
checkDataset(
Expand Down

0 comments on commit 7a63f5e

Please sign in to comment.