diff --git a/databricks/koalas/plot/core.py b/databricks/koalas/plot/core.py index a147369b5e..496d2ed71c 100644 --- a/databricks/koalas/plot/core.py +++ b/databricks/koalas/plot/core.py @@ -99,7 +99,7 @@ def set_result_text(self, ax): class HistogramPlotBase: @staticmethod def prepare_hist_data(data, bins): - # TODO: this logic is same with KdePlot. Might have to deduplicate it. + # TODO: this logic is similar with KdePlotBase. Might have to deduplicate it. from databricks.koalas.series import Series if isinstance(data, Series): @@ -339,19 +339,44 @@ def get_fliers(colname, outliers, min_val): class KdePlotBase: + @staticmethod + def prepare_kde_data(data): + # TODO: this logic is similar with HistogramPlotBase. Might have to deduplicate it. + from databricks.koalas.series import Series + + if isinstance(data, Series): + data = data.to_frame() + + numeric_data = data.select_dtypes( + include=["byte", "decimal", "integer", "float", "long", "double", np.datetime64] + ) + + # no empty frames or series allowed + if len(numeric_data.columns) == 0: + raise TypeError( + "Empty {0!r}: no numeric data to " "plot".format(numeric_data.__class__.__name__) + ) + + return numeric_data + @staticmethod def get_ind(sdf, ind): - # 'sdf' is a Spark DataFrame that selects one column. + def calc_min_max(): + if len(sdf.columns) > 1: + min_col = F.least(*map(F.min, sdf)) + max_col = F.greatest(*map(F.max, sdf)) + else: + min_col = F.min(sdf.columns[-1]) + max_col = F.max(sdf.columns[-1]) + return sdf.select(min_col, max_col).first() if ind is None: - min_val, max_val = sdf.select(F.min(sdf.columns[-1]), F.max(sdf.columns[-1])).first() - + min_val, max_val = calc_min_max() sample_range = max_val - min_val ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, 1000,) elif is_integer(ind): - min_val, max_val = sdf.select(F.min(sdf.columns[-1]), F.max(sdf.columns[-1])).first() - - sample_range = min_val - max_val + min_val, max_val = calc_min_max() + sample_range = max_val - min_val ind = np.linspace(min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, ind,) return ind diff --git a/databricks/koalas/plot/matplotlib.py b/databricks/koalas/plot/matplotlib.py index f1792c7095..5e60121f4c 100644 --- a/databricks/koalas/plot/matplotlib.py +++ b/databricks/koalas/plot/matplotlib.py @@ -464,23 +464,7 @@ def _make_plot(self): class KoalasKdePlot(PandasKdePlot, KdePlotBase): def _compute_plot_data(self): - from databricks.koalas.series import Series - - data = self.data - if isinstance(data, Series): - data = data.to_frame() - - numeric_data = data.select_dtypes( - include=["byte", "decimal", "integer", "float", "long", "double", np.datetime64] - ) - - # no empty frames or series allowed - if len(numeric_data.columns) == 0: - raise TypeError( - "Empty {0!r}: no numeric data to " "plot".format(numeric_data.__class__.__name__) - ) - - self.data = numeric_data + self.data = KdePlotBase.prepare_kde_data(self.data) def _make_plot(self): # 'num_colors' requires to calculate `shape` which has to count all. diff --git a/databricks/koalas/plot/plotly.py b/databricks/koalas/plot/plotly.py index 1db8f1f19c..3d47a68713 100644 --- a/databricks/koalas/plot/plotly.py +++ b/databricks/koalas/plot/plotly.py @@ -22,6 +22,7 @@ name_like_string, KoalasPlotAccessor, BoxPlotBase, + KdePlotBase, ) if TYPE_CHECKING: @@ -38,6 +39,8 @@ def plot_koalas(data: Union["ks.DataFrame", "ks.Series"], kind: str, **kwargs): return plot_histogram(data, **kwargs) if kind == "box": return plot_box(data, **kwargs) + if kind == "kde" or kind == "density": + return plot_kde(data, **kwargs) # Other plots. return plotly.plot(KoalasPlotAccessor.pandas_plot_data_map[kind](data), kind, **kwargs) @@ -171,3 +174,38 @@ def plot_box(data: Union["ks.DataFrame", "ks.Series"], **kwargs): fig["layout"]["xaxis"]["title"] = colname fig["layout"]["yaxis"]["title"] = "value" return fig + + +def plot_kde(data: Union["ks.DataFrame", "ks.Series"], **kwargs): + from plotly import express + import databricks.koalas as ks + + if isinstance(data, ks.DataFrame) and "color" not in kwargs: + kwargs["color"] = "names" + + kdf = KdePlotBase.prepare_kde_data(data) + sdf = kdf._internal.spark_frame + data_columns = kdf._internal.data_spark_columns + ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None)) + bw_method = kwargs.pop("bw_method", None) + + pdfs = [] + for label in kdf._internal.column_labels: + pdfs.append( + pd.DataFrame( + { + "Density": KdePlotBase.compute_kde( + sdf.select(kdf._internal.spark_column_for(label)), + ind=ind, + bw_method=bw_method, + ), + "names": name_like_string(label), + "index": ind, + } + ) + ) + pdf = pd.concat(pdfs) + + fig = express.line(pdf, x="index", y="Density", **kwargs) + fig["layout"]["xaxis"]["title"] = None + return fig diff --git a/databricks/koalas/tests/plot/test_frame_plot_plotly.py b/databricks/koalas/tests/plot/test_frame_plot_plotly.py index c48c3403e1..1199b12663 100644 --- a/databricks/koalas/tests/plot/test_frame_plot_plotly.py +++ b/databricks/koalas/tests/plot/test_frame_plot_plotly.py @@ -224,3 +224,31 @@ def check_hist_plot(kdf): columns = pd.MultiIndex.from_tuples([("x", "y"), ("y", "z")]) kdf1.columns = columns check_hist_plot(kdf1) + + def test_kde_plot(self): + kdf = ks.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 3, 5, 7, 9], "c": [2, 4, 6, 8, 10]}) + + pdf = pd.DataFrame( + { + "Density": [ + 0.03515491, + 0.06834979, + 0.00663503, + 0.02372059, + 0.06834979, + 0.01806934, + 0.01806934, + 0.06834979, + 0.02372059, + ], + "names": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "index": [-3.5, 5.5, 14.5, -3.5, 5.5, 14.5, -3.5, 5.5, 14.5], + } + ) + + actual = kdf.plot.kde(bw_method=5, ind=3) + + expected = express.line(pdf, x="index", y="Density", color="names") + expected["layout"]["xaxis"]["title"] = None + + self.assertEqual(pprint.pformat(actual.to_dict()), pprint.pformat(expected.to_dict())) diff --git a/databricks/koalas/tests/plot/test_series_plot_plotly.py b/databricks/koalas/tests/plot/test_series_plot_plotly.py index 4c015e20ea..17b0659fa1 100644 --- a/databricks/koalas/tests/plot/test_series_plot_plotly.py +++ b/databricks/koalas/tests/plot/test_series_plot_plotly.py @@ -206,3 +206,20 @@ def test_pox_plot_arguments(self): with self.assertRaisesRegex(ValueError, "does not support"): self.kdf1.a.plot.box(notched=True) self.kdf1.a.plot.box(hovertext="abc") # other arguments should not throw an exception + + def test_kde_plot(self): + kdf = ks.DataFrame({"a": [1, 2, 3, 4, 5]}) + pdf = pd.DataFrame( + { + "Density": [0.05709372, 0.07670272, 0.05709372], + "names": ["a", "a", "a"], + "index": [-1.0, 3.0, 7.0], + } + ) + + actual = kdf.a.plot.kde(bw_method=5, ind=3) + + expected = express.line(pdf, x="index", y="Density") + expected["layout"]["xaxis"]["title"] = None + + self.assertEqual(pprint.pformat(actual.to_dict()), pprint.pformat(expected.to_dict()))