Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22750][SQL] Reuse mutable states when possible #19940

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override protected def evalInternal(input: InternalRow): Int = partitionId

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId")
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ class CodegenContext {

}

/**
* A map containing the mutable states which have been defined so far using
* `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as key and
* its Java type and init code as value.
*/
private val immutableStates: mutable.Map[String, (String, String)] =
mutable.Map.empty[String, (String, String)]

/**
* Add a mutable state as a field to the generated class. c.f. the comments above.
*
Expand Down Expand Up @@ -252,6 +260,38 @@ class CodegenContext {
}
}

/**
* Add an immutable state as a field to the generated class only if it does not exist yet a field
* with that name. This helps reducing the number of the generated class' fields, since the same
* variable can be reused by many functions.
*
* Even though the added variables are not declared as final, they should never be reassigned in
* the generated code to prevent errors and unexpected behaviors.
*
* Internally, this method calls `addMutableState`.
*
* @param javaType Java type of the field.
* @param variableName Name of the field.
* @param initFunc Function includes statement(s) to put into the init() method to initialize
* this field. The argument is the name of the mutable state variable.
*/
def addImmutableStateIfNotExists(
javaType: String,
variableName: String,
initFunc: String => String = _ => ""): Unit = {
val existingImmutableState = immutableStates.get(variableName)
if (existingImmutableState.isEmpty) {
addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline = true)
immutableStates(variableName) = (javaType, initFunc(variableName))
} else {
val (prevJavaType, prevInitCode) = existingImmutableState.get
assert(prevJavaType == javaType, s"$variableName has already been defined with type " +
s"$prevJavaType and now it is tried to define again with type $javaType.")
assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined " +
s"with different initialization statements.")
}
}

/**
* Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
* that the variable is safely stored, which is important for (potentially) byte array backed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val c = ctx.addMutableState(cal, "cal",
val c = "calDayOfWeek"
ctx.addImmutableStateIfNotExists(cal, c,
v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
s"""
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
Expand Down Expand Up @@ -484,8 +485,9 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = "calWeekOfYear"
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val c = ctx.addMutableState(cal, "cal", v =>
ctx.addImmutableStateIfNotExists(cal, c, v =>
s"""
|$v = $cal.getInstance($dtu.getTimeZone("UTC"));
|$v.setFirstDayOfWeek($cal.MONDAY);
Expand Down Expand Up @@ -1017,7 +1019,8 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val tzTerm = ctx.addMutableState(tzClass, "tz",
v => s"""$v = $dtu.getTimeZone("$tz");""")
val utcTerm = ctx.addMutableState(tzClass, "utc",
val utcTerm = "tzUTC"
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
Expand Down Expand Up @@ -1193,7 +1196,8 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val tzTerm = ctx.addMutableState(tzClass, "tz",
v => s"""$v = $dtu.getTimeZone("$tz");""")
val utcTerm = ctx.addMutableState(tzClass, "utc",
val utcTerm = "tzUTC"
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated question: in the codebase sometimes we use UTC sometimes we use GMT, is it corrected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there is no difference between them in practice. But I think that being consistent would be better for readability

val eval = left.genCode(ctx)
ev.copy(code = s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1148,17 +1148,21 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
val (serializerClass, serializerInstanceClass) = {
val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
("kryoSerializer",
classOf[KryoSerializer].getName,
classOf[KryoSerializerInstance].getName)
} else {
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
("javaSerializer",
classOf[JavaSerializer].getName,
classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v =>
ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
s"""
|if ($env == null) {
| $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
Expand Down Expand Up @@ -1193,17 +1197,21 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
val (serializerClass, serializerInstanceClass) = {
val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
("kryoSerializer",
classOf[KryoSerializer].getName,
classOf[KryoSerializerInstance].getName)
} else {
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
("javaSerializer",
classOf[JavaSerializer].getName,
classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v =>
ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
s"""
|if ($env == null) {
| $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10)
assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
}

test("SPARK-22750: addImmutableStateIfNotExists") {
val ctx = new CodegenContext
val mutableState1 = "field1"
val mutableState2 = "field2"
ctx.addImmutableStateIfNotExists("int", mutableState1)
ctx.addImmutableStateIfNotExists("int", mutableState1)
ctx.addImmutableStateIfNotExists("String", mutableState2)
ctx.addImmutableStateIfNotExists("int", mutableState1)
ctx.addImmutableStateIfNotExists("String", mutableState2)
assert(ctx.inlinedMutableStates.length == 2)
}
}