Skip to content

Commit 0ea0ef6

Browse files
committed
add true_pred_hist to histograms.py
1 parent 39a3b15 commit 0ea0ef6

File tree

9 files changed

+108
-18
lines changed

9 files changed

+108
-18
lines changed

assets/true_pred_hist.svg

+1
Loading

mlmatrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
ptable_elemental_prevalence,
77
ptable_elemental_ratio,
88
)
9-
from .histograms import residual_hist
9+
from .histograms import residual_hist, true_pred_hist
1010
from .metrics import regression_metrics
1111
from .parity import (
1212
density_hexbin,

mlmatrics/elements.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def ptable_elemental_prevalence(
4545
"""Display the prevalence of each element in a materials dataset plotted as a
4646
heatmap over the periodic table. `formulas` xor `elem_counts` must be passed.
4747
48-
Adapted from https://github.com/kaaiian/ML_figures.
48+
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).
4949
5050
Args:
5151
formulas (list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
@@ -240,7 +240,7 @@ def hist_elemental_prevalence(
240240
) -> None:
241241
"""Plots a histogram of the prevalence of each element in a materials dataset.
242242
243-
Adapted from https://github.com/kaaiian/ML_figures.
243+
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).
244244
245245
Args:
246246
formulas (list): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]

mlmatrics/histograms.py

+83-2
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,34 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
33
from matplotlib.axes import Axes
4+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
45
from numpy import ndarray as Array
56
from scipy.stats import gaussian_kde
67

78

89
def residual_hist(
9-
y_true: Array, y_pred: Array, ax: Axes = None, xlabel: str = None
10+
y_true: Array, y_pred: Array, ax: Axes = None, xlabel: str = None, **kwargs
1011
) -> Axes:
12+
"""Plot the residual distribution overlayed with a Gaussian kernel
13+
density estimate.
14+
15+
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/Jmb2O).
16+
17+
Args:
18+
y_true (Array): ground truth targets
19+
y_pred (Array): model predictions
20+
ax (Axes, optional): plt axes. Defaults to None.
21+
xlabel (str, optional): x-axis label. Defaults to None.
22+
23+
Returns:
24+
Axes: plt axes with plotted data.
25+
"""
1126

1227
if ax is None:
1328
ax = plt.gca()
1429

1530
y_res = y_pred - y_true
16-
plt.hist(y_res, bins=35, density=True, edgecolor="black")
31+
plt.hist(y_res, bins=35, density=True, edgecolor="black", **kwargs)
1732

1833
# Gaussian kernel density estimation: evaluates the Gaussian
1934
# probability density estimated based on the points in y_res
@@ -27,3 +42,69 @@ def residual_hist(
2742
plt.legend(loc=2, framealpha=0.5, handlelength=1)
2843

2944
return ax
45+
46+
47+
def true_pred_hist(
48+
y_true: Array,
49+
y_pred: Array,
50+
y_std: Array,
51+
ax: Axes = None,
52+
cmap: str = "hot",
53+
bins: int = 50,
54+
log: bool = True,
55+
truth_color: str = "blue",
56+
**kwargs,
57+
) -> Axes:
58+
"""Plot a histogram of model predictions with bars colored by the average uncertainty of
59+
predictions in that bin. Overlayed by a more transparent histogram of ground truth values.
60+
61+
Args:
62+
y_true (Array): ground truth targets
63+
y_pred (Array): model predictions
64+
y_std (Array): model uncertainty
65+
ax (Axes, optional): plt axes. Defaults to None.
66+
cmap (str, optional): string identifier of a plt colormap. Defaults to "hot".
67+
bins (int, optional): Histogram resolution. Defaults to 50.
68+
log (bool, optional): Whether to log-scale the y-axis. Defaults to True.
69+
truth_color (str, optional): Face color to use for y_true bars. Defaults to "blue".
70+
71+
Returns:
72+
Axes: plt axes with plotted data.
73+
"""
74+
75+
if ax is None:
76+
ax = plt.gca()
77+
78+
cmap = getattr(plt.cm, cmap)
79+
y_true, y_pred, y_std = np.array([y_true, y_pred, y_std])
80+
81+
_, bins, bars = ax.hist(
82+
y_pred, bins=bins, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs
83+
)
84+
ax.hist(
85+
y_true,
86+
bins=bins,
87+
alpha=0.2,
88+
color=truth_color,
89+
label=r"$y_\mathrm{true}$",
90+
**kwargs,
91+
)
92+
93+
for xmin, xmax, rect in zip(bins, bins[1:], bars.patches):
94+
95+
y_preds_in_rect = np.logical_and(y_pred > xmin, y_pred < xmax).nonzero()
96+
97+
color_value = y_std[y_preds_in_rect].mean()
98+
99+
rect.set_color(cmap(color_value))
100+
101+
if log:
102+
plt.yscale("log")
103+
ax.legend(frameon=False)
104+
cb_ax = inset_axes(ax, width="3%", height="50%", loc="center right")
105+
106+
norm = plt.cm.colors.Normalize(vmax=y_std.max(), vmin=y_std.min())
107+
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cb_ax)
108+
cb_ax.yaxis.set_ticks_position("left")
109+
110+
return ax

mlmatrics/parity.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def density_scatter(
9090
Defaults to True.
9191
9292
Returns:
93-
Axes: plt axes containing the plot.
93+
Axes: plt axes with plotted data.
9494
"""
9595
if ax is None:
9696
ax = plt.gca()
@@ -137,7 +137,7 @@ def scatter_with_err_bar(
137137
title (str, optional): Plot tile. Defaults to None.
138138
139139
Returns:
140-
Axes: plt axes on which the data was plotted.
140+
Axes: plt axes with plotted data.
141141
"""
142142
if ax is None:
143143
ax = plt.gca()
@@ -169,7 +169,7 @@ def density_hexbin(
169169

170170
# the scatter plot
171171
hexbin = ax.hexbin(targets, preds, gridsize=75, mincnt=1, bins="log", C=color_map)
172-
cb_ax = inset_axes(ax, width="3%", height="70%", loc=4)
172+
cb_ax = inset_axes(ax, width="3%", height="70%", loc="lower right")
173173
plt.colorbar(hexbin, cax=cb_ax)
174174
cb_ax.yaxis.set_ticks_position("left")
175175

readme.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ numpy==1.20.1
3232
git+git://github.com/janosh/mlmatrics
3333
```
3434

35-
To specify a certain branch or commit, append it's name or hash, e.g.
35+
To specify a specific branch or commit, append its name or hash, e.g.
3636

3737
```txt
3838
git+git://github.com/janosh/mlmatrics@master # default
@@ -107,9 +107,9 @@ See [`mlmatrics/correlation.py`](mlmatrics/correlation.py).
107107

108108
See [`mlmatrics/histograms.py`](mlmatrics/histograms.py).
109109

110-
| [`residual_hist(y_true, y_pred)`](mlmatrics/histograms.py) | |
111-
| :--------------------------------------------------------: | :---: |
112-
| ![residual_hist](assets/residual_hist.svg) | |
110+
| [`residual_hist(y_true, y_pred)`](mlmatrics/histograms.py) | [`true_pred_hist(y_true, y_pred, y_std)`](mlmatrics/histograms.py) |
111+
| :--------------------------------------------------------: | :----------------------------------------------------------------: |
112+
| ![residual_hist](assets/residual_hist.svg) | ![true_pred_hist](assets/true_pred_hist.svg) |
113113

114114
## Adding Assets
115115

@@ -130,7 +130,7 @@ python -m pytest tests/test_cumulative.py
130130
python -m pytest **/test_*_metrics.py
131131
```
132132

133-
You can also run single tests by passing its name to the `-k` flag:
133+
To run a single test, pass its name to the `-k` flag:
134134

135135
```sh
136136
python -m pytest -k test_precision_recall_curve
@@ -140,6 +140,6 @@ Consult the [`pytest`](https://docs.pytest.org/en/stable/usage.html) docs for mo
140140

141141
## Glossary
142142

143-
1. **Residual** `y - y_hat`: The difference between ground truth target and model prediction.
144-
2. **Error** `abs(y - y_hat)`: Absolute error between target and model prediction.
145-
3. **Uncertainty** `y_std`: The model's estimate for its own error, i.e. how much the model thinks its prediction can be trusted. (`std` for standard deviation.)
143+
1. **Residual** `y_res = y_true - y_pred`: The difference between ground truth target and model prediction.
144+
2. **Error** `y_err = abs(y_true - y_pred)`: Absolute error between target and model prediction.
145+
3. **Uncertainty** `y_std`: The model's estimate for its error, i.e. how much the model thinks its prediction can be trusted. (`std` for standard deviation.)

scripts/plot_all.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
residual_vs_actual,
2323
roc_curve,
2424
scatter_with_err_bar,
25+
true_pred_hist,
2526
)
2627

2728
plt.rcParams.update({"font.size": 20})
@@ -158,6 +159,9 @@ def savefig(filename: str) -> None:
158159
residual_hist(y_true, y_pred)
159160
savefig("residual_hist")
160161

162+
true_pred_hist(y_true, y_pred, y_std)
163+
savefig("true_pred_hist")
164+
161165

162166
# %% Correlation Plots
163167
rand_wide_mat = pd.read_csv(f"{ROOT}/data/rand_wide_matrix.csv", header=None).to_numpy()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
version="0.0.1",
66
author="Janosh Riebesell",
77
author_email="[email protected]",
8-
description="A collection of plots useful in data-driven research of materials",
8+
description="A collection of plots useful in data-driven materials science",
99
long_description=open("readme.md").read(),
1010
long_description_content_type="text/markdown",
1111
url="https://github.com/janosh/mlmatrics",

tests/test_histograms.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from mlmatrics import residual_hist
1+
from mlmatrics import residual_hist, true_pred_hist
22

33
from . import y_pred, y_true
44

55

66
def test_residual_hist():
77
residual_hist(y_true, y_pred)
8+
9+
10+
def test_true_pred_hist():
11+
true_pred_hist(y_true, y_pred, y_true - y_pred)

0 commit comments

Comments
 (0)