From f79034cc3d18d1051905a8ee77fad1f676373037 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 2 Feb 2015 16:57:54 -0800 Subject: [PATCH] fix python tests --- python/pyspark/sql.py | 75 +++++++++++++------ python/pyspark/tests.py | 6 +- .../org/apache/spark/sql/DataFrame.scala | 4 - .../org/apache/spark/sql/DataFrameImpl.scala | 4 - .../main/scala/org/apache/spark/sql/Dsl.scala | 13 +++- .../apache/spark/sql/GroupedDataFrame.scala | 2 + .../apache/spark/sql/IncomputableColumn.scala | 2 - 7 files changed, 69 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3f2d7ac82585f..32bff0c7e8c55 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2124,6 +2124,10 @@ def head(self, n=None): return rs[0] if rs else None return self.take(n) + def first(self): + """ Return the first row. """ + return self.head() + def tail(self): raise NotImplemented @@ -2159,7 +2163,7 @@ def select(self, *cols): else: cols = [c._jc for c in cols] jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): @@ -2189,7 +2193,7 @@ def groupBy(self, *cols): else: cols = [c._jc for c in cols] jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): @@ -2278,14 +2282,17 @@ def agg(self, *exprs): :param exprs: list or aggregate columns or a map from column name to agregate methods. """ + assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): jmap = MapConverter().convert(exprs[0], self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.agg(jmap) else: # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" - jdf = self._jdf.agg(*exprs) + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal): def _create_column_from_name(name): sc = SparkContext._active_spark_context - return sc._jvm.Column(name) + return sc._jvm.IncomputableColumn(name) def _scalaMethod(name): @@ -2371,7 +2378,7 @@ def _(self): return _ -def _bin_op(name, pass_literal_through=False): +def _bin_op(name, pass_literal_through=True): """ Create a method for given binary operator Keyword arguments: @@ -2465,10 +2472,10 @@ def __init__(self, jc, jdf=None, sql_ctx=None): # __getattr__ = _bin_op("getField") # string methods - rlike = _bin_op("rlike", pass_literal_through=True) - like = _bin_op("like", pass_literal_through=True) - startswith = _bin_op("startsWith", pass_literal_through=True) - endswith = _bin_op("endsWith", pass_literal_through=True) + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") upper = _unary_op("upper") lower = _unary_op("lower") @@ -2476,7 +2483,6 @@ def substr(self, startPos, pos): if type(startPos) != type(pos): raise TypeError("Can not mix the type") if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, pos) elif isinstance(startPos, Column): jc = self._jc.substr(startPos._jc, pos._jc) @@ -2507,16 +2513,21 @@ def cast(self, dataType): return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + def _aggregate_func(name): """ Create a function for aggregator by name""" def _(col): sc = SparkContext._active_spark_context - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol) + jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) return Column(jc) + return staticmethod(_) @@ -2524,13 +2535,31 @@ class Aggregator(object): """ A collections of builtin aggregators """ - max = _aggregate_func("max") - min = _aggregate_func("min") - avg = mean = _aggregate_func("mean") - sum = _aggregate_func("sum") - first = _aggregate_func("first") - last = _aggregate_func("last") - count = _aggregate_func("count") + AGGS = [ + 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs', + 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct', + ] + for _name in AGGS: + locals()[_name] = _aggregate_func(_name) + del _name + + @staticmethod + def countDistinct(col, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), + sc._jvm.Dsl.toColumns(jcols)) + return Column(jc) + + @staticmethod + def approxCountDistinct(col, rsd=None): + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bec1961f26393..fef6c92875a1c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1029,9 +1029,11 @@ def test_aggregator(self): g = df.groupBy() self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg - # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0]) def test_help_command(self): # Regression test for SPARK-5464 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 95830bd1b0f18..385e1ec74f5f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -500,10 +500,6 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// - /** - * A helpful function for Py4j, convert a list of Column to an array - */ - protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] /** * Converts a JavaRDD to a PythonRDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 87441abb1357f..f8fcc25569482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -323,10 +323,6 @@ private[sql] class DataFrameImpl protected[sql]( //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// - protected[sql] override def toColumnArray(cols: JList[Column]): Array[Column] = { - cols.toList.toArray - } - protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 4d6320c05aed2..b4279a32ffa21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.util.{List => JList} + import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -105,8 +108,7 @@ object Dsl { def countDistinct(expr: Column, exprs: Column*): Column = CountDistinct((expr +: exprs).map(_.expr)) - def approxCountDistinct(e: Column): Column = - ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) @@ -121,6 +123,13 @@ object Dsl { def sqrt(e: Column): Column = Sqrt(e.expr) def abs(e: Column): Column = Abs(e.expr) + /** + * This is a private API for Python + * TODO: move this to a private package + */ + def toColumns(cols: JList[Column]): Seq[Column] = { + cols.toList.toSeq + } // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index 3187a92fc7b71..d3acd41bbf3eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.{List => JList} + import scala.language.implicitConversions import scala.collection.JavaConversions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 215b57ec63ba9..2f8c695d5654b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -156,7 +156,5 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def toJSON: RDD[String] = err() - protected[sql] override def toColumnArray(cols: java.util.List[Column]): Array[Column] = err() - protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err() }