From e079f2b32d3391bdfe835ca66dde7eaedf5df5c0 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 15 Jan 2014 22:53:00 -0800 Subject: [PATCH 1/4] Add GenericUDAF wrapper and HiveUDAFFunction --- .../catalyst/execution/FunctionRegistry.scala | 105 ++++++++++++++---- .../scala/catalyst/execution/TestShark.scala | 8 +- .../scala/catalyst/execution/aggregates.scala | 31 ++++++ src/main/scala/catalyst/frontend/Hive.scala | 13 +++ 4 files changed, 135 insertions(+), 22 deletions(-) diff --git a/src/main/scala/catalyst/execution/FunctionRegistry.scala b/src/main/scala/catalyst/execution/FunctionRegistry.scala index 5ebc30fc982df..8b49366fcf46b 100644 --- a/src/main/scala/catalyst/execution/FunctionRegistry.scala +++ b/src/main/scala/catalyst/execution/FunctionRegistry.scala @@ -4,17 +4,21 @@ package execution import scala.collection.JavaConversions._ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFEvaluator, AbstractGenericUDAFResolver, GenericUDF} import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveJavaObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.{io => hadoopIo} import expressions._ import types._ +import org.apache.hadoop.hive.serde2.objectinspector.{ListObjectInspector, StructObjectInspector, ObjectInspector} +import catalyst.types.StructField +import catalyst.types.StructType +import catalyst.types.ArrayType +import catalyst.expressions.Cast -object HiveFunctionRegistry extends analysis.FunctionRegistry { +object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveFunctionFactory { def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. @@ -22,8 +26,7 @@ object HiveFunctionRegistry extends analysis.FunctionRegistry { sys.error(s"Couldn't find function $name")) if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val functionInfo = FunctionRegistry.getFunctionInfo(name) - val function = functionInfo.getFunctionClass.newInstance.asInstanceOf[UDF] + val function = createFunction[UDF](name) val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) @@ -34,6 +37,8 @@ object HiveFunctionRegistry extends analysis.FunctionRegistry { ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(name, IntegerType, children) + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdaf(name, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -67,20 +72,10 @@ object HiveFunctionRegistry extends analysis.FunctionRegistry { } } -abstract class HiveUdf extends Expression with ImplementedUdf with Logging { - self: Product => - - type UDFType - val name: String - - def nullable = true - def references = children.flatMap(_.references).toSet - - // FunctionInfo is not serializable so we must look it up here again. - lazy val functionInfo = FunctionRegistry.getFunctionInfo(name) - lazy val function = functionInfo.getFunctionClass.newInstance.asInstanceOf[UDFType] - - override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})" +trait HiveFunctionFactory { + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass + def createFunction[UDFType](name: String) = getFunctionClass(name).newInstance.asInstanceOf[UDFType] def unwrap(a: Any): Any = a match { case null => null @@ -93,6 +88,7 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging { case b: hadoopIo.BooleanWritable => b.get() case b: hiveIo.ByteWritable => b.get case list: java.util.List[_] => list.map(unwrap) + case array: Array[_] => array.map(unwrap) case p: java.lang.Short => p case p: java.lang.Long => p case p: java.lang.Float => p @@ -104,6 +100,24 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging { } } +abstract class HiveUdf extends Expression with ImplementedUdf with Logging with HiveFunctionFactory { + self: Product => + + type UDFType + val name: String + + def nullable = true + def references = children.flatMap(_.references).toSet + + // FunctionInfo is not serializable so we must look it up here again. + lazy val functionInfo = getFunctionInfo(name) + lazy val function = createFunction[UDFType](name) + + override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})" + + +} + case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { import HiveFunctionRegistry._ type UDFType = UDF @@ -194,3 +208,54 @@ case class HiveGenericUdf( unwrap(instance.evaluate(args)) } } + +trait HiveInspectors { + def toInspectors(exprs: Seq[Expression]) = exprs.map(_.dataType).map { + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + StructField(f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + } +} + +case class HiveGenericUdaf( + name: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + lazy val resolver = createFunction[AbstractGenericUDAFResolver](name) + + lazy val objectInspector: ObjectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + type UDFType = AbstractGenericUDAFResolver + + lazy val inspectors: Seq[ObjectInspector] = toInspectors(children) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + def references: Set[Attribute] = children.map(_.references).flatten.toSet +} \ No newline at end of file diff --git a/src/main/scala/catalyst/execution/TestShark.scala b/src/main/scala/catalyst/execution/TestShark.scala index cf73abe135113..bf8ef9b4c3865 100644 --- a/src/main/scala/catalyst/execution/TestShark.scala +++ b/src/main/scala/catalyst/execution/TestShark.scala @@ -86,8 +86,12 @@ object TestShark extends SharkInstance { * hive test cases assume the system is set up. */ private def rewritePaths(cmd: String): String = - if (cmd.toUpperCase contains "LOAD DATA") { - cmd.replaceAll("\\.\\.", TestShark.inRepoTests.getCanonicalPath) + if (cmd.toUpperCase.contains("LOAD DATA") && cmd.contains("..")) { + "[\"\']([\\./\\w]+)[\"\'] ".r.findFirstMatchIn(cmd) + .map(r => { + val newPath = new File(TestShark.inRepoTests.getCanonicalPath, cmd.substring(r.start + 1, r.end - 2).replaceFirst("(\\.\\./)+", "")).getAbsolutePath + cmd.substring(0, r.start + 1) + newPath + cmd.substring(r.end - 2) + }).getOrElse(cmd) } else { cmd } diff --git a/src/main/scala/catalyst/execution/aggregates.scala b/src/main/scala/catalyst/execution/aggregates.scala index c15414cf002f9..2b01c6eccb719 100644 --- a/src/main/scala/catalyst/execution/aggregates.scala +++ b/src/main/scala/catalyst/execution/aggregates.scala @@ -3,6 +3,7 @@ package execution import catalyst.errors._ import catalyst.expressions._ +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFEvaluator, AbstractGenericUDAFResolver} /* Implicits */ import org.apache.spark.SparkContext._ @@ -87,6 +88,35 @@ case class Aggregate( } override def otherCopyArgs = sc :: Nil + + case class HiveUdafFunction( + exprs: Seq[Expression], + base: AggregateExpression, + functionName: String) + extends AggregateFunction + with HiveInspectors + with HiveFunctionFactory { + + def this() = this(null, null, null) + + val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + + val function = { + val evaluator = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray) + evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, toInspectors(exprs).toArray) + evaluator + } + + val buffer = function.getNewAggregationBuffer + + def result: Any = unwrap(function.evaluate(buffer)) + + def apply(input: Seq[Row]): Unit = { + val inputs = exprs.map(Evaluate(_, input).asInstanceOf[AnyRef]).toArray + function.iterate(buffer, inputs) + } + } + def output = aggregateExpressions.map(_.toAttribute) def createAggregateImplementations() = aggregateExpressions.map { agg => @@ -97,6 +127,7 @@ case class Aggregate( // TODO: Create custom query plan node that calculates distinct values efficiently. case base @ CountDistinct(expr) => new CountDistinctFunction(expr, base) case base @ First(expr) => new FirstFunction(expr, base) + case base @ HiveGenericUdaf(resolver, expr) => new HiveUdafFunction(expr, base, resolver) } val remainingAttributes = impl.collect { case a: Attribute => a } diff --git a/src/main/scala/catalyst/frontend/Hive.scala b/src/main/scala/catalyst/frontend/Hive.scala index b0af22caa8eb2..7d3f1c654c996 100644 --- a/src/main/scala/catalyst/frontend/Hive.scala +++ b/src/main/scala/catalyst/frontend/Hive.scala @@ -703,6 +703,7 @@ object HiveQl { case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) /* Casts */ +<<<<<<< HEAD case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => @@ -721,6 +722,18 @@ object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) +======= + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) +>>>>>>> Add GenericUDAF wrapper and HiveUDAFFunction /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) From 8e0931f1ca55aff597132c6a27ed058866680db5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 28 Jan 2014 14:15:03 -0800 Subject: [PATCH 2/4] Cast to avoid using deprecated hive API. --- src/main/scala/catalyst/execution/aggregates.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/scala/catalyst/execution/aggregates.scala b/src/main/scala/catalyst/execution/aggregates.scala index 2b01c6eccb719..a7b5d5da489e2 100644 --- a/src/main/scala/catalyst/execution/aggregates.scala +++ b/src/main/scala/catalyst/execution/aggregates.scala @@ -107,7 +107,8 @@ case class Aggregate( evaluator } - val buffer = function.getNewAggregationBuffer + // Cast required to avoid type inference selecting a deprecated Hive API. + val buffer = function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] def result: Any = unwrap(function.evaluate(buffer)) From b1151a8a13b6a3cd1dfa53115b67610955112d66 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 29 Jan 2014 09:58:26 -0800 Subject: [PATCH 3/4] Fix load data regex --- src/main/scala/catalyst/execution/TestShark.scala | 2 +- .../scala/catalyst/execution/aggregates.scala | 2 +- src/main/scala/catalyst/frontend/Hive.scala | 15 ++------------- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/main/scala/catalyst/execution/TestShark.scala b/src/main/scala/catalyst/execution/TestShark.scala index bf8ef9b4c3865..51d138e019ae4 100644 --- a/src/main/scala/catalyst/execution/TestShark.scala +++ b/src/main/scala/catalyst/execution/TestShark.scala @@ -87,7 +87,7 @@ object TestShark extends SharkInstance { */ private def rewritePaths(cmd: String): String = if (cmd.toUpperCase.contains("LOAD DATA") && cmd.contains("..")) { - "[\"\']([\\./\\w]+)[\"\'] ".r.findFirstMatchIn(cmd) + "[\"\'](../.*)[\"\'] ".r.findFirstMatchIn(cmd) .map(r => { val newPath = new File(TestShark.inRepoTests.getCanonicalPath, cmd.substring(r.start + 1, r.end - 2).replaceFirst("(\\.\\./)+", "")).getAbsolutePath cmd.substring(0, r.start + 1) + newPath + cmd.substring(r.end - 2) diff --git a/src/main/scala/catalyst/execution/aggregates.scala b/src/main/scala/catalyst/execution/aggregates.scala index a7b5d5da489e2..5f2d2db15fe7e 100644 --- a/src/main/scala/catalyst/execution/aggregates.scala +++ b/src/main/scala/catalyst/execution/aggregates.scala @@ -107,7 +107,7 @@ case class Aggregate( evaluator } - // Cast required to avoid type inference selecting a deprecated Hive API. + // Cast required to avoid type inference selecting a deprecated Hive API. val buffer = function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] def result: Any = unwrap(function.evaluate(buffer)) diff --git a/src/main/scala/catalyst/frontend/Hive.scala b/src/main/scala/catalyst/frontend/Hive.scala index 7d3f1c654c996..595d1cd9250c6 100644 --- a/src/main/scala/catalyst/frontend/Hive.scala +++ b/src/main/scala/catalyst/frontend/Hive.scala @@ -703,9 +703,10 @@ object HiveQl { case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) /* Casts */ -<<<<<<< HEAD case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), IntegerType) case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => @@ -722,18 +723,6 @@ object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) -======= - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) ->>>>>>> Add GenericUDAF wrapper and HiveUDAFFunction /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) From 63003e90fb70e13d22ad7e260e29897286a7776b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 2 Feb 2014 12:37:58 -0800 Subject: [PATCH 4/4] Fix spacing. --- src/main/scala/catalyst/execution/aggregates.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/catalyst/execution/aggregates.scala b/src/main/scala/catalyst/execution/aggregates.scala index 51b6d27bd75d2..ea8b7a1aeab63 100644 --- a/src/main/scala/catalyst/execution/aggregates.scala +++ b/src/main/scala/catalyst/execution/aggregates.scala @@ -26,9 +26,9 @@ case class Aggregate( override def otherCopyArgs = sc :: Nil case class HiveUdafFunction( - exprs: Seq[Expression], - base: AggregateExpression, - functionName: String) + exprs: Seq[Expression], + base: AggregateExpression, + functionName: String) extends AggregateFunction with HiveInspectors with HiveFunctionFactory {