Skip to content

Scatter matrix diagonals #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions doc/source/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,8 @@ Scatter plot matrix
from pandas.tools.plotting import scatter_matrix
df = DataFrame(np.random.randn(1000, 4), columns=['a', 'b', 'c', 'd'])

@savefig scatter_matrix_ex.png width=6in
scatter_matrix(df, alpha=0.2, figsize=(8, 8))
@savefig scatter_matrix_kde.png width=6in
scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='kde')

@savefig scatter_matrix_hist.png width=6in
scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='hist')
2 changes: 2 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def scat(**kwds):
_check_plot_works(scat)
_check_plot_works(scat, marker='+')
_check_plot_works(scat, vmin=0)
_check_plot_works(scat, diagonal='kde')
_check_plot_works(scat, diagonal='hist')

def scat2(x, y, by=None, ax=None, figsize=None):
return plt.scatter_plot(df, x, y, by, ax, figsize=None)
Expand Down
32 changes: 22 additions & 10 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import izip

import numpy as np
from scipy import stats

from pandas.util.decorators import cache_readonly
import pandas.core.common as com
Expand All @@ -12,12 +13,19 @@
from pandas.tseries.offsets import DateOffset

def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
**kwds):
diagonal='hist', **kwds):
"""
Draw a matrix of scatter plots.

Parameters
----------
alpha : amount of transparency applied
figsize : a tuple (width, height) in inches
ax : Matplotlib axis object
grid : setting this to True will show the grid
diagonal : pick between 'kde' and 'hist' for
either Kernel Density Estimation or Histogram
plon in the diagonal
kwds : other plotting keyword arguments
To be passed to scatter function

Expand All @@ -36,15 +44,26 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,

for i, a in zip(range(n), df.columns):
for j, b in zip(range(n), df.columns):
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)
if i == j:
# Deal with the diagonal by drawing a histogram there.
if diagonal == 'hist':
axes[i, j].hist(df[a])
elif diagonal == 'kde':
y = df[a]
gkde = stats.gaussian_kde(y)
ind = np.linspace(min(y), max(y), 1000)
axes[i, j].plot(ind, gkde.evaluate(ind), **kwds)
else:
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)

axes[i, j].set_xlabel('')
axes[i, j].set_ylabel('')
axes[i, j].set_xticklabels([])
axes[i, j].set_yticklabels([])
ticks = df.index

is_datetype = ticks.inferred_type in ('datetime', 'date',
'datetime64')
'datetime64')

if ticks.is_numeric() or is_datetype:
"""
Expand Down Expand Up @@ -87,13 +106,6 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,

axes[i, j].grid(b=grid)

# ensure {x,y}lim off diagonal are the same as diagonal
for i in range(n):
for j in range(n):
if i != j:
axes[i, j].set_xlim(axes[j, j].get_xlim())
axes[i, j].set_ylim(axes[i, i].get_ylim())

return axes

def _gca():
Expand Down