Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiszk committed Dec 14, 2017
1 parent d6c1a97 commit a9d40e9
Show file tree
Hide file tree
Showing 18 changed files with 58 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ class CodegenContext {
*
* They will be kept as member variables in generated classes like `SpecificProjection`.
*/
val mutableStates: mutable.ArrayBuffer[(String, String)] =
val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
mutable.ArrayBuffer.empty[(String, String)]

// An map keyed by mutable states' types holds the status of mutableStateArray
val mutableStateArrayMap: mutable.Map[String, MutableStateArrays] =
val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
mutable.Map.empty[String, MutableStateArrays]

// An array holds the code that will initialize each state
val mutableStateInitCodes: mutable.ArrayBuffer[String] =
val mutableStateInitCode: mutable.ArrayBuffer[String] =
mutable.ArrayBuffer.empty[String]

// Holding names and current index of mutableStateArrays for a certain type
Expand Down Expand Up @@ -202,7 +202,7 @@ class CodegenContext {
* @param useFreshName If false and inline is true, the name is not changed
* @return the name of the mutable state variable, which is either the original name if the
* variable is inlined to the outer class, or an array access if the variable is to be
* stored in an array of variables of the same type and initialization.
* stored in an array of variables of the same type.
* There are two use cases. One is to use the original name for global variable instead
* of fresh name. Second is to use the original initialization statement since it is
* complex (e.g. allocate multi-dimensional array or object constructor has varibles).
Expand All @@ -217,22 +217,22 @@ class CodegenContext {
initFunc: String => String = _ => "",
forceInline: Boolean = false,
useFreshName: Boolean = true): String = {
val varName = if (useFreshName) freshName(variableName) else variableName

// want to put a primitive type variable at outerClass for performance
val canInlinePrimitive = isPrimitiveType(javaType) &&
(mutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
(inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
val varName = if (useFreshName) freshName(variableName) else variableName
val initCode = initFunc(varName)
mutableStates += ((javaType, varName))
mutableStateInitCodes += initCode
inlinedMutableStates += ((javaType, varName))
mutableStateInitCode += initCode
varName
} else {
val arrays = mutableStateArrayMap.getOrElseUpdate(javaType, new MutableStateArrays)
val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
val element = arrays.getNextSlot()

val initCode = initFunc(element)
mutableStateInitCodes += initCode
mutableStateInitCode += initCode
element
}
}
Expand All @@ -255,11 +255,11 @@ class CodegenContext {
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val inlinedStates = mutableStates.distinct.map { case (javaType, variableName) =>
val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) =>
s"private $javaType $variableName;"
}

val arrayStates = mutableStateArrayMap.flatMap { case (javaType, mutableStateArrays) =>
val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) =>
val numArrays = mutableStateArrays.arrayNames.size
mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) =>
val length = if (index + 1 == numArrays) {
Expand All @@ -284,7 +284,7 @@ class CodegenContext {
def initMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val initCodes = mutableStateInitCodes.distinct
val initCodes = mutableStateInitCode.distinct

// The generated initialization code may exceed 64kb function size limit in JVM if there are too
// many mutable states, so split it into multiple functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
case _ => true
}.unzip
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
|$v = $cal.getInstance($dtu.getTimeZone("UTC"));
|$v.setFirstDayOfWeek($cal.MONDAY);
|$v.setMinimalDaysInFirstWeek(4);
""")
""".stripMargin)
s"""
|$c.setTimeInMillis($time * 1000L * 3600L * 24L);
|${ev.value} = $c.get($cal.WEEK_OF_YEAR);
"""
""".stripMargin
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
|} else {
| $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
|}
""", forceInline = true)
""".stripMargin)

// Code to serialize.
val input = child.genCode(ctx)
Expand Down Expand Up @@ -1203,14 +1203,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode",
v => s"""
if ($env == null) {
$v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
} else {
$v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
}
""", forceInline = true)
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v =>
s"""
|if ($env == null) {
| $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
|} else {
| $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
|}
""".stripMargin)

// Code to deserialize.
val input = child.genCode(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
test("SPARK-22704: Least and greatest use less global variables") {
val ctx1 = new CodegenContext()
Least(Seq(Literal(1), Literal(1))).genCode(ctx1)
assert(ctx1.mutableStates.size == 1)
assert(ctx1.inlinedMutableStates.size == 1)

val ctx2 = new CodegenContext()
Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
assert(ctx2.mutableStates.size == 1)
assert(ctx2.inlinedMutableStates.size == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val ctx = new CodegenContext
cast("1", IntegerType).genCode(ctx)
cast("2", LongType).genCode(ctx)
assert(ctx.mutableStates.length == 0)
assert(ctx.inlinedMutableStates.length == 0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -385,42 +385,44 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val ctx = new CodegenContext
val schema = new StructType().add("a", IntegerType).add("b", StringType)
CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-22696: InitializeJavaBean should not use global variables") {
val ctx = new CodegenContext
InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]),
Map("add" -> Literal(1))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-22716: addReferenceObj should not add mutable states") {
val ctx = new CodegenContext
val foo = new Object()
ctx.addReferenceObj("foo", foo)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-18016: def mutable states by using an array") {
test("SPARK-18016: define mutable states by using an array") {
val ctx1 = new CodegenContext
for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) {
ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;")
}
assert(ctx1.mutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
// When the number of primitive type mutable states is over the threshold, others are
// allocated into an array
assert(ctx1.mutableStateArrayMap.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
assert(ctx1.mutableStateInitCodes.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)
assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)

val ctx2 = new CodegenContext
for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) {
ctx2.addMutableState("InternalRow[]", "r", v => s"$v = new InternalRow[$i];")
}
// When the number of non-primitive type mutable states is over the threshold, others are
// allocated into a new array
assert(ctx2.mutableStateArrayMap.get("InternalRow[]").get.arrayNames.size == 2)
assert(ctx2.mutableStateInitCodes.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
assert(ctx2.inlinedMutableStates.isEmpty)
assert(ctx2.arrayCompactedMutableStates.get("InternalRow[]").get.arrayNames.size == 2)
assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10)
assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22693: CreateNamedStruct should not use global variables") {
val ctx = new CodegenContext
CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
test("SPARK-22705: case when should use less global variables") {
val ctx = new CodegenContext()
CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
assert(ctx.mutableStates.size == 1)
assert(ctx.inlinedMutableStates.size == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22705: Coalesce should use less global variables") {
val ctx = new CodegenContext()
Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
assert(ctx.mutableStates.size == 1)
assert(ctx.inlinedMutableStates.size == 1)
}

test("AtLeastNNonNulls should not throw 64kb exception") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22705: In should use less global variables") {
val ctx = new CodegenContext()
In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("INSET") {
Expand Down Expand Up @@ -440,6 +440,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22693: InSet should not use global variables") {
val ctx = new CodegenContext
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
// four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
// are always required, which are allocated in type-based global array
assert(ctx.mutableStates.length == 0)
assert(ctx.mutableStateInitCodes.length == 4)
assert(ctx.inlinedMutableStates.length == 0)
assert(ctx.mutableStateInitCode.length == 4)
}

test("RegexExtract") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ class GeneratedProjectionSuite extends SparkFunSuite {

test("SPARK-18016: generated projections on wider table requiring state compaction") {
val N = 6000
val wideRow1 = new GenericInternalRow((0 until N).toArray[Any])
val wideRow1 = new GenericInternalRow(new Array[Any](N))
val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
val wideRow2 = new GenericInternalRow(
(0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
Array.tabulate[Any](N)(i => UTF8String.fromString(i.toString)))
val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
val joined = new JoinedRow(wideRow1, wideRow2)
val joinedSchema = StructType(schema1 ++ schema2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class ComplexTypesSuite extends PlanTest{
test("SPARK-22570: CreateArray should not create a lot of global variables") {
val ctx = new CodegenContext
CreateArray(Seq(Literal(1))).genCode(ctx)
assert(ctx.mutableStates.length == 0)
assert(ctx.inlinedMutableStates.length == 0)
}

test("simplify map ops") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
override protected def doProduce(ctx: CodegenContext): String = {
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input",
v => s"$v = inputs[0];", forceInline = true)
v => s"$v = inputs[0];")

// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime")
val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0

val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")

val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx")
val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs = vectorTypes.getOrElse(
Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ case class RowDataSourceScanExec(
override protected def doProduce(ctx: CodegenContext): String = {
val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
forceInline = true)
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
val exprRows = output.zipWithIndex.map{ case (a, i) =>
BoundReference(i, a.dataType, a.nullable)
}
Expand Down Expand Up @@ -353,8 +352,7 @@ case class FileSourceScanExec(
}
val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
forceInline = true)
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
val row = ctx.freshName("row")

ctx.INPUT_ROW = row
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly")
val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = 0

ctx.addNewFunction("stopEarly", s"""
@Override
protected boolean stopEarly() {
return $stopEarly;
}
""", inlineToOuterClass = true)
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count")
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
Expand Down

0 comments on commit a9d40e9

Please sign in to comment.