Skip to content

Commit

Permalink
ENH: Add bar plot for DataFrame (#695)
Browse files Browse the repository at this point in the history
kdf:

<img width="348" alt="Screen Shot 2019-08-26 at 2 36 11 PM" src="https://user-images.githubusercontent.com/9269816/63691312-21e17880-c80f-11e9-83d3-cb87655a383f.png">

pdf:

<img width="327" alt="Screen Shot 2019-08-26 at 2 36 21 PM" src="https://user-images.githubusercontent.com/9269816/63691323-2ad24a00-c80f-11e9-91ba-c2c1ec7589a4.png">
  • Loading branch information
charlesdong1991 authored and HyukjinKwon committed Aug 27, 2019
1 parent fa36a48 commit 4404178
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
43 changes: 39 additions & 4 deletions databricks/koalas/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,16 @@ class TopNPlot:
max_rows = 1000

def get_top_n(self, data):
from databricks.koalas import DataFrame, Series
# Simply use the first 1k elements and make it into a pandas dataframe
# For categorical variables, it is likely called from df.x.value_counts().plot.xxx().
data = data.head(TopNPlot.max_rows + 1).to_pandas().to_frame()
if isinstance(data, Series):
data = data.head(TopNPlot.max_rows + 1).to_pandas().to_frame()
elif isinstance(data, DataFrame):
data = data.head(TopNPlot.max_rows + 1).to_pandas()
else:
ValueError("Only DataFrame and Series are supported for plotting.")

self.partial = False
if len(data) > TopNPlot.max_rows:
self.partial = True
Expand Down Expand Up @@ -633,7 +640,7 @@ def plot_series(data, kind='line', ax=None, # Series unique

def _plot(data, x=None, y=None, subplots=False,
ax=None, kind='line', **kwds):

from databricks.koalas import DataFrame
# function copied from pandas.plotting._core
# and adapted to handle Koalas DataFrame and Series

Expand All @@ -642,6 +649,15 @@ def _plot(data, x=None, y=None, subplots=False,
klass = _plot_klass[kind]
else:
raise ValueError("%r is not a valid plot kind" % kind)

# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]

plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj.generate()
plot_obj.draw()
Expand Down Expand Up @@ -949,8 +965,27 @@ def pie(self, bw_method=None, ind=None, **kwds):
def area(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='area')()

def bar(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='bar')()
def bar(self, x=None, y=None, **kwds):
"""
Vertical bar plot.
Parameters
----------
x : label or position, optional
Allows plotting of one column versus another.
If not specified, the index of the DataFrame is used.
y : label or position, optional
Allows plotting of one column versus another.
If not specified, all numerical columns are used.
`**kwds` : optional
Additional keyword arguments are documented in
:meth:`Koalas.DataFrame.plot`.
Returns
-------
axes : :class:`matplotlib.axes.Axes` or numpy.ndarray of them
"""
return self(kind='bar', x=x, y=y, **kwds)

def barh(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='barh')()
Expand Down
27 changes: 26 additions & 1 deletion databricks/koalas/tests/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,35 @@ def test_line_plot(self):
ax4 = kdf.plot.line(colormap='Paired')
self.compare_plots(ax3, ax4)

def test_bar_plot(self):
pdf = self.pdf1
kdf = self.kdf1

ax1 = pdf.plot(kind='bar', colormap='Paired')
ax2 = kdf.plot(kind='bar', colormap='Paired')
self.compare_plots(ax1, ax2)

ax3 = pdf.plot.bar(colormap='Paired')
ax4 = kdf.plot.bar(colormap='Paired')
self.compare_plots(ax3, ax4)

def test_bar_with_x_y(self):
# this is testing plot with specified x and y
pdf = pd.DataFrame({'lab': ['A', 'B', 'C'], 'val': [10, 30, 20]})
kdf = koalas.from_pandas(pdf)

ax1 = pdf.plot(kind="bar", x='lab', y='val', colormap='Paired')
ax2 = kdf.plot(kind="bar", x='lab', y='val', colormap='Paired')
self.compare_plots(ax1, ax2)

ax3 = pdf.plot.bar(x='lab', y='val', colormap='Paired')
ax4 = kdf.plot.bar(x='lab', y='val', colormap='Paired')
self.compare_plots(ax3, ax4)

def test_missing(self):
ks = self.kdf1

unsupported_functions = ['area', 'bar', 'barh', 'box', 'density', 'hexbin',
unsupported_functions = ['area', 'barh', 'box', 'density', 'hexbin',
'hist', 'kde', 'pie', 'scatter']
for name in unsupported_functions:
with self.assertRaisesRegex(PandasNotImplementedError,
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,4 @@ specific plotting methods of the form ``DataFrame.plot.<kind>``.

DataFrame.plot
DataFrame.plot.line
DataFrame.plot.bar

0 comments on commit 4404178

Please sign in to comment.