Skip to content

Commit

Permalink
consts fold
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 5, 2015
1 parent 86fac2c commit e03edaa
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
if (!${ev.nullTerm}) {
${eval2.code}
if(!${eval2.nullTerm}) {
${ev.primitiveTerm} = (${ctx.primitiveType(dataType)})($resultCode);
${ev.primitiveTerm} = $resultCode;
} else {
${ev.nullTerm} = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
if (left.dataType.isInstanceOf[DecimalType]) {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
} else {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev, (eval1, eval2) =>
s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

protected def evalInternal(evalE1: Any, evalE2: Any): Any =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
* valid if `nullTerm` is set to `true`.
*/
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term)
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term)

/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,25 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
override def eval(input: Row): Any = value

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// change the nullTerm and primitiveTerm to consts, to inline them
if (value == null) {
s"""
final boolean ${ev.nullTerm} = true;
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
"""
ev.nullTerm = "true"
ev.primitiveTerm = ctx.defaultValue(dataType)
""
} else {
dataType match {
case BooleanType =>
ev.nullTerm = "false"
ev.primitiveTerm = value.toString
""
case FloatType => // This must go before NumericType
s"""
final boolean ${ev.nullTerm} = false;
final float ${ev.primitiveTerm} = ${value}f;
"""
ev.nullTerm = "false"
ev.primitiveTerm = s"${value}f"
""
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
s"""
final boolean ${ev.nullTerm} = false;
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
"""
ev.nullTerm = "false"
ev.primitiveTerm = value.toString
""
// eval() version may be faster for non-primitive types
case other =>
super.genCode(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
final boolean ${ev.nullTerm} = false;
final boolean ${ev.primitiveTerm} = ${eval.nullTerm};
"""
ev.nullTerm = "false"
ev.primitiveTerm = eval.nullTerm
eval.code
}

override def toString: String = s"IS NULL $child"
Expand All @@ -103,10 +102,9 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = false;
boolean ${ev.primitiveTerm} = !${eval.nullTerm};
"""
ev.nullTerm = "false"
ev.primitiveTerm = s"(!(${eval.nullTerm}))"
eval.code
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm)
ev.nullTerm = "false"
eval1.code + eval2.code + s"""
final boolean ${ev.nullTerm} = false;
final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) ||
(!${eval1.nullTerm} && $equalCode);
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
elementType match {
case IntegerType | LongType =>
ev.nullTerm = "false"
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}();
"""
case _ => super.genCode(ctx, ev)
Expand Down Expand Up @@ -111,11 +111,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
val setEval = set.gen(ctx)
val htype = ctx.primitiveType(dataType)

ev.nullTerm = "false"
itemEval.code + setEval.code + s"""
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
boolean ${ev.nullTerm} = false;
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
"""
case _ => super.genCode(ctx, ev)
Expand Down Expand Up @@ -164,8 +164,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
val rightEval = right.gen(ctx)
val htype = ctx.primitiveType(dataType)

ev.nullTerm = "false"
leftEval.code + rightEval.code + s"""
boolean ${ev.nullTerm} = false;
${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm};
${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm});
"""
Expand Down

0 comments on commit e03edaa

Please sign in to comment.