Skip to content

Commit 6e36278

Browse files
committed
add cumulative F1 score to cumulative precision recall plot
update figures/2022-12-05-precision-recall-curves.svelte and render with neg margins to increase width
1 parent 359dfae commit 6e36278

File tree

10 files changed

+96
-41
lines changed

10 files changed

+96
-41
lines changed
27.6 KB
Loading

data/wbm/readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The full set of processing steps used to curate the WBM test set from the raw da
2020
- correctly aligning initial structures to DFT-relaxed [`ComputedStructureEntries`](https://pymatgen.org/pymatgen.entries.computed_entries.html#pymatgen.entries.computed_entries.ComputedStructureEntry)
2121
- remove 6 pathological structures (with 0 volume)
2222
- remove formation energy outliers below -5 and above 5 eV/atom (502 and 22 crystals respectively out of 257,487 total, including an anomaly of 500 structures at exactly -10 eV/atom)
23-
<!-- ![WBM formation energy histogram indicating outlier cutoffs](2022-12-07-hist-e-form-per-atom.png) -->
23+
![WBM formation energy histogram indicating outlier cutoffs](2022-12-07-hist-e-form-per-atom.png)
2424
- apply the [`MaterialsProject2020Compatibility`](https://pymatgen.org/pymatgen.entries.compatibility.html#pymatgen.entries.compatibility.MaterialsProject2020Compatibility) energy correction scheme to the formation energies
2525
- compute energy to the Materials Project convex hull constructed from all MP `ComputedStructureEntries` queried on 2022-09-16 ([database release 2021.05.13](https://docs.materialsproject.org/changes/database-versions#v2021.05.13))
2626

matbench_discovery/plots.py

+44-20
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ def cumulative_precision_recall(
489489
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and
490490
dataframe of cumulative metrics for each model.
491491
"""
492-
fact = lambda: pd.DataFrame(index=range(len(e_above_hull_true)))
493-
dfs = dict(Precision=fact(), Recall=fact())
492+
factory = lambda: pd.DataFrame(index=range(len(e_above_hull_true)))
493+
dfs = dict(Precision=factory(), Recall=factory(), F1=factory())
494494

495495
for model_name in df_preds:
496496
model_preds = df_preds[model_name].sort_values()
@@ -502,36 +502,43 @@ def cumulative_precision_recall(
502502

503503
true_pos_cumsum = true_pos.cumsum()
504504
# precision aka positive predictive value (PPV)
505-
precision = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum())
505+
precision_cum = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum())
506506
n_total_pos = sum(true_pos) + sum(false_neg)
507-
recall = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
507+
recall_cum = true_pos_cumsum / n_total_pos # aka true_pos_rate aka sensitivity
508508

509-
end = int(np.argmax(recall))
509+
end = int(np.argmax(recall_cum))
510510
xs = np.arange(end)
511511

512-
prec_interp = scipy.interpolate.interp1d(xs, precision[:end], kind="cubic")
513-
recall_interp = scipy.interpolate.interp1d(xs, recall[:end], kind="cubic")
512+
# cumulative F1 score
513+
f1_cum = 2 * (precision_cum * recall_cum) / (precision_cum + recall_cum)
514+
515+
prec_interp = scipy.interpolate.interp1d(xs, precision_cum[:end], kind="cubic")
516+
recall_interp = scipy.interpolate.interp1d(xs, recall_cum[:end], kind="cubic")
517+
f1_interp = scipy.interpolate.interp1d(xs, f1_cum[:end], kind="cubic")
514518
dfs["Precision"][model_name] = pd.Series(prec_interp(xs))
515519
dfs["Recall"][model_name] = pd.Series(recall_interp(xs))
520+
dfs["F1"][model_name] = pd.Series(f1_interp(xs))
516521

517522
for key, df in dfs.items():
518523
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
519524
# predicted materials by any model
520525
df.dropna(how="all", inplace=True)
526+
# will be used as facet_col in plotly to split different metrics into subplots
521527
df["metric"] = key
522528

523529
df_cum = pd.concat(dfs.values())
530+
# subselect rows for speed, plot has sufficient precision with 1k rows
531+
df_cum = df_cum.iloc[:: len(df_cum) // 1000 or 1]
524532

525533
if backend == "matplotlib":
526-
fig, axs = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
534+
fig, axs = plt.subplots(1, len(dfs), figsize=(15, 7), sharey=True)
527535
line_kwargs = dict(
528536
linewidth=3, markevery=[-1], marker="x", markersize=14, markeredgewidth=2.5
529537
)
530538
for (key, df), ax in zip(dfs.items(), axs):
531539
# select every n-th row of df so that 1000 rows are left for increased
532540
# plotting speed and reduced file size
533541
# falls back on every row if df has less than 1000 rows
534-
535542
df.iloc[:: len(df) // 1000 or 1].plot(
536543
ax=ax, legend=False, backend=backend, **line_kwargs | kwargs, ylabel=key
537544
)
@@ -541,9 +548,12 @@ def cumulative_precision_recall(
541548
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
542549
assert len(axs) == len(dfs), f"{len(axs)} != {len(dfs)}"
543550

544-
for ax, df in zip(axs, dfs.values()):
551+
for ax, (key, df) in zip(axs.flat, dfs.items()):
545552
ax.set(ylim=(0, 1), xlim=(0, None), ylabel=key)
546553
for model in df_preds:
554+
# TODO is this if really necessary?
555+
if len(df[model].dropna()) == 0:
556+
continue
547557
x_end = df[model].dropna().index[-1]
548558
y_end = df[model].dropna().iloc[-1]
549559
# place model name at the end of every line
@@ -556,11 +566,12 @@ def cumulative_precision_recall(
556566
# optimal recall line finds all stable materials without any false positives
557567
# can be included to confirm all models start out of with near optimal recall
558568
# and to see how much each model overshoots total n_stable
559-
n_below_hull = sum(e_above_hull_true < 0)
560569
if show_optimal:
570+
ax = next(filter(lambda ax: ax.get_ylabel() == "Recall", axs.flat))
571+
n_below_hull = sum(e_above_hull_true < 0)
561572
opt_label = "Optimal Recall"
562-
axs[1].plot([0, n_below_hull], [0, 1], color="green", linestyle="--")
563-
axs[1].text(
573+
ax.plot([0, n_below_hull], [0, 1], color="green", linestyle="--")
574+
ax.text(
564575
*[n_below_hull, 0.81],
565576
opt_label,
566577
color="green",
@@ -571,16 +582,29 @@ def cumulative_precision_recall(
571582
)
572583

573584
elif backend == "plotly":
574-
fig = df_cum.iloc[:: len(df_cum) // 1000 or 1].plot(
575-
backend=backend, facet_col="metric", **kwargs
585+
fig = df_cum.plot(
586+
backend=backend,
587+
facet_col="metric",
588+
facet_col_wrap=3,
589+
facet_col_spacing=0.03,
590+
# pivot df in case we want to show all 3 metrics in each plot's hover
591+
# requires fixing index mismatch due to df subsampling above
592+
# customdata=dict(
593+
# df_cum.reset_index()
594+
# .pivot(index="index", columns="metric")["Voronoi RF above hull pred"]
595+
# .items()
596+
# ),
597+
**kwargs,
576598
)
577599
fig.update_traces(line=dict(width=4))
578-
for idx in range(1, 3):
579-
fig.update_xaxes(
580-
title_text="Number of materials predicted stable", row=1, col=idx
600+
for idx, metric in enumerate(df_cum.metric.unique(), 1):
601+
x_axis_label = "Number of materials predicted stable" if idx == 2 else ""
602+
fig.update_xaxes(title=x_axis_label, col=idx)
603+
fig.update_yaxes(title=dict(text=metric, standoff=0), col=idx)
604+
fig.update_traces(
605+
hovertemplate=f"Index = %{{x:d}}<br>{metric} = %{{y:.2f}}",
606+
col=idx, # model = %{customdata[0]}<br>
581607
)
582-
fig.update_yaxes(title="Precision", col=1)
583-
fig.update_yaxes(title="Recall", col=2)
584608
fig.for_each_annotation(lambda a: a.update(text=""))
585609
fig.update_layout(legend=dict(title=""))
586610
fig.update_layout(showlegend=False)

scripts/precision_recall.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# %%
2-
from sklearn.metrics import f1_score
2+
import pandas as pd
33

44
from matbench_discovery import ROOT, today
55
from matbench_discovery.data import load_df_wbm_with_preds
@@ -11,8 +11,8 @@
1111

1212
# %%
1313
models = (
14-
"Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi RF, "
15-
"Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
14+
# Wren, CGCNN IS2RE, CGCNN RS2RE
15+
"Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
1616
).split(", ")
1717

1818
df_wbm = load_df_wbm_with_preds(models=models).round(3)
@@ -23,21 +23,19 @@
2323

2424

2525
# %%
26+
df_e_above_hull_pred = pd.DataFrame()
2627
for model in models:
27-
pred_col = f"{model}_e_form"
28-
F1 = f1_score(df_wbm[e_above_hull_col] < 0, df_wbm[model] < 0)
29-
plot_label = f"{model} {F1=:.2}"
30-
df_wbm[plot_label] = df_wbm[e_above_hull_col] + df_wbm[model] - df_wbm[target_col]
28+
e_above_hul_pred = df_wbm[e_above_hull_col] + df_wbm[model] - df_wbm[target_col]
29+
df_e_above_hull_pred[model] = e_above_hul_pred
3130

3231
fig, df_metric = cumulative_precision_recall(
3332
e_above_hull_true=df_wbm[e_above_hull_col],
34-
df_preds=df_wbm.filter(like="F1="),
33+
df_preds=df_e_above_hull_pred,
3534
project_end_point="xy",
3635
backend=(backend := "plotly"),
3736
show_optimal=True,
3837
)
3938

40-
4139
title = f"{today} - Cumulative Precision and Recall for Stable Materials"
4240
# xlabel_cumulative = "Materials predicted stable sorted by hull distance"
4341
if backend == "matplotlib":
@@ -46,7 +44,6 @@
4644
elif backend == "plotly":
4745
fig.update_layout(title=title)
4846

49-
5047
fig.show()
5148

5249

site/.eslintrc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ settings:
1818
svelte3/typescript: true
1919
rules:
2020
indent: [error, 2, SwitchCase: 1]
21-
# '@typescript-eslint/quotes': [error, backtick, avoidEscape: true]
21+
'@typescript-eslint/quotes': [error, backtick, avoidEscape: true]
2222
semi: [error, never]
2323
linebreak-style: [error, unix]
2424
no-console: [error, allow: [warn, error]]

site/src/routes/+layout.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import '../app.css'
88
99
const routes = Object.keys(import.meta.glob(`./*/+page.{svx,svelte,md}`)).map(
10-
(filename) => '/' + filename.split(`/`)[1]
10+
(filename) => `/` + filename.split(`/`)[1]
1111
)
1212
</script>
1313

site/src/routes/+page.svelte

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
<script lang="ts">
2-
import Plot from '$root/figures/2022-12-05-precision-recall-curves.svelte'
2+
import Plot from '$root/figures/2022-12-25-precision-recall-curves.svelte'
33
import Readme from '$root/readme.md'
44
</script>
55

66
<Readme />
77

8-
{#if typeof document !== 'undefined'}
9-
<Plot />
8+
{#if typeof document !== `undefined`}
9+
<Plot style="margin: 0 -5vw;" />
1010
{/if}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<script lang="ts">
2+
import DataReadme from '$root/data/wbm/readme.md'
3+
import { onMount } from 'svelte'
4+
5+
const figs = import.meta.glob(`$root/data/wbm/*.{png,svg,pdf}`, {
6+
eager: true,
7+
as: `url`,
8+
})
9+
10+
onMount(() => {
11+
for (const img of document.querySelectorAll(`img`)) {
12+
const src = img.getAttribute(`src`)
13+
if (figs[`../data/wbm/${src}`]) {
14+
img.src = figs[`../data/wbm/${src}`]
15+
}
16+
}
17+
})
18+
</script>
19+
20+
<main>
21+
<DataReadme />
22+
</main>
23+
24+
<style>
25+
:global(img) {
26+
max-width: 100%;
27+
margin: 1em auto;
28+
}
29+
</style>

site/svelte.config.js

+4
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,9 @@ export default {
3333

3434
kit: {
3535
adapter: adapter(),
36+
37+
prerender: {
38+
handleHttpError: `warn`,
39+
},
3640
},
3741
}

tests/test_plots.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def test_cumulative_precision_recall(
4444

4545
if backend == "matplotlib":
4646
assert isinstance(fig, plt.Figure)
47-
ax1, ax2 = fig.axes
48-
assert ax1.get_ylim() == ax2.get_ylim() == (0, 1)
49-
assert ax1.get_ylabel() == "Recall"
50-
# TODO ax2 ylabel also 'Recall', should be 'Precision'
51-
# assert ax2.get_ylabel() == "Precision"
47+
assert all(ax.get_ylim() == (0, 1) for ax in fig.axes)
48+
assert (
49+
[ax.get_ylabel() for ax in fig.axes]
50+
== list(df_metrics.metric.unique())
51+
== ["Precision", "Recall", "F1"]
52+
)
5253
elif backend == "plotly":
5354
assert isinstance(fig, go.Figure)
5455
assert fig.layout.yaxis1.title.text == "Precision"

0 commit comments

Comments
 (0)