Skip to content

Commit

Permalink
[SPARK-22750][SQL] Reuse mutable states when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Dec 10, 2017
1 parent ab1b6ee commit 978bfd6
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.freshName("count")
val partitionMaskTerm = ctx.freshName("partitionMask")
val partitionMaskTerm = "partitionMask"
ctx.addMutableState(ctx.JAVA_LONG, countTerm)
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addSingleMutableState(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override protected def evalInternal(input: InternalRow): Int = partitionId

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = ctx.freshName("partitionId")
ctx.addMutableState(ctx.JAVA_INT, idTerm)
val idTerm = "partitionId"
ctx.addSingleMutableState(ctx.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ class CodegenContext {
val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
mutable.ArrayBuffer.empty[(String, String, String)]

/**
* A set containing the names of the mutable states which have been defined so far using
* `addSingleMutableState`.
*/
val singleMutableStates: mutable.Set[String] = mutable.Set.empty[String]

/**
* Add a mutable state as a field to the generated class. c.f. the comments above.
*
Expand All @@ -184,6 +190,27 @@ class CodegenContext {
mutableStates += ((javaType, variableName, initCode))
}

/**
* Add a mutable state as a field to the generated class only if it does not exist yet a field
* with that name. This helps reducing the number of the generated class' fields, since the same
* variable can be reused by many functions.
*
* Internally, this method calls `addMutableState`.
*
* @param javaType Java type of the field.
* @param variableName Name of the field.
* @param initCode The statement(s) to put into the init() method to initialize this field.
*/
def addSingleMutableState(
javaType: String,
variableName: String,
initCode: String = ""): Unit = {
if (!singleMutableStates.contains(variableName)) {
addMutableState(javaType, variableName, initCode)
singleMutableStates += variableName
}
}

/**
* Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
* that the variable is safely stored, which is important for (potentially) byte array backed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = ctx.freshName("cal")
val c = "calDayOfWeek"
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""")
ctx.addSingleMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""")
s"""
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
${ev.value} = $c.get($cal.DAY_OF_WEEK);
Expand Down Expand Up @@ -484,14 +484,14 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = ctx.freshName("cal")
val c = "calWeekOfYear"
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(cal, c,
ctx.addSingleMutableState(cal, c,
s"""
$c = $cal.getInstance($dtu.getTimeZone("UTC"));
$c.setFirstDayOfWeek($cal.MONDAY);
$c.setMinimalDaysInFirstWeek(4);
""")
|$c = $cal.getInstance($dtu.getTimeZone("UTC"));
|$c.setFirstDayOfWeek($cal.MONDAY);
|$c.setMinimalDaysInFirstWeek(4);
""".stripMargin)
s"""
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
${ev.value} = $c.get($cal.WEEK_OF_YEAR);
Expand Down Expand Up @@ -1015,11 +1015,11 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
""".stripMargin)
} else {
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val utcTerm = "tzUTC"
val tzClass = classOf[TimeZone].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
ctx.addSingleMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
Expand Down Expand Up @@ -1191,11 +1191,11 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
""".stripMargin)
} else {
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val utcTerm = "tzUTC"
val tzClass = classOf[TimeZone].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
ctx.addSingleMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1148,12 +1148,15 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
val serializer = ctx.freshName("serializer")
val (serializerClass, serializerInstanceClass) = {
val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
("kryoSerializer",
classOf[KryoSerializer].getName,
classOf[KryoSerializerInstance].getName)
} else {
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
("javaSerializer",
classOf[JavaSerializer].getName,
classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
Expand All @@ -1166,7 +1169,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
$serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
}
"""
ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
ctx.addSingleMutableState(serializerInstanceClass, serializer, serializerInit)

// Code to serialize.
val input = child.genCode(ctx)
Expand Down Expand Up @@ -1194,12 +1197,15 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
val serializer = ctx.freshName("serializer")
val (serializerClass, serializerInstanceClass) = {
val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
("kryoSerializer",
classOf[KryoSerializer].getName,
classOf[KryoSerializerInstance].getName)
} else {
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
("javaSerializer",
classOf[JavaSerializer].getName,
classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
Expand All @@ -1212,7 +1218,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
$serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
}
"""
ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
ctx.addSingleMutableState(serializerInstanceClass, serializer, serializerInit)

// Code to deserialize.
val input = child.genCode(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import scala.reflect.classTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, DecodeUsingSerializer, EncodeUsingSerializer, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -356,4 +359,33 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
}

test("SPARK-22750: reuse serializer in DecodeUsingSerializer and EncodeUsingSerializer") {
val ctx = new CodegenContext
val integerClass = classOf[java.lang.Integer]
val enc = EncodeUsingSerializer(
NewInstance(integerClass, Seq.empty, ObjectType(integerClass)),
kryo = true)
DecodeUsingSerializer[java.lang.Integer](
enc, classTag[java.lang.Integer], kryo = true).genCode(ctx)
assert(ctx.mutableStates.length == 1)

val ctx2 = new CodegenContext
val enc2 = EncodeUsingSerializer(
NewInstance(integerClass, Seq.empty, ObjectType(integerClass)),
kryo = false)
DecodeUsingSerializer[java.lang.Integer](
enc2, classTag[java.lang.Integer], kryo = false).genCode(ctx2)
assert(ctx2.mutableStates.length == 1)

val ctx3 = new CodegenContext
val enc3 = EncodeUsingSerializer(
NewInstance(integerClass, Seq.empty, ObjectType(integerClass)),
kryo = false)
DecodeUsingSerializer[java.lang.Integer](
enc3, classTag[java.lang.Integer], kryo = true).genCode(ctx3)
// here we should have 2 because one is using javaSerializer, while the other is using
// kryoSerializer
assert(ctx3.mutableStates.length == 2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("add" -> Literal(1))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
}

test("SPARK-22750: addSingleMutableState") {
val ctx = new CodegenContext
val mutableState1 = "field1"
val mutableState2 = "field2"
ctx.addSingleMutableState("int", mutableState1)
ctx.addSingleMutableState("int", mutableState1)
ctx.addSingleMutableState("String", mutableState2)
ctx.addSingleMutableState("int", mutableState1)
ctx.addSingleMutableState("String", mutableState2)
assert(ctx.mutableStates.length == 2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat
import java.util.{Calendar, Locale, TimeZone}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
Expand Down Expand Up @@ -741,4 +742,31 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("2015-07-24 00:00:00", null, null)
test(null, null, null)
}

test("SPARK-22750: we should reuse the same UTC timezone object in code generation") {
val ctx = new CodegenContext
ToUTCTimestamp(
Literal.create(ts, TimestampType),
Literal.create(gmtId.get, StringType)).genCode(ctx)
FromUTCTimestamp(
Literal.create(Timestamp.valueOf("2017-12-10 00:00:00"), TimestampType),
Literal.create(gmtId.get, StringType)).genCode(ctx)
// we should have one mutable state for UTC timezone and one mutable state for each expression
// holding the other specific timezone
assert(ctx.mutableStates.length == 3)
}

test("SPARK-22750: we should reuse the same calendar for every DayOfWeek") {
val ctx = new CodegenContext
DayOfWeek(Literal(d)).genCode(ctx)
DayOfWeek(Cast(Literal(ts), DateType, gmtId)).genCode(ctx)
assert(ctx.mutableStates.length == 1)
}

test("SPARK-22750: we should reuse the same calendar for every WeekOfYear") {
val ctx = new CodegenContext
WeekOfYear(Literal(d)).genCode(ctx)
WeekOfYear(Cast(Literal(ts), DateType, gmtId)).genCode(ctx)
assert(ctx.mutableStates.length == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext

class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
test("MonotonicallyIncreasingID") {
Expand All @@ -31,4 +32,19 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
test("InputFileName") {
checkEvaluation(InputFileName(), "")
}

test("SPARK-22750: SparkPartitionID should reuse the mutable state") {
val ctx = new CodegenContext
SparkPartitionID().genCode(ctx)
SparkPartitionID().genCode(ctx)
assert(ctx.mutableStates.length == 1)
}

test("SPARK-22750: MonotonicallyIncreasingID should reuse the mutable state") {
val ctx = new CodegenContext
MonotonicallyIncreasingID().genCode(ctx)
MonotonicallyIncreasingID().genCode(ctx)
// one mutable state for each counter and one for the shared partition mask
assert(ctx.mutableStates.length == 3)
}
}

0 comments on commit 978bfd6

Please sign in to comment.