Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement frame plot #686

Merged
merged 7 commits into from
Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pandas.core.dtypes.common import infer_dtype_from_object
else:
from pandas.core.dtypes.common import _get_dtype_from_object as infer_dtype_from_object
from pandas.core.accessor import CachedAccessor
from pandas.core.dtypes.inference import is_sequence
from pyspark import sql as spark
from pyspark.sql import functions as F, Column
Expand All @@ -52,6 +53,7 @@
from databricks.koalas.ml import corr
from databricks.koalas.utils import column_index_level, scol_for
from databricks.koalas.typedef import as_spark_type
from databricks.koalas.plot import KoalasFramePlotMethods

# These regular expression patterns are complied and defined here to avoid to compile the same
# pattern every time it is used in _repr_ and _repr_html_ in DataFrame.
Expand Down Expand Up @@ -466,6 +468,9 @@ def __rfloordiv__(self, other):
def add(self, other):
return self + other

# create accessor for plot
plot = CachedAccessor("plot", KoalasFramePlotMethods)

add.__doc__ = _flex_doc_FRAME.format(
desc='Addition',
op_name='+',
Expand Down
217 changes: 213 additions & 4 deletions databricks/koalas/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,20 @@ def set_result_text(self, ax):

class SampledPlot:
def get_sampled(self, data):
from databricks.koalas import DataFrame
from databricks.koalas import DataFrame, Series

self.fraction = 1 / (len(data) / 1000) # make sure the records are roughly 1000.
if self.fraction > 1:
self.fraction = 1
sampled = data._kdf._sdf.sample(fraction=float(self.fraction))
return DataFrame(data._kdf._internal.copy(sdf=sampled)).to_pandas()

if isinstance(data, DataFrame):
sampled = data._sdf.sample(fraction=float(self.fraction))
return DataFrame(data._internal.copy(sdf=sampled)).to_pandas()
elif isinstance(data, Series):
sampled = data._kdf._sdf.sample(fraction=float(self.fraction))
return DataFrame(data._kdf._internal.copy(sdf=sampled)).to_pandas()
else:
ValueError("Only DataFrame and Series are supported for plotting.")

def set_result_text(self, ax):
assert hasattr(self, "fraction")
Expand Down Expand Up @@ -635,7 +642,6 @@ def _plot(data, x=None, y=None, subplots=False,
klass = _plot_klass[kind]
else:
raise ValueError("%r is not a valid plot kind" % kind)

plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj.generate()
plot_obj.draw()
Expand Down Expand Up @@ -879,3 +885,206 @@ def pie(self, **kwds):
>>> plot = df.mass.plot.pie(subplots=True, figsize=(6, 3))
"""
return self(kind='pie', **kwds)


class KoalasFramePlotMethods(PandasObject):
# TODO: not sure if Koalas wanna combine plot method for Series and DataFrame
"""
DataFrame plotting accessor and method.

Plotting methods can also be accessed by calling the accessor as a method
with the ``kind`` argument:
``df.plot(kind='hist')`` is equivalent to ``df.plot.hist()``
"""
def __init__(self, data):
self.data = data

def __call__(self, x=None, y=None, kind='line', ax=None,
subplots=None, sharex=None, sharey=False, layout=None,
figsize=None, use_index=True, title=None, grid=None,
legend=True, style=None, logx=False, logy=False,
loglog=False, xticks=None, yticks=None, xlim=None,
ylim=None, rot=None, fontsize=None, colormap=None,
table=False, yerr=None, xerr=None, secondary_y=False,
sort_columns=False, **kwds):
return plot_frame(self.data, x=x, y=y, kind=kind, ax=ax,
subplots=subplots, sharex=sharex, sharey=sharey, layout=layout,
figsize=figsize, use_index=use_index, title=title, grid=grid,
legend=legend, style=style, logx=logx, logy=logy,
loglog=loglog, xticks=xticks, yticks=yticks, xlim=xlim,
ylim=ylim, rot=rot, fontsize=fontsize, colormap=colormap,
table=table, yerr=yerr, xerr=xerr, secondary_y=secondary_y,
sort_columns=sort_columns, **kwds)

def line(self, x=None, y=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mimic pandas documenation (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.plot.line.html#pandas.DataFrame.plot.line)? It's okay to don't include the images for now.

"""
Plot DataFrame as lines.

Parameters
----------
x: int or str, optional
Columns to use for the horizontal axis.
y : int, str, or list of them, optional
The values to be plotted.
**kwargs
Keyword arguments to pass on to :meth:`DataFrame.plot`.

Returns
-------
:class:`matplotlib.axes.Axes` or :class:`numpy.ndarray`
Return an ndarray when ``subplots=True``.

See Also
--------
matplotlib.pyplot.plot : Plot y versus x as lines and/or markers.
"""
return self(kind='line', x=x, y=y, **kwargs)

def kde(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='kde')()

def pie(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='pie')()

def area(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='area')()

def bar(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='bar')()

def barh(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='barh')()

def hexbin(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='hexbin')()

def density(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='density')()

def box(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='box')()

def hist(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='hist')()

def scatter(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='scatter')()


def plot_frame(data, x=None, y=None, kind='line', ax=None,
subplots=None, sharex=None, sharey=False, layout=None,
figsize=None, use_index=True, title=None, grid=None,
legend=True, style=None, logx=False, logy=False,
loglog=False, xticks=None, yticks=None, xlim=None,
ylim=None, rot=None, fontsize=None, colormap=None,
table=False, yerr=None, xerr=None, secondary_y=False,
sort_columns=False, **kwds):
"""
Make plots of DataFrames using matplotlib / pylab.

Each plot kind has a corresponding method on the
``DataFrame.plot`` accessor:
``kdf.plot(kind='line')`` is equivalent to
``kdf.plot.line()``.

Parameters
----------
data : DataFrame

kind : str
- 'line' : line plot (default)
- 'bar' : vertical bar plot
- 'barh' : horizontal bar plot
- 'hist' : histogram
- 'box' : boxplot
- 'kde' : Kernel Density Estimation plot
- 'density' : same as 'kde'
- 'area' : area plot
- 'pie' : pie plot
ax : matplotlib axes object
If not passed, uses gca()
x : label or position, default None
y : label, position or list of label, positions, default None
Allows plotting of one column versus another.
figsize : a tuple (width, height) in inches
use_index : boolean, default True
Use index as ticks for x axis
title : string or list
Title to use for the plot. If a string is passed, print the string at
the top of the figure. If a list is passed and `subplots` is True,
print each item in the list above the corresponding subplot.
grid : boolean, default None (matlab style default)
Axis grid lines
legend : False/True/'reverse'
Place legend on axis subplots
style : list or dict
matplotlib line style per column
logx : boolean, default False
Use log scaling on x axis
logy : boolean, default False
Use log scaling on y axis
loglog : boolean, default False
Use log scaling on both x and y axes
xticks : sequence
Values to use for the xticks
yticks : sequence
Values to use for the yticks
xlim : 2-tuple/list
ylim : 2-tuple/list
sharex: bool or None, default is None
Whether to share x axis or not.
sharey: bool, default is False
Whether to share y axis or not.
rot : int, default None
Rotation for ticks (xticks for vertical, yticks for horizontal plots)
fontsize : int, default None
Font size for xticks and yticks
colormap : str or matplotlib colormap object, default None
Colormap to select colors from. If string, load colormap with that name
from matplotlib.
colorbar : boolean, optional
If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
position : float
Specify relative alignments for bar plot layout.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
table : boolean, Series or DataFrame, default False
If True, draw a table using the data in the DataFrame and the data will
be transposed to meet matplotlib's default layout.
If a Series or DataFrame is passed, use passed data to draw a table.
yerr : DataFrame, Series, array-like, dict and str
See :ref:`Plotting with Error Bars <visualization.errorbars>` for
detail.
xerr : same types as yerr.
label : label argument to provide to plot
secondary_y : boolean or sequence of ints, default False
If True then y-axis will be on the right
mark_right : boolean, default True
When using a secondary_y axis, automatically mark the column
labels with "(right)" in the legend
sort_columns: bool, default is False
When True, will sort values on plots.
`**kwds` : keywords
Options to pass to matplotlib plotting method

Returns
-------
axes : :class:`matplotlib.axes.Axes` or numpy.ndarray of them

Notes
-----

- See matplotlib documentation online for more on this subject
- If `kind` = 'bar' or 'barh', you can specify relative alignments
for bar plot layout by `position` keyword.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
"""

return _plot(data, kind=kind, x=x, y=y, ax=ax,
figsize=figsize, use_index=use_index, title=title,
grid=grid, legend=legend, subplots=subplots,
style=style, logx=logx, logy=logy, loglog=loglog,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
rot=rot, fontsize=fontsize, colormap=colormap, table=table,
yerr=yerr, xerr=xerr, sharex=sharex, sharey=sharey,
secondary_y=secondary_y, layout=layout, sort_columns=sort_columns,
**kwds)
60 changes: 60 additions & 0 deletions databricks/koalas/tests/test_frame_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import base64
from io import BytesIO

import matplotlib
from matplotlib import pyplot as plt
import pandas as pd

from databricks import koalas
from databricks.koalas.exceptions import PandasNotImplementedError
from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils


matplotlib.use('agg')


class DataFramePlotTest(ReusedSQLTestCase, TestUtils):
@property
def pdf1(self):
return pd.DataFrame({
'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
'b': [2, 3, 4, 5, 7, 9, 10, 15, 34, 45, 49]
}, index=[0, 1, 3, 5, 6, 8, 9, 9, 9, 10, 10])

@property
def kdf1(self):
return koalas.from_pandas(self.pdf1)

@staticmethod
def plot_to_base64(ax):
bytes_data = BytesIO()
ax.figure.savefig(bytes_data, format='png')
bytes_data.seek(0)
b64_data = base64.b64encode(bytes_data.read())
plt.close(ax.figure)
return b64_data

def compare_plots(self, ax1, ax2):
self.assert_eq(self.plot_to_base64(ax1), self.plot_to_base64(ax2))

def test_line_plot(self):
pdf = self.pdf1
kdf = self.kdf1

ax1 = pdf.plot(kind="line", colormap='Paired')
ax2 = kdf.plot(kind="line", colormap='Paired')
self.compare_plots(ax1, ax2)

ax3 = pdf.plot.line(colormap='Paired')
ax4 = kdf.plot.line(colormap='Paired')
self.compare_plots(ax3, ax4)

def test_missing(self):
ks = self.kdf1

unsupported_functions = ['area', 'bar', 'barh', 'box', 'density', 'hexbin',
'hist', 'kde', 'pie', 'scatter']
for name in unsupported_functions:
with self.assertRaisesRegex(PandasNotImplementedError,
"method.*DataFrame.*{}.*not implemented".format(name)):
getattr(ks.plot, name)()
14 changes: 14 additions & 0 deletions docs/source/reference/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,17 @@ Serialization / IO / Conversion
DataFrame.to_clipboard
DataFrame.to_records
DataFrame.to_latex

.. _api.dataframe.plot:

Plotting
-------------------------------
``DataFrame.plot`` is both a callable method and a namespace attribute for
specific plotting methods of the form ``DataFrame.plot.<kind>``.

.. currentmodule:: databricks.koalas.frame
.. autosummary::
:toctree: api/

DataFrame.plot
DataFrame.plot.line