Skip to content

Commit

Permalink
Merge pull request #13 from tnachen/master
Browse files Browse the repository at this point in the history
Add GenericUDAF wrapper and HiveUDAFFunction
  • Loading branch information
marmbrus committed Feb 2, 2014
2 parents 5b7afd8 + 63003e9 commit 2de89d0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 25 deletions.
103 changes: 83 additions & 20 deletions src/main/scala/catalyst/execution/FunctionRegistry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,29 @@ package execution
import scala.collection.JavaConversions._

import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFEvaluator, AbstractGenericUDAFResolver, GenericUDF}
import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveJavaObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.{io => hadoopIo}

import expressions._
import types._
import org.apache.hadoop.hive.serde2.objectinspector.{ListObjectInspector, StructObjectInspector, ObjectInspector}
import catalyst.types.StructField
import catalyst.types.StructType
import catalyst.types.ArrayType
import catalyst.expressions.Cast

object HiveFunctionRegistry extends analysis.FunctionRegistry {
object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveFunctionFactory {
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
sys.error(s"Couldn't find function $name"))

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
val functionInfo = FunctionRegistry.getFunctionInfo(name)
val function = functionInfo.getFunctionClass.newInstance.asInstanceOf[UDF]
val function = createFunction[UDF](name)
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))

lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
Expand All @@ -34,6 +37,8 @@ object HiveFunctionRegistry extends analysis.FunctionRegistry {
)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(name, IntegerType, children)
} else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(name, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
Expand Down Expand Up @@ -67,20 +72,10 @@ object HiveFunctionRegistry extends analysis.FunctionRegistry {
}
}

abstract class HiveUdf extends Expression with ImplementedUdf with Logging {
self: Product =>

type UDFType
val name: String

def nullable = true
def references = children.flatMap(_.references).toSet

// FunctionInfo is not serializable so we must look it up here again.
lazy val functionInfo = FunctionRegistry.getFunctionInfo(name)
lazy val function = functionInfo.getFunctionClass.newInstance.asInstanceOf[UDFType]

override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})"
trait HiveFunctionFactory {
def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
def createFunction[UDFType](name: String) = getFunctionClass(name).newInstance.asInstanceOf[UDFType]

def unwrap(a: Any): Any = a match {
case null => null
Expand All @@ -93,6 +88,7 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging {
case b: hadoopIo.BooleanWritable => b.get()
case b: hiveIo.ByteWritable => b.get
case list: java.util.List[_] => list.map(unwrap)
case array: Array[_] => array.map(unwrap)
case p: java.lang.Short => p
case p: java.lang.Long => p
case p: java.lang.Float => p
Expand All @@ -104,6 +100,22 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging {
}
}

abstract class HiveUdf extends Expression with ImplementedUdf with Logging with HiveFunctionFactory {
self: Product =>

type UDFType
val name: String

def nullable = true
def references = children.flatMap(_.references).toSet

// FunctionInfo is not serializable so we must look it up here again.
lazy val functionInfo = getFunctionInfo(name)
lazy val function = createFunction[UDFType](name)

override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})"
}

case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
import HiveFunctionRegistry._
type UDFType = UDF
Expand Down Expand Up @@ -194,3 +206,54 @@ case class HiveGenericUdf(
unwrap(instance.evaluate(args))
}
}

trait HiveInspectors {
def toInspectors(exprs: Seq[Expression]) = exprs.map(_.dataType).map {
case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
}

def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
case s: StructObjectInspector =>
StructType(s.getAllStructFieldRefs.map(f => {
StructField(f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), true)
}))
case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
case _: WritableStringObjectInspector => StringType
case _: WritableIntObjectInspector => IntegerType
case _: WritableDoubleObjectInspector => DoubleType
case _: WritableBooleanObjectInspector => BooleanType
case _: WritableLongObjectInspector => LongType
case _: WritableShortObjectInspector => ShortType
case _: WritableByteObjectInspector => ByteType
}
}

case class HiveGenericUdaf(
name: String,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {

type UDFType = AbstractGenericUDAFResolver

lazy val resolver = createFunction[AbstractGenericUDAFResolver](name)

lazy val objectInspector: ObjectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
}

lazy val inspectors: Seq[ObjectInspector] = toInspectors(children)

def dataType: DataType = inspectorToDataType(objectInspector)

def nullable: Boolean = true

def references: Set[Attribute] = children.map(_.references).flatten.toSet
}
8 changes: 6 additions & 2 deletions src/main/scala/catalyst/execution/TestShark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ object TestShark extends SharkInstance {
* hive test cases assume the system is set up.
*/
private def rewritePaths(cmd: String): String =
if (cmd.toUpperCase contains "LOAD DATA") {
cmd.replaceAll("\\.\\.", TestShark.inRepoTests.getCanonicalPath)
if (cmd.toUpperCase.contains("LOAD DATA") && cmd.contains("..")) {
"[\"\'](../.*)[\"\'] ".r.findFirstMatchIn(cmd)
.map(r => {
val newPath = new File(TestShark.inRepoTests.getCanonicalPath, cmd.substring(r.start + 1, r.end - 2).replaceFirst("(\\.\\./)+", "")).getAbsolutePath
cmd.substring(0, r.start + 1) + newPath + cmd.substring(r.end - 2)
}).getOrElse(cmd)
} else {
cmd
}
Expand Down
38 changes: 35 additions & 3 deletions src/main/scala/catalyst/execution/aggregates.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package catalyst
package execution

import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFEvaluator, AbstractGenericUDAFResolver}

import catalyst.errors._
import catalyst.expressions._
import catalyst.plans.physical.{ClusteredDistribution, AllTuples}
import org.apache.spark.rdd.SharkPairRDDFunctions

/* Implicits */
import SharkPairRDDFunctions._
import org.apache.spark.rdd.SharkPairRDDFunctions._

case class Aggregate(
groupingExpressions: Seq[Expression],
Expand All @@ -23,6 +24,36 @@ case class Aggregate(
}

override def otherCopyArgs = sc :: Nil

case class HiveUdafFunction(
exprs: Seq[Expression],
base: AggregateExpression,
functionName: String)
extends AggregateFunction
with HiveInspectors
with HiveFunctionFactory {

def this() = this(null, null, null)

val resolver = createFunction[AbstractGenericUDAFResolver](functionName)

val function = {
val evaluator = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray)
evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, toInspectors(exprs).toArray)
evaluator
}

// Cast required to avoid type inference selecting a deprecated Hive API.
val buffer = function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]

def result: Any = unwrap(function.evaluate(buffer))

def apply(input: Seq[Row]): Unit = {
val inputs = exprs.map(Evaluate(_, input).asInstanceOf[AnyRef]).toArray
function.iterate(buffer, inputs)
}
}

def output = aggregateExpressions.map(_.toAttribute)

/* Replace all aggregate expressions with spark functions that will compute the result. */
Expand All @@ -34,6 +65,7 @@ case class Aggregate(
// TODO: Create custom query plan node that calculates distinct values efficiently.
case base @ CountDistinct(expr) => new CountDistinctFunction(expr, base)
case base @ First(expr) => new FirstFunction(expr, base)
case base @ HiveGenericUdaf(resolver, expr) => new HiveUdafFunction(expr, base, resolver)
}

val remainingAttributes = impl.collect { case a: Attribute => a }
Expand Down Expand Up @@ -159,4 +191,4 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
result = Evaluate(expr, input)
}
}
}
}
2 changes: 2 additions & 0 deletions src/main/scala/catalyst/frontend/Hive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ object HiveQl {
/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), StringType)
case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), StringType)
case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), IntegerType)
case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
Expand Down

0 comments on commit 2de89d0

Please sign in to comment.