Skip to content

Commit 39fa65a

Browse files
committed
add kwarg intersect_lines: str | Sequence[str] = () to precision_recall_vs_calc_count()
1 parent 7ed1b09 commit 39fa65a

File tree

5 files changed

+65
-24
lines changed

5 files changed

+65
-24
lines changed

.pre-commit-config.yaml

+2-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ repos:
2020
rev: 4.0.1
2121
hooks:
2222
- id: flake8
23+
additional_dependencies: [flake8-bugbear]
2324

2425
- repo: https://github.com/asottile/pyupgrade
2526
rev: v2.34.0
@@ -56,13 +57,7 @@ repos:
5657
stages: [commit, commit-msg]
5758
exclude_types: [csv, html, json]
5859

59-
- repo: https://github.com/myint/autoflake
60+
- repo: https://github.com/PyCQA/autoflake
6061
rev: v1.4
6162
hooks:
6263
- id: autoflake
63-
args:
64-
- --in-place
65-
- --remove-unused-variables
66-
- --remove-all-unused-imports
67-
- --expand-star-imports
68-
- --ignore-init-module-imports

mb_discovery/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,11 @@
193193
df_m3gnet_is2re["m3gnet_energy_rs2re"] = df_m3gnet_rs2re.m3gnet_energy
194194

195195
for task_type in ["is2re", "rs2re"]:
196-
e_per_atom = df_m3gnet_is2re[f"m3gnet_energy_{task_type}"] / df_m3gnet_is2re.n_sites
196+
energy_per_atom = (
197+
df_m3gnet_is2re[f"m3gnet_energy_{task_type}"] / df_m3gnet_is2re.n_sites
198+
)
197199

198-
df_m3gnet_is2re[f"e_m3gnet_per_atom_{task_type}"] = e_per_atom
200+
df_m3gnet_is2re[f"e_m3gnet_per_atom_{task_type}"] = energy_per_atom
199201

200202
fig = px.scatter(
201203
df_m3gnet_is2re,

mb_discovery/plot_scripts/plot_funcs.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,30 @@ def precision_recall_vs_calc_count(
261261
# in eV / atom, usually 0 or 0.1 eV
262262
ax: plt.Axes = None,
263263
label: str = None,
264+
intersect_lines: str | Sequence[str] = (),
264265
**kwargs: Any,
265266
) -> plt.Axes:
266-
"""Precision and recall as a function of the number of calculations performed."""
267+
"""Precision and recall as a function of the number of calculations performed.
268+
269+
Args:
270+
df (pd.DataFrame): Model predictions and target energy values.
271+
residual_col (str, optional): Column name with residuals of model predictions,
272+
i.e. residual = pred - target. Defaults to "residual".
273+
e_above_hull_col (str, optional): Column name with convex hull distance values.
274+
Defaults to "e_above_hull".
275+
criterion (Literal['energy', 'std', 'neg_std'], optional): Whether to use
276+
energy, energy+model_std, or energy-model_std as stability criterion.
277+
Defaults to "energy".
278+
stability_thresh (float, optional): Max distance from convex hull before
279+
material is considered unstable. Defaults to 0.
280+
label (str, optional): Model name used to identify its liens in the legend.
281+
Defaults to None.
282+
intersect_lines (Sequence[str], optional): precision_{x,y,xy} and/or
283+
recall_{x,y,xy}. Defaults to (), i.e. no intersect lines.
284+
285+
Returns:
286+
plt.Axes: The matplotlib axes object.
287+
"""
267288
if ax is None:
268289
ax = plt.gca()
269290

@@ -315,7 +336,7 @@ def precision_recall_vs_calc_count(
315336
rolling_recall_curve = scipy.interpolate.interp1d(xs, tpr[:end], kind="cubic")
316337

317338
line_kwargs = dict(
318-
linewidth=3,
339+
linewidth=4,
319340
markevery=[-1],
320341
marker="x",
321342
markersize=14,
@@ -326,6 +347,22 @@ def precision_recall_vs_calc_count(
326347
ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs)
327348
ax.plot((0, 0), (0, 0), label=label, **line_kwargs)
328349

350+
if intersect_lines == "all":
351+
intersect_lines = ("precision_xy", "recall_xy")
352+
for line_name in intersect_lines:
353+
y_func = dict(
354+
precision=precision_curve,
355+
recall=rolling_recall_curve,
356+
)[line_name.split("_")[0]]
357+
intersect_kwargs = dict(
358+
linestyle=":", alpha=0.4, color=kwargs.get("color", "gray")
359+
)
360+
# Add some visual guidelines
361+
if "x" in line_name:
362+
ax.plot((0, xs[-1]), (y_func(xs[-1]), y_func(xs[-1])), **intersect_kwargs)
363+
if "y" in line_name:
364+
ax.plot((xs[-1], xs[-1]), (0, y_func(xs[-1])), **intersect_kwargs)
365+
329366
if not is_fresh_ax:
330367
# return earlier if all plot objects besides the line were already drawn by a
331368
# previous call

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
df_hull = pd.read_csv(
2424
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
2525
).set_index("material_id")
26-
dfs: dict[str, pd.DataFrame] = {}
2726

27+
dfs: dict[str, pd.DataFrame] = {}
2828
for model_name in ("Wren", "CGCNN", "Voronoi"):
2929
dfs[model_name] = pd.read_csv(
3030
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
@@ -34,17 +34,21 @@
3434
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
3535
# ).set_index("material_id")
3636

37-
# dfs["Wrenformer"] = pd.read_csv(
38-
# f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
39-
# ).set_index("material_id")
37+
dfs["Wrenformer"] = pd.read_csv(
38+
f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
39+
).set_index("material_id")
40+
4041

41-
# dfs["Wrenformer"]["e_form_target"] = dfs["Wren"]["e_form_target"]
42-
# dfs["M3GNet"]["e_form_target"] = dfs["Wren"]["e_form_target"]
42+
# download wbm-steps-summary.csv (23.31 MB)
43+
df_summary = pd.read_csv(
44+
"https://figshare.com/ndownloader/files/36714216?private_link=ff0ad14505f9624f0c05"
45+
).set_index("material_id")
4346

4447

4548
# %%
4649
for (model_name, df), color in zip(
47-
dfs.items(), ("tab:blue", "tab:orange", "teal", "tab:pink", "black")
50+
dfs.items(),
51+
("tab:blue", "tab:orange", "teal", "tab:pink", "black", "red", "turquoise"),
4852
):
4953
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
5054

@@ -66,9 +70,11 @@
6670
if model_name == "M3GNet":
6771
model_preds = df.e_form_m3gnet
6872
targets = df.e_form_wbm
69-
elif model_name == "Wrenformer":
70-
model_preds = df.e_form_pred_ens
71-
targets = df.e_form
73+
elif "Wrenformer" in model_name:
74+
df["e_form_per_atom_pred_ens"] = df.e_form_pred_ens / df.n_sites
75+
df["e_form_per_atom"] = df.e_form / df.n_sites
76+
model_preds = df.e_form_per_atom_pred_ens
77+
targets = df.e_form_per_atom
7278
elif df.filter(regex=r"_pred_\d").shape[1] > 1:
7379
assert df.filter(regex=r"_pred_\d").shape[1] == 10
7480
model_preds = df.filter(regex=r"_pred_\d").mean(axis=1)
@@ -89,10 +95,11 @@
8995
e_above_hull_col="e_above_mp_hull",
9096
color=color,
9197
label=model_name,
98+
intersect_lines="recall_xy",
99+
# intersect_lines="all",
92100
)
93101

94-
model_legend = ax.legend(frameon=False, loc="lower right")
95-
ax.add_artist(model_legend)
102+
ax.legend(frameon=False, loc="lower right")
96103

97104
ax.figure.set_size_inches(10, 9)
98105

readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# ML Stability
1+
# Matbench Discovery

0 commit comments

Comments
 (0)