Skip to content

Commit 15ffaba

Browse files
committed
Support x and y properly in plots (both matplotlib and plotly)
1 parent eda0bb5 commit 15ffaba

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

databricks/koalas/plot/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def prepare_hist_data(data, bins):
117117

118118
if is_integer(bins):
119119
# computes boundaries for the column
120-
bins = HistogramPlotBase.get_bins(data.to_spark(), bins)
120+
bins = HistogramPlotBase.get_bins(numeric_data.to_spark(), bins)
121121

122122
return numeric_data, bins
123123

databricks/koalas/plot/matplotlib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pandas.core.dtypes.inference import is_list_like
2424
from pandas.io.formats.printing import pprint_thing
2525

26+
from databricks.koalas import Series
2627
from databricks.koalas.plot import (
2728
TopNPlotBase,
2829
SampledPlotBase,
@@ -373,6 +374,12 @@ class KoalasHistPlot(PandasHistPlot, HistogramPlotBase):
373374
def _args_adjust(self):
374375
if is_list_like(self.bottom):
375376
self.bottom = np.array(self.bottom)
377+
y = self.kwds.get("y")
378+
if y:
379+
if isinstance(self.data, Series):
380+
self.data = self.data.to_frame()
381+
# When y is explicitly specified, we can drop other columns.
382+
self.data = self.data[[y]]
376383

377384
def _compute_plot_data(self):
378385
self.data, self.bins = HistogramPlotBase.prepare_hist_data(self.data, self.bins)

databricks/koalas/plot/plotly.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def plot_histogram(data: Union["ks.DataFrame", "ks.Series"], **kwargs):
7474
import plotly.graph_objs as go
7575

7676
bins = kwargs.get("bins", 10)
77+
x = kwargs.get("x", None)
78+
if x:
79+
if isinstance(data, ks.Series):
80+
data = data.to_frame()
81+
# Note that it's opposite with matplotlib which takes y as the bars.
82+
data = data[[x]]
7783
kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
7884
assert len(bins) > 2, "the number of buckets must be higher than 2."
7985
output_series = HistogramPlotBase.compute_hist(kdf, bins)

databricks/koalas/tests/plot/test_frame_plot_matplotlib.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,19 @@ def check_hist_plot(pdf, kdf):
402402
bin2 = self.plot_to_base64(ax2)
403403
self.assertEqual(bin1, bin2)
404404

405+
non_numeric_pdf = self.pdf1.copy()
406+
non_numeric_pdf.c = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]
407+
non_numeric_kdf = ks.from_pandas(non_numeric_pdf)
408+
ax1 = non_numeric_pdf.plot.hist(
409+
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
410+
)
411+
bin1 = self.plot_to_base64(ax1)
412+
ax2 = non_numeric_kdf.plot.hist(
413+
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
414+
)
415+
bin2 = self.plot_to_base64(ax2)
416+
self.assertEqual(bin1, bin2)
417+
405418
pdf1 = self.pdf1
406419
kdf1 = self.kdf1
407420
check_hist_plot(pdf1, kdf1)

0 commit comments

Comments
 (0)