Skip to content

Commit

Permalink
address comment, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 19, 2015
1 parent 7f5ffbe commit 4ff0457
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ class CodeGenContext {
def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
}


abstract class GeneratedClass {
def generate(expressions: Array[Expression]): Any
}

/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
Expand Down Expand Up @@ -234,26 +239,23 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*
* It will track the time used to compile
*/
protected def compile(code: String): Class[_] = {
protected def compile(code: String): GeneratedClass = {
val startTime = System.nanoTime()
// Current class loader may be ExecutorClassLoader, which will cause fail to find
// class in java.lang. It's also slow, use the default JVM class loader.
val currentThread = Thread.currentThread()
val oldClassLoader = currentThread.getContextClassLoader
currentThread.setContextClassLoader(getClass.getClassLoader)
val clazz = try {
new ClassBodyEvaluator(code).getClazz()
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow"))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
evaluator.cook(code)
} catch {
case e: Exception =>
logError(s"failed to compile:\n $code", e)
throw e
} finally {
currentThread.setContextClassLoader(oldClassLoader)
}
val endTime = System.nanoTime()
def timeMs: Double = (endTime - startTime).toDouble / 1000000
logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")
clazz
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}.mkString("\n")
val code = s"""
import org.apache.spark.sql.catalyst.InternalRow;

public SpecificProjection generate($exprType[] expr) {
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
}

Expand Down Expand Up @@ -85,10 +83,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
logDebug(s"code for ${expressions.mkString(",")}:\n$code")

val c = compile(code)
// fetch the only one method `generate(Expression[])`
val m = c.getDeclaredMethods()(0)
() => {
m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection]
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ object GenerateOrdering
}.mkString("\n")

val code = s"""
import org.apache.spark.sql.catalyst.InternalRow;

public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
}
Expand All @@ -100,9 +98,6 @@ object GenerateOrdering

logDebug(s"Generated Ordering: $code")

val c = compile(code)
// fetch the only one method `generate(Expression[])`
val m = c.getDeclaredMethods()(0)
m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering]
compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -41,8 +40,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
val ctx = newCodeGenContext()
val eval = predicate.gen(ctx)
val code = s"""
import org.apache.spark.sql.catalyst.InternalRow;

public SpecificPredicate generate($exprType[] expr) {
return new SpecificPredicate(expr);
}
Expand All @@ -62,10 +59,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool

logDebug(s"Generated predicate '$predicate':\n$code")

val c = compile(code)
// fetch the only one method `generate(Expression[])`
val m = c.getDeclaredMethods()(0)
val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate]
val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
(r: InternalRow) => p.eval(r)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n")

val code = s"""
import org.apache.spark.sql.catalyst.InternalRow;

public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
}
Expand Down Expand Up @@ -220,9 +218,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}")

val c = compile(code)
// fetch the only one method `generate(Expression[])`
val m = c.getDeclaredMethods()(0)
m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection]
compile(code).generate(ctx.references.toArray).asInstanceOf[Projection]
}
}

0 comments on commit 4ff0457

Please sign in to comment.