From ddb3950327058b5133644766bf35325b153e8866 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Wed, 18 Mar 2015 01:12:49 +0100 Subject: [PATCH] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns --- .../org/apache/spark/sql/DataFrame.scala | 83 ++++++------------- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++-- .../scala/org/apache/spark/sql/TestData.scala | 4 +- 3 files changed, 37 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index accf72b15170f..1746cf0b27179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -751,56 +751,6 @@ class DataFrame private[sql]( select(colNames :_*) } - /** - * Compute specified aggregations for given columns of this [[DataFrame]]. - * Each row of the resulting [[DataFrame]] contains column with aggregation name - * and columns with aggregation results for each given column. - * The aggregations are described as a List of mappings of their name to function - * which generates aggregation expression from column name. - * - * Note: can process only simple aggregation expressions - * which can be parsed by spark [[SqlParser]] - * - * {{{ - * val aggregations = List( - * "max" -> (col => s"max($col)"), // expression computes max - * "avg" -> (col => s"sum($col)/count($col)")) // expression computes average - * df.multipleAggExpr("summary", aggregations, "age", "height") - * - * // summary age height - * // max 92.0 192.0 - * // avg 53.0 178.0 - * }}} - */ - @scala.annotation.varargs - private def multipleAggExpr( - aggCol: String, - aggregations: List[(String, String => String)], - cols: String*): DataFrame = { - - val sqlParser = new SqlParser() - - def addAggNameCol(aggDF: DataFrame, aggName: String = "") = - aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*) - - def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) = - nextAgg match { case (aggName, colToAggExpr) => - val nextAggDF = if (cols.nonEmpty) { - def colToAggCol(col: String) = - Column(sqlParser.parseExpression(colToAggExpr(col))).as(col) - val aggCols = cols.map(colToAggCol) - agg(aggCols.head, aggCols.tail:_*) - } else { - sqlContext.emptyDataFrame - } - val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName) - aggSoFarDF.unionAll(nextAggWithNameDF) - } - - val emptyAgg = addAggNameCol(this).limit(0) - aggregations.foldLeft(emptyAgg)(unionWithNextAgg) - } - /** * Compute numerical statistics for given columns of this [[DataFrame]]: * count, mean (avg), stddev (standard deviation), min, max. @@ -821,14 +771,33 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def describe(cols: String*): DataFrame = { + + def aggCol(name: String = "") = s"'$name' as summary" + val statistics = List[(String, Expression => Expression)]( + "count" -> (expr => Count(expr)), + "mean" -> (expr => Average(expr)), + "stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)), + Multiply(Average(expr), Average(expr))))), + "min" -> (expr => Min(expr)), + "max" -> (expr => Max(expr))) + val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols - val aggregations = List[(String, String => String)]( - "count" -> (col => s"count($col)"), - "mean" -> (col => s"avg($col)"), - "stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"), - "min" -> (col => s"min($col)"), - "max" -> (col => s"max($col)")) - multipleAggExpr("summary", aggregations, numCols:_*) + + // union all statistics starting from empty one + var description = selectExpr(aggCol()::numCols.toList:_*).limit(0) + for ((name, colToAgg) <- statistics) { + // generate next statistic aggregation + val nextAgg = if (numCols.nonEmpty) { + val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + agg(aggCols.head, aggCols.tail:_*) + } else { + sqlContext.emptyDataFrame + } + // add statistic name column + val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*) + description = description.unionAll(nextStat) + } + description } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bee5a49b0bb6f..0f37664ce1b06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -439,18 +439,22 @@ class DataFrameSuite extends QueryTest { test("describe") { def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq - val describeAllCols = describeTestData.describe("age", "height") + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) checkAnswer(describeAllCols, describeResult) - val describeNoCols = describeTestData.describe() - assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height")) - checkAnswer(describeNoCols, describeResult) - val describeOneCol = describeTestData.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + val emptyDescription = describeTestData.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index d96b5be9aa9b9..e4446cd5e0818 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -209,8 +209,8 @@ object TestData { Row("count", 4.0, 4.0) :: Row("mean", 33.0, 178.0) :: Row("stddev", 16.583123951777, 10.0) :: - Row("min", 16.0, 164) :: - Row("max", 60.0, 192) :: Nil + Row("min", 16.0, 164.0) :: + Row("max", 60.0, 192.0) :: Nil val emptyDescribeResult = Row("count", 0, 0) :: Row("mean", null, null) ::