Skip to content

Commit d30a29f

Browse files
committed
remove ml_matrics.utils.add_identity, use plt.axline instead https://git.io/JERaj
1 parent 9efee97 commit d30a29f

11 files changed

+102
-111
lines changed

assets/residual_vs_actual.svg

+1-1
Loading

ml_matrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
from .quantile import qq_gaussian
2020
from .ranking import err_decay
2121
from .relevance import precision_recall_curve, roc_curve
22-
from .utils import ROOT, add_identity, annotate_bar_heights
22+
from .utils import ROOT, annotate_bar_heights

ml_matrics/cumulative.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def cum_res(preds: NumArray, targets: NumArray, ax: Axes = None) -> None:
2424
"""Plot the empirical cumulative distribution for the residuals (y - mu).
2525
2626
Args:
27-
preds (NumArray): Numpy array of predictions.
28-
targets (NumArray): Numpy array of targets.
27+
preds (array): Numpy array of predictions.
28+
targets (array): Numpy array of targets.
2929
ax (Axes, optional): plt.Axes object. Defaults to None.
3030
"""
3131
if ax is None:
@@ -65,8 +65,8 @@ def cum_err(preds: NumArray, targets: NumArray, ax: Axes = None) -> None:
6565
"""Plot the empirical cumulative distribution for the absolute errors abs(y - y_hat).
6666
6767
Args:
68-
preds (NumArray): Numpy array of predictions.
69-
targets (NumArray): Numpy array of targets.
68+
preds (array): Numpy array of predictions.
69+
targets (array): Numpy array of targets.
7070
ax (Axes, optional): plt.Axes object. Defaults to None.
7171
"""
7272
if ax is None:

ml_matrics/histograms.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def residual_hist(
2424
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/Jmb2O).
2525
2626
Args:
27-
y_true (NumArray): ground truth targets
28-
y_pred (NumArray): model predictions
27+
y_true (array): ground truth targets
28+
y_pred (array): model predictions
2929
ax (Axes, optional): plt.Axes object. Defaults to None.
3030
xlabel (str, optional): x-axis label. Defaults to None.
3131
@@ -68,9 +68,9 @@ def true_pred_hist(
6868
predictions in that bin. Overlayed by a more transparent histogram of ground truth values.
6969
7070
Args:
71-
y_true (NumArray): ground truth targets
72-
y_pred (NumArray): model predictions
73-
y_std (NumArray): model uncertainty
71+
y_true (array): ground truth targets
72+
y_pred (array): model predictions
73+
y_std (array): model uncertainty
7474
ax (Axes, optional): plt.Axes object. Defaults to None.
7575
cmap (str, optional): string identifier of a plt colormap. Defaults to "hot".
7676
bins (int, optional): Histogram resolution. Defaults to 50.
@@ -129,7 +129,7 @@ def spacegroup_hist(spacegroups: NumArray, ax: Axes = None, **kwargs: Any) -> Ax
129129
(triclinic, monoclinic, orthorhombic, tetragonal, trigonal, hexagonal, cubic)
130130
131131
Args:
132-
spacegroups (NumArray): A list of spacegroup numbers.
132+
spacegroups (array): A list of spacegroup numbers.
133133
ax (Axes, optional): plt.Axes object. Defaults to None.
134134
kwargs: Keywords passed to pd.Series.plot.bar().
135135

ml_matrics/metrics.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def regression_metrics(
1919
TODO make robust by finding the common axis
2020
2121
Args:
22-
y_true (NumArray): Regression targets.
23-
y_preds (NumArray): Model predictions.
22+
y_true (array): Regression targets.
23+
y_preds (array): Model predictions.
2424
verbose (bool, optional): Whether to print metrics. Defaults to False.
2525
2626
Returns:
@@ -107,8 +107,8 @@ def classification_metrics(
107107
to multi-task automatically?
108108
109109
Args:
110-
target (NumArray): categorical encoding of the tasks
111-
logits (NumArray): logits predicted by the model
110+
target (array): categorical encoding of the tasks
111+
logits (array): logits predicted by the model
112112
verbose (bool, optional): Whether to print metrics. Defaults to False.
113113
"""
114114

ml_matrics/parity.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from scipy.interpolate import interpn
1111
from sklearn.metrics import r2_score
1212

13-
from ml_matrics.utils import NumArray, add_identity, with_hist
13+
from ml_matrics.utils import NumArray, with_hist
1414

1515

1616
def hist_density(
@@ -19,8 +19,8 @@ def hist_density(
1919
"""Return an approximate density of 2d points.
2020
2121
Args:
22-
xs (NumArray): x-coordinates of points
23-
ys (NumArray): y-coordinates of points
22+
xs (array): x-coordinates of points
23+
ys (array): y-coordinates of points
2424
sort (bool, optional): Whether to sort points by density so that densest points
2525
are plotted last. Defaults to True.
2626
bins (int, optional): Number of bins (histogram resolution). Defaults to 100.
@@ -76,8 +76,8 @@ def density_scatter(
7676
"""Scatter plot colored (and optionally sorted) by density.
7777
7878
Args:
79-
xs (NumArray): x values.
80-
ys (NumArray): y values.
79+
xs (array): x values.
80+
ys (array): y values.
8181
ax (Axes, optional): plt.Axes object. Defaults to None.
8282
color_map (str, optional): plt color map or valid string name. Defaults to "Blues".
8383
sort (bool, optional): Whether to sort the data. Defaults to True.
@@ -102,8 +102,12 @@ def density_scatter(
102102
norm = mpl.colors.LogNorm() if log else None
103103

104104
ax.scatter(xs, ys, c=cs, cmap=color_map, norm=norm, **kwargs)
105+
105106
if identity:
106-
add_identity(ax, label="ideal")
107+
ax.axline(
108+
(0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black"
109+
)
110+
107111
if stats:
108112
add_mae_r2_box(xs, ys, ax)
109113

@@ -128,10 +132,10 @@ def scatter_with_err_bar(
128132
i.e. if points farther from the parity line have larger uncertainty.
129133
130134
Args:
131-
xs (NumArray): x-values
132-
ys (NumArray): y-values
133-
xerr (NumArray, optional): Horizontal error bars. Defaults to None.
134-
yerr (NumArray, optional): Vertical error bars. Defaults to None.
135+
xs (array): x-values
136+
ys (array): y-values
137+
xerr (array, optional): Horizontal error bars. Defaults to None.
138+
yerr (array, optional): Vertical error bars. Defaults to None.
135139
ax (Axes, optional): plt.Axes object. Defaults to None.
136140
xlabel (str, optional): x-axis label. Defaults to "Actual".
137141
ylabel (str, optional): y-axis label. Defaults to "Predicted".
@@ -145,7 +149,10 @@ def scatter_with_err_bar(
145149

146150
styles = dict(markersize=6, fmt="o", ecolor="g", capthick=2, elinewidth=2)
147151
ax.errorbar(xs, ys, yerr=yerr, xerr=xerr, **kwargs, **styles)
148-
add_identity(ax)
152+
153+
# identity line
154+
ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")
155+
149156
add_mae_r2_box(xs, ys, ax)
150157

151158
ax.set(xlabel=xlabel, ylabel=ylabel, title=title)
@@ -166,10 +173,10 @@ def density_hexbin(
166173
dimension passed as weights.
167174
168175
Args:
169-
xs (NumArray): x values
170-
yx (NumArray): y values
176+
xs (array): x values
177+
yx (array): y values
171178
ax (Axes, optional): plt.Axes object. Defaults to None.
172-
weights (NumArray, optional): If given, these values are accumulated in the bins.
179+
weights (array, optional): If given, these values are accumulated in the bins.
173180
Otherwise, every point has value 1. Must be of the same length as x and y.
174181
Defaults to None.
175182
xlabel (str, optional): x-axis label. Defaults to "Actual".
@@ -188,7 +195,9 @@ def density_hexbin(
188195
plt.colorbar(hexbin, cax=cb_ax)
189196
cb_ax.yaxis.set_ticks_position("left")
190197

191-
add_identity(ax, label="ideal")
198+
# identity line
199+
ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")
200+
192201
add_mae_r2_box(xs, yx, ax, loc="upper left")
193202

194203
ax.set(xlabel=xlabel, ylabel=ylabel)
@@ -235,8 +244,8 @@ def residual_vs_actual(y_true: NumArray, y_pred: NumArray, ax: Axes = None) -> A
235244
(y_err = y_true - y_pred) on the y-axis.
236245
237246
Args:
238-
y_true (NumArray): Ground truth values
239-
y_pred (NumArray): Model predictions
247+
y_true (array): Ground truth values
248+
y_pred (array): Model predictions
240249
ax (Axes, optional): plt.Axes object. Defaults to None.
241250
242251
Returns:
@@ -248,11 +257,10 @@ def residual_vs_actual(y_true: NumArray, y_pred: NumArray, ax: Axes = None) -> A
248257

249258
y_err = y_true - y_pred
250259

251-
xmin = np.min(y_true) * 0.9
252-
xmax = np.max(y_true) / 0.9
253-
254260
plt.plot(y_true, y_err, "o", alpha=0.5, label=None, mew=1.2, ms=5.2)
255-
plt.plot([xmin, xmax], [0, 0], "k--", alpha=0.5, label="ideal")
261+
plt.axline(
262+
[1, 0], [2, 0], linestyle="dashed", color="black", alpha=0.5, label="ideal"
263+
)
256264

257265
plt.ylabel(r"Residual ($y_\mathrm{test} - y_\mathrm{pred}$)")
258266
plt.xlabel("Actual value")

ml_matrics/quantile.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
import matplotlib.pyplot as plt
44
import numpy as np
5+
from matplotlib.axes import Axes
56
from scipy.stats import norm
67

7-
from ml_matrics.utils import NumArray, add_identity
8+
from ml_matrics.utils import NumArray
89

910

1011
def qq_gaussian(
11-
y_true: NumArray, y_pred: NumArray, y_std: Union[NumArray, Dict[str, NumArray]]
12+
y_true: NumArray,
13+
y_pred: NumArray,
14+
y_std: Union[NumArray, Dict[str, NumArray]],
15+
ax: Axes = None,
1216
) -> None:
13-
"""Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as NumArray)
17+
"""Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array)
1418
or multiple (passed as dict) sets of uncertainty estimates for a single
1519
pair of ground truth targets `y_true` and model predictions `y_pred`.
1620
@@ -25,10 +29,13 @@ def qq_gaussian(
2529
Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot
2630
2731
Args:
28-
y_true (NumArray): ground truth targets
29-
y_pred (NumArray): model predictions
30-
y_std (NumArray | dict): model uncertainties
32+
y_true (array): ground truth targets
33+
y_pred (array): model predictions
34+
y_std (array | dict[str, array]): model uncertainties
3135
"""
36+
if ax is None:
37+
ax = plt.gca()
38+
3239
if isinstance(y_std, np.ndarray):
3340
y_std = {"std": y_std}
3441

@@ -38,35 +45,33 @@ def qq_gaussian(
3845
lines = [] # collect plotted lines to show second legend with miscalibration areas
3946
for key, std in y_std.items():
4047

41-
z_scored = (res / std).reshape(-1, 1)
48+
z_scored = (np.array(res) / std).reshape(-1, 1)
4249

4350
exp_proportions = np.linspace(0, 1, resolution)
4451
gaussian_upper_bound = norm.ppf(0.5 + exp_proportions / 2)
4552
obs_proportions = np.mean(z_scored <= gaussian_upper_bound, axis=0)
4653

47-
[line] = plt.plot(
54+
[line] = ax.plot(
4855
exp_proportions, obs_proportions, linewidth=2, alpha=0.8, label=key
4956
)
50-
plt.fill_between(
57+
ax.fill_between(
5158
exp_proportions, y1=obs_proportions, y2=exp_proportions, alpha=0.2
5259
)
5360
miscal_area = np.trapz(
5461
np.abs(obs_proportions - exp_proportions), dx=1 / resolution
5562
)
5663
lines.append([line, miscal_area])
5764

58-
add_identity(label="ideal")
59-
60-
plt.xlim(0, 1)
61-
plt.ylim(0, 1)
65+
# identity line
66+
ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")
6267

63-
plt.xlabel("Theoretical Quantile")
64-
plt.ylabel("Observed Quantile")
68+
ax.set(xlim=(0, 1), ylim=(0, 1))
69+
ax.set(xlabel="Theoretical Quantile", ylabel="Observed Quantile")
6570

6671
legend1 = plt.legend(loc="upper left", frameon=False)
6772
# Multiple legends on the same axes:
6873
# https://matplotlib.org/3.3.3/tutorials/intermediate/legend_guide.html#multiple-legends-on-the-same-axes
69-
plt.gca().add_artist(legend1)
74+
ax.add_artist(legend1)
7075

7176
lines, areas = zip(*lines)
7277

ml_matrics/ranking.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def err_decay(
5454
similarly to how it decays when removing the predictions of largest error.
5555
5656
Args:
57-
y_true (NumArray): Ground truth regression targets.
58-
y_pred (NumArray): Model predictions.
59-
y_stds (NumArray | dict[str, NumArray]): Model uncertainties. Can be a single or
57+
y_true (array): Ground truth regression targets.
58+
y_pred (array): Model predictions.
59+
y_stds (array | dict[str, NumArray]): Model uncertainties. Can be a single or
6060
multiple types (e.g. aleatoric/epistemic/total uncertainty) in dict form.
6161
title (str, optional): Plot title. Defaults to None.
6262
n_rand (int, optional): Number of shuffles from which to compute std.dev.

ml_matrics/relevance.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def roc_curve(
1515
the positive class.
1616
1717
Args:
18-
targets (NumArray): Ground truth targets.
19-
proba_pos (NumArray): predicted probabilities for the positive class.
18+
targets (array): Ground truth targets.
19+
proba_pos (array): predicted probabilities for the positive class.
2020
2121
Returns:
2222
float: The classifier's ROC area under the curve.
@@ -44,8 +44,8 @@ def precision_recall_curve(
4444
"""Plot the precision recall curve of a binary classifier.
4545
4646
Args:
47-
targets (NumArray): Ground truth targets.
48-
proba_pos (NumArray): predicted probabilities for the positive class.
47+
targets (array): Ground truth targets.
48+
proba_pos (array): predicted probabilities for the positive class.
4949
5050
Returns:
5151
float: The classifier's precision score.

ml_matrics/utils.py

+3-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from os.path import abspath, dirname
2-
from typing import Any, Sequence, Union
2+
from typing import Sequence, Union
33

44
import matplotlib.pyplot as plt
55
import numpy as np
@@ -14,29 +14,6 @@
1414
NumArray = NDArray[Union[np.float64, np.int_]]
1515

1616

17-
def add_identity(ax: Axes = None, **line_kwargs: Any) -> None:
18-
"""Add a parity line (y = x) to the provided axis."""
19-
if ax is None:
20-
ax = plt.gca()
21-
22-
# zorder=0 ensures other plotted data displays on top of line
23-
default_kwargs = dict(alpha=0.5, zorder=0, linestyle="dashed", color="black")
24-
(identity,) = ax.plot([], [], **default_kwargs, **line_kwargs)
25-
26-
def callback(axes: Axes) -> None:
27-
x_min, x_max = axes.get_xlim()
28-
y_min, y_max = axes.get_ylim()
29-
low = max(x_min, y_min)
30-
high = min(x_max, y_max)
31-
identity.set_data([low, high], [low, high])
32-
33-
callback(ax)
34-
# Register callbacks to update identity line when moving plots in interactive
35-
# mode to ensure line always extend to plot edges.
36-
ax.callbacks.connect("xlim_changed", callback)
37-
ax.callbacks.connect("ylim_changed", callback)
38-
39-
4017
def with_hist(
4118
xs: NumArray, ys: NumArray, cell: GridSpec = None, bins: int = 100 # type: ignore
4219
) -> Axes:
@@ -46,8 +23,8 @@ def with_hist(
4623
above and near the right edge.
4724
4825
Args:
49-
xs (NumArray): x values.
50-
ys (NumArray): y values.
26+
xs (array): x values.
27+
ys (array): y values.
5128
cell (GridSpec, optional): Cell of a plt GridSpec at which to add the
5229
grid of plots. Defaults to None.
5330
bins (int, optional): Resolution/bin count of the histograms. Defaults to 100.

0 commit comments

Comments
 (0)