From c5583fdcd2289559ad98371475eb7288ced9b148 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 1 Sep 2018 12:19:19 +0900 Subject: [PATCH] [SPARK-23466][SQL] Remove redundant null checks in generated Java code by GenerateUnsafeProjection ## What changes were proposed in this pull request? This PR works for one of TODOs in `GenerateUnsafeProjection` "if the nullability of field is correct, we can use it to save null check" to simplify generated code. When `nullable=false` in `DataType`, `GenerateUnsafeProjection` removed code for null checks in the generated Java code. ## How was this patch tested? Added new test cases into `GenerateUnsafeProjectionSuite` Closes #20637 from kiszk/SPARK-23466. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../codegen/GenerateUnsafeProjection.scala | 77 +++++++++++-------- .../expressions/JsonExpressionsSuite.scala | 2 +- .../GenerateUnsafeProjectionSuite.scala | 71 ++++++++++++++++- 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 998a675eecc62..0ecd0de8d8203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types._ */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true @@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, input: String, index: String, - fieldTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode( - JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) => + val isNull = if (nullable) { + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") + } else { + FalseLiteral + } + ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dataType), index) => + val writeFields = inputs.zip(schemas).zipWithIndex.map { + case ((input, Schema(dataType, nullable)), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral) { + if (!nullable) { s""" |${input.code} |${writeField.trim} @@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """.stripMargin } - // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) + val elementAssignment = if (containsNull) { + s""" + |if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + |} else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + |} + """.stripMargin + } else { + writeElement(ctx, element, index, et, arrayWriter) + } + s""" |final ArrayData $tmpInput = $input; |if ($tmpInput instanceof UnsafeArrayData) { @@ -179,23 +195,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $arrayWriter.initialize($numElements); | | for (int $index = 0; $index < $numElements; $index++) { - | if ($tmpInput.isNullAt($index)) { - | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - | } else { - | ${writeElement(ctx, element, index, et, arrayWriter)} - | } + | $elementAssignment | } |} """.stripMargin } - // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, index: String, keyType: DataType, valueType: DataType, + valueContainsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -203,6 +215,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. + val keyArray = writeArrayToBuffer( + ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter) + val valueArray = writeArrayToBuffer( + ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter) + s""" |final MapData $tmpInput = $input; |if ($tmpInput instanceof UnsafeMapData) { @@ -219,7 +236,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | $keyArray | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -227,7 +244,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $valueArray | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro dt: DataType, writer: String): String = dt match { case t: StructType => - writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + writeStructToBuffer( + ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer) - case ArrayType(et, _) => + case ArrayType(et, en) => val previousCursor = ctx.freshName("previousCursor") s""" |// Remember the current cursor so that we can calculate how many bytes are |// written later. |final int $previousCursor = $writer.cursor(); - |${writeArrayToBuffer(ctx, input, et, writer)} + |${writeArrayToBuffer(ctx, input, et, en, writer)} |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """.stripMargin - case MapType(kt, vt, _) => - writeMapToBuffer(ctx, input, index, kt, vt, writer) + case MapType(kt, vt, vn) => + writeMapToBuffer(ctx, input, index, kt, vt, vn, writer) case DecimalType.Fixed(precision, scale) => s"$writer.write($index, $input, $precision, $scale);" @@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypes = expressions.map(_.dataType) + val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) - val numVarLenFields = exprTypes.count { - case dt if UnsafeRow.isFixedLength(dt) => false + val numVarLenFields = exprSchemas.count { + case Schema(dt, _) => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type - case _ => true } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 04f1c8ce0b83d..0e9c8abec33e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with |""".stripMargin val jsonSchema = new StructType() .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) + .add("b", StringType, nullable = !forceJsonNullableSchema) .add("c", StringType, nullable = false) val output = InternalRow(1L, null, UTF8String.fromString("foo")) val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index e9d21f8a8ebcd..01aa3579aea98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { @@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } + + test("Test unsafe projection for array/map/struct") { + val dataType1 = ArrayType(StringType, false) + val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil + val projection1 = GenerateUnsafeProjection.generate(exprs1) + val result1 = projection1.apply(AlwaysNonNull) + assert(!result1.isNullAt(0)) + assert(!result1.getArray(0).isNullAt(0)) + assert(!result1.getArray(0).isNullAt(1)) + assert(!result1.getArray(0).isNullAt(2)) + + val dataType2 = MapType(StringType, StringType, false) + val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil + val projection2 = GenerateUnsafeProjection.generate(exprs2) + val result2 = projection2.apply(AlwaysNonNull) + assert(!result2.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(1)) + assert(!result2.getMap(0).keyArray.isNullAt(2)) + assert(!result2.getMap(0).valueArray.isNullAt(0)) + assert(!result2.getMap(0).valueArray.isNullAt(1)) + assert(!result2.getMap(0).valueArray.isNullAt(2)) + + val dataType3 = (new StructType) + .add("a", StringType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil + val projection3 = GenerateUnsafeProjection.generate(exprs3) + val result3 = projection3.apply(InternalRow(AlwaysNonNull)) + assert(!result3.isNullAt(0)) + assert(!result3.getStruct(0, 1).isNullAt(0)) + assert(!result3.getStruct(0, 2).isNullAt(0)) + assert(!result3.getStruct(0, 3).isNullAt(0)) + } } object AlwaysNull extends InternalRow { @@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow { override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException } + +object AlwaysNonNull extends InternalRow { + private def stringToUTF8Array(stringArray: Array[String]): ArrayData = { + val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray + ArrayData.toArrayData(utf8Array) + } + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = notSupported + override def isNullAt(ordinal: Int): Boolean = notSupported + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) + val keyArray = stringToUTF8Array(Array("1", "2", "3")) + val valueArray = stringToUTF8Array(Array("a", "b", "c")) + override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray) + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException + +}