Skip to content

Commit f50e1b3

Browse files
committed
add m3gnet/plots/2022-08-16-m3gnet-wbm-hull-dist-hist.pdf
also tweak hist_classify_stable_as_func_of_hull_dist()
1 parent 1e96458 commit f50e1b3

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

ml_stability/m3gnet/join_and_plot_m3gnet_relax_results.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from urllib.request import urlopen
1010

1111
import pandas as pd
12-
from diel_frontier.patched_phase_diagram import load_ppd
1312
from pymatgen.analysis.phase_diagram import PDEntry
1413
from pymatgen.core import Structure
1514
from tqdm import tqdm
@@ -64,8 +63,6 @@
6463
zipped_file = urlopen(ppd_pickle_url)
6564
ppd_mp_wbm = pickle.load(io.BytesIO(gzip.decompress(zipped_file.read())))
6665

67-
ppd_mp_wbm = load_ppd("ppd-mp+wbm-2022-01-25.pkl.gz")
68-
6966

7067
df_m3gnet["m3gnet_structure"] = df_m3gnet.m3gnet_structure.map(Structure.from_dict)
7168
df_m3gnet["pd_entry"] = [
@@ -102,4 +99,6 @@
10299
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
103100
)
104101

105-
ax_hull_dist_hist.figure.savefig(f"{ROOT}/data/{today}-m3gnet-wbm-hull-dist-hist.pdf")
102+
ax_hull_dist_hist.figure.savefig(
103+
f"{ROOT}/ml_stability/m3gnet/plots/{today}-m3gnet-wbm-hull-dist-hist.pdf"
104+
)

ml_stability/plots/plot_funcs.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import matplotlib.pyplot as plt
55
import pandas as pd
6+
from matplotlib.offsetbox import AnchoredText
67

78

89
__author__ = "Janosh Riebesell"
@@ -23,6 +24,7 @@
2324
plt.rc("savefig", bbox="tight", dpi=200)
2425
plt.rcParams["figure.constrained_layout.use"] = True
2526
plt.rc("figure", dpi=150)
27+
plt.rc("font", size=14)
2628

2729

2830
def hist_classify_stable_as_func_of_hull_dist(
@@ -34,6 +36,7 @@ def hist_classify_stable_as_func_of_hull_dist(
3436
std_vals: pd.Series = None,
3537
criterion: Literal["energy", "std", "neg"] = "energy",
3638
energy_type: Literal["true", "pred"] = "true",
39+
annotate_all_stats: bool = False,
3740
) -> plt.Axes:
3841
"""
3942
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -77,7 +80,6 @@ def hist_classify_stable_as_func_of_hull_dist(
7780
n_false_neg = len(e_above_hull_vals[actual_pos & model_neg])
7881

7982
n_total_pos = n_true_pos + n_false_neg
80-
null = n_total_pos / len(e_above_hull_vals)
8183

8284
# --- histogram by DFT-computed distance to convex hull
8385
if energy_type == "true":
@@ -121,29 +123,35 @@ def hist_classify_stable_as_func_of_hull_dist(
121123
len(false_neg),
122124
)
123125
# null = (tp + fn) / (tp + tn + fp + fn)
124-
ppv = n_true_pos / (n_true_pos + n_false_pos)
125-
tpr = n_true_pos / n_total_pos
126-
f1 = 2 * ppv * tpr / (ppv + tpr)
126+
Null = n_total_pos / len(e_above_hull_vals)
127+
PPV = n_true_pos / (n_true_pos + n_false_pos)
128+
TPR = n_true_pos / n_total_pos
129+
F1 = 2 * PPV * TPR / (PPV + TPR)
127130

128131
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len(
129132
formation_energy_targets
130133
)
131134

132-
print(f"PPV: {ppv:.2f}")
133-
print(f"TPR: {tpr:.2f}")
134-
print(f"F1: {f1:.2f}")
135-
print(f"Enrich: {ppv/null:.2f}")
136-
print(f"Null: {null:.2f}")
137-
138135
RMSE = (error**2.0).mean() ** 0.5
139136
MAE = error.abs().mean()
140-
print(f"{MAE=:.3}")
141-
print(f"{RMSE=:.3}")
142137

143138
# anno_text = f"Prevalence = {null:.2f}\nPrecision = {ppv:.2f}\nRecall = {tpr:.2f}",
144-
anno_text = f"Enrichment\nFactor = {ppv/null:.1f}"
145-
146-
ax.text(0.75, 0.9, anno_text, transform=ax.transAxes, fontsize=20)
139+
anno_text = f"Enrichment Factor = {PPV/Null:.3}"
140+
if annotate_all_stats:
141+
anno_text += f"\n{MAE = :.3}\n{RMSE = :.3}\n{Null = :.3}\n{TPR = :.3}"
142+
else:
143+
print(f"{PPV = :.3}")
144+
print(f"{TPR = :.3}")
145+
print(f"{F1 = :.3}")
146+
print(f"Enrich: {PPV/Null:.2f}")
147+
print(f"{Null = :.3}")
148+
print(f"{MAE = :.3}")
149+
print(f"{RMSE = :.3}")
150+
151+
text_box = AnchoredText(
152+
anno_text, loc="upper right", frameon=False, prop=dict(fontsize=16)
153+
)
154+
ax.add_artist(text_box)
147155

148156
ax.set(
149157
xlabel=xlabel,

0 commit comments

Comments
 (0)