Skip to content

Commit

Permalink
[SQL] Improve DataFrame API error reporting
Browse files Browse the repository at this point in the history
1. Throw UnsupportedOperationException if a Column is not computable.
2. Perform eager analysis on DataFrame so we can catch errors when they happen (not when an action is run).

Author: Reynold Xin <[email protected]>
Author: Davies Liu <[email protected]>

Closes apache#4296 from rxin/col-computability and squashes the following commits:

6527b86 [Reynold Xin] Merge pull request alteryx#8 from davies/col-computability
fd92bc7 [Reynold Xin] Merge branch 'master' into col-computability
f79034c [Davies Liu] fix python tests
5afe1ff [Reynold Xin] Fix scala test.
17f6bae [Reynold Xin] Various fixes.
b932e86 [Reynold Xin] Added eager analysis for error reporting.
e6f00b8 [Reynold Xin] [SQL][API] ComputableColumn vs IncomputableColumn
  • Loading branch information
rxin committed Feb 3, 2015
1 parent eccb9fb commit 554403f
Show file tree
Hide file tree
Showing 20 changed files with 896 additions and 381 deletions.
75 changes: 52 additions & 23 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,10 @@ def head(self, n=None):
return rs[0] if rs else None
return self.take(n)

def first(self):
""" Return the first row. """
return self.head()

def tail(self):
raise NotImplemented

Expand Down Expand Up @@ -2159,7 +2163,7 @@ def select(self, *cols):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)

def filter(self, condition):
Expand Down Expand Up @@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)

def agg(self, *exprs):
Expand Down Expand Up @@ -2278,14 +2282,17 @@ def agg(self, *exprs):
:param exprs: list or aggregate columns or a map from column
name to agregate methods.
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jmap = MapConverter().convert(exprs[0],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(jmap)
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
jdf = self._jdf.agg(*exprs)
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)

@dfapi
Expand Down Expand Up @@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):

def _create_column_from_name(name):
sc = SparkContext._active_spark_context
return sc._jvm.Column(name)
return sc._jvm.IncomputableColumn(name)


def _scalaMethod(name):
Expand All @@ -2371,7 +2378,7 @@ def _(self):
return _


def _bin_op(name, pass_literal_through=False):
def _bin_op(name, pass_literal_through=True):
""" Create a method for given binary operator
Keyword arguments:
Expand Down Expand Up @@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
# __getattr__ = _bin_op("getField")

# string methods
rlike = _bin_op("rlike", pass_literal_through=True)
like = _bin_op("like", pass_literal_through=True)
startswith = _bin_op("startsWith", pass_literal_through=True)
endswith = _bin_op("endsWith", pass_literal_through=True)
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
upper = _unary_op("upper")
lower = _unary_op("lower")

def substr(self, startPos, pos):
if type(startPos) != type(pos):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):

jc = self._jc.substr(startPos, pos)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, pos._jc)
Expand Down Expand Up @@ -2507,30 +2513,53 @@ def cast(self, dataType):
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)


def _to_java_column(col):
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
return jcol


def _aggregate_func(name):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)

return staticmethod(_)


class Aggregator(object):
"""
A collections of builtin aggregators
"""
max = _aggregate_func("max")
min = _aggregate_func("min")
avg = mean = _aggregate_func("mean")
sum = _aggregate_func("sum")
first = _aggregate_func("first")
last = _aggregate_func("last")
count = _aggregate_func("count")
AGGS = [
'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
]
for _name in AGGS:
locals()[_name] = _aggregate_func(_name)
del _name

@staticmethod
def countDistinct(col, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
sc._jvm.Dsl.toColumns(jcols))
return Column(jc)

@staticmethod
def approxCountDistinct(col, rsd=None):
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)


def _test():
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,9 +1029,11 @@ def test_aggregator(self):
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
# TODO(davies): fix aggregators

from pyspark.sql import Aggregator as Agg
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])

def test_help_command(self):
# Regression test for SPARK-5464
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
nameToField.getOrElse(name,
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}

/**
Expand Down
Loading

0 comments on commit 554403f

Please sign in to comment.