Skip to content

Commit

Permalink
[SPARK-23466][SQL] Remove redundant null checks in generated Java cod…
Browse files Browse the repository at this point in the history
…e by GenerateUnsafeProjection

## What changes were proposed in this pull request?

This PR works for one of TODOs in `GenerateUnsafeProjection` "if the nullability of field is correct, we can use it to save null check" to simplify generated code.
When `nullable=false` in `DataType`, `GenerateUnsafeProjection` removed code for null checks in the generated Java code.

## How was this patch tested?

Added new test cases into `GenerateUnsafeProjectionSuite`

Closes apache#20637 from kiszk/SPARK-23466.

Authored-by: Kazuaki Ishizaki <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
kiszk authored and ueshin committed Sep 1, 2018
1 parent e1d72f2 commit c5583fd
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import org.apache.spark.sql.types._
*/
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {

case class Schema(dataType: DataType, nullable: Boolean)

/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match {
case NullType => true
Expand All @@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => false
}

// TODO: if the nullability of field is correct, we can use it to save null check.
private def writeStructToBuffer(
ctx: CodegenContext,
input: String,
index: String,
fieldTypes: Seq[DataType],
schemas: Seq[Schema],
rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode(
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) =>
val isNull = if (nullable) {
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)")
} else {
FalseLiteral
}
ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
}

val rowWriterClass = classOf[UnsafeRowWriter].getName
Expand All @@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| // Remember the current cursor so that we can calculate how many bytes are
| // written later.
| final int $previousCursor = $rowWriter.cursor();
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)}
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
|}
""".stripMargin
Expand All @@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
row: String,
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
schemas: Seq[Schema],
rowWriter: String,
isTopLevel: Boolean = false): String = {
val resetWriter = if (isTopLevel) {
Expand All @@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.resetRowWriter();"
}

val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
case ((input, dataType), index) =>
val writeFields = inputs.zip(schemas).zipWithIndex.map {
case ((input, Schema(dataType, nullable)), index) =>
val dt = UserDefinedType.sqlType(dataType)

val setNull = dt match {
Expand All @@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}

val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
if (input.isNull == FalseLiteral) {
if (!nullable) {
s"""
|${input.code}
|${writeField.trim}
Expand Down Expand Up @@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
""".stripMargin
}

// TODO: if the nullability of array element is correct, we can use it to save null check.
private def writeArrayToBuffer(
ctx: CodegenContext,
input: String,
elementType: DataType,
containsNull: Boolean,
rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
Expand All @@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

val element = CodeGenerator.getValue(tmpInput, et, index)

val elementAssignment = if (containsNull) {
s"""
|if ($tmpInput.isNullAt($index)) {
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
|} else {
| ${writeElement(ctx, element, index, et, arrayWriter)}
|}
""".stripMargin
} else {
writeElement(ctx, element, index, et, arrayWriter)
}

s"""
|final ArrayData $tmpInput = $input;
|if ($tmpInput instanceof UnsafeArrayData) {
Expand All @@ -179,30 +195,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| $arrayWriter.initialize($numElements);
|
| for (int $index = 0; $index < $numElements; $index++) {
| if ($tmpInput.isNullAt($index)) {
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
| } else {
| ${writeElement(ctx, element, index, et, arrayWriter)}
| }
| $elementAssignment
| }
|}
""".stripMargin
}

// TODO: if the nullability of value element is correct, we can use it to save null check.
private def writeMapToBuffer(
ctx: CodegenContext,
input: String,
index: String,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean,
rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val tmpCursor = ctx.freshName("tmpCursor")
val previousCursor = ctx.freshName("previousCursor")

// Writes out unsafe map according to the format described in `UnsafeMapData`.
val keyArray = writeArrayToBuffer(
ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)
val valueArray = writeArrayToBuffer(
ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter)

s"""
|final MapData $tmpInput = $input;
|if ($tmpInput instanceof UnsafeMapData) {
Expand All @@ -219,15 +236,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| // Remember the current cursor so that we can write numBytes of key array later.
| final int $tmpCursor = $rowWriter.cursor();
|
| ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
| $keyArray
|
| // Write the numBytes of key array into the first 8 bytes.
| Platform.putLong(
| $rowWriter.getBuffer(),
| $tmpCursor - 8,
| $rowWriter.cursor() - $tmpCursor);
|
| ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
| $valueArray
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
|}
""".stripMargin
Expand All @@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
dt: DataType,
writer: String): String = dt match {
case t: StructType =>
writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
writeStructToBuffer(
ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer)

case ArrayType(et, _) =>
case ArrayType(et, en) =>
val previousCursor = ctx.freshName("previousCursor")
s"""
|// Remember the current cursor so that we can calculate how many bytes are
|// written later.
|final int $previousCursor = $writer.cursor();
|${writeArrayToBuffer(ctx, input, et, writer)}
|${writeArrayToBuffer(ctx, input, et, en, writer)}
|$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
""".stripMargin

case MapType(kt, vt, _) =>
writeMapToBuffer(ctx, input, index, kt, vt, writer)
case MapType(kt, vt, vn) =>
writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)

case DecimalType.Fixed(precision, scale) =>
s"$writer.write($index, $input, $precision, $scale);"
Expand All @@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
expressions: Seq[Expression],
useSubexprElimination: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))

val numVarLenFields = exprTypes.count {
case dt if UnsafeRow.isFixedLength(dt) => false
val numVarLenFields = exprSchemas.count {
case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
// TODO: consider large decimal and interval type
case _ => true
}

val rowWriterClass = classOf[UnsafeRowWriter].getName
Expand All @@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val evalSubexpr = ctx.subexprFunctions.mkString("\n")

val writeExpressions = writeExpressionsToBuffer(
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)

val code =
code"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
|""".stripMargin
val jsonSchema = new StructType()
.add("a", LongType, nullable = false)
.add("b", StringType, nullable = false)
.add("b", StringType, nullable = !forceJsonNullableSchema)
.add("c", StringType, nullable = false)
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class GenerateUnsafeProjectionSuite extends SparkFunSuite {
Expand All @@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite {
assert(!result.isNullAt(0))
assert(result.getStruct(0, 1).isNullAt(0))
}

test("Test unsafe projection for array/map/struct") {
val dataType1 = ArrayType(StringType, false)
val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil
val projection1 = GenerateUnsafeProjection.generate(exprs1)
val result1 = projection1.apply(AlwaysNonNull)
assert(!result1.isNullAt(0))
assert(!result1.getArray(0).isNullAt(0))
assert(!result1.getArray(0).isNullAt(1))
assert(!result1.getArray(0).isNullAt(2))

val dataType2 = MapType(StringType, StringType, false)
val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil
val projection2 = GenerateUnsafeProjection.generate(exprs2)
val result2 = projection2.apply(AlwaysNonNull)
assert(!result2.isNullAt(0))
assert(!result2.getMap(0).keyArray.isNullAt(0))
assert(!result2.getMap(0).keyArray.isNullAt(1))
assert(!result2.getMap(0).keyArray.isNullAt(2))
assert(!result2.getMap(0).valueArray.isNullAt(0))
assert(!result2.getMap(0).valueArray.isNullAt(1))
assert(!result2.getMap(0).valueArray.isNullAt(2))

val dataType3 = (new StructType)
.add("a", StringType, nullable = false)
.add("b", StringType, nullable = false)
.add("c", StringType, nullable = false)
val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil
val projection3 = GenerateUnsafeProjection.generate(exprs3)
val result3 = projection3.apply(InternalRow(AlwaysNonNull))
assert(!result3.isNullAt(0))
assert(!result3.getStruct(0, 1).isNullAt(0))
assert(!result3.getStruct(0, 2).isNullAt(0))
assert(!result3.getStruct(0, 3).isNullAt(0))
}
}

object AlwaysNull extends InternalRow {
Expand All @@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow {
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
private def notSupported: Nothing = throw new UnsupportedOperationException
}

object AlwaysNonNull extends InternalRow {
private def stringToUTF8Array(stringArray: Array[String]): ArrayData = {
val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray
ArrayData.toArrayData(utf8Array)
}
override def numFields: Int = 1
override def setNullAt(i: Int): Unit = {}
override def copy(): InternalRow = this
override def anyNull: Boolean = notSupported
override def isNullAt(ordinal: Int): Boolean = notSupported
override def update(i: Int, value: Any): Unit = notSupported
override def getBoolean(ordinal: Int): Boolean = notSupported
override def getByte(ordinal: Int): Byte = notSupported
override def getShort(ordinal: Int): Short = notSupported
override def getInt(ordinal: Int): Int = notSupported
override def getLong(ordinal: Int): Long = notSupported
override def getFloat(ordinal: Int): Float = notSupported
override def getDouble(ordinal: Int): Double = notSupported
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test")
override def getBinary(ordinal: Int): Array[Byte] = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3"))
val keyArray = stringToUTF8Array(Array("1", "2", "3"))
val valueArray = stringToUTF8Array(Array("a", "b", "c"))
override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray)
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
private def notSupported: Nothing = throw new UnsupportedOperationException

}

0 comments on commit c5583fd

Please sign in to comment.