3
3
4
4
import matplotlib .pyplot as plt
5
5
import pandas as pd
6
+ from matplotlib .offsetbox import AnchoredText
6
7
7
8
8
9
__author__ = "Janosh Riebesell"
23
24
plt .rc ("savefig" , bbox = "tight" , dpi = 200 )
24
25
plt .rcParams ["figure.constrained_layout.use" ] = True
25
26
plt .rc ("figure" , dpi = 150 )
27
+ plt .rc ("font" , size = 14 )
26
28
27
29
28
30
def hist_classify_stable_as_func_of_hull_dist (
@@ -34,6 +36,7 @@ def hist_classify_stable_as_func_of_hull_dist(
34
36
std_vals : pd .Series = None ,
35
37
criterion : Literal ["energy" , "std" , "neg" ] = "energy" ,
36
38
energy_type : Literal ["true" , "pred" ] = "true" ,
39
+ annotate_all_stats : bool = False ,
37
40
) -> plt .Axes :
38
41
"""
39
42
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -77,7 +80,6 @@ def hist_classify_stable_as_func_of_hull_dist(
77
80
n_false_neg = len (e_above_hull_vals [actual_pos & model_neg ])
78
81
79
82
n_total_pos = n_true_pos + n_false_neg
80
- null = n_total_pos / len (e_above_hull_vals )
81
83
82
84
# --- histogram by DFT-computed distance to convex hull
83
85
if energy_type == "true" :
@@ -121,29 +123,35 @@ def hist_classify_stable_as_func_of_hull_dist(
121
123
len (false_neg ),
122
124
)
123
125
# null = (tp + fn) / (tp + tn + fp + fn)
124
- ppv = n_true_pos / (n_true_pos + n_false_pos )
125
- tpr = n_true_pos / n_total_pos
126
- f1 = 2 * ppv * tpr / (ppv + tpr )
126
+ Null = n_total_pos / len (e_above_hull_vals )
127
+ PPV = n_true_pos / (n_true_pos + n_false_pos )
128
+ TPR = n_true_pos / n_total_pos
129
+ F1 = 2 * PPV * TPR / (PPV + TPR )
127
130
128
131
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len (
129
132
formation_energy_targets
130
133
)
131
134
132
- print (f"PPV: { ppv :.2f} " )
133
- print (f"TPR: { tpr :.2f} " )
134
- print (f"F1: { f1 :.2f} " )
135
- print (f"Enrich: { ppv / null :.2f} " )
136
- print (f"Null: { null :.2f} " )
137
-
138
135
RMSE = (error ** 2.0 ).mean () ** 0.5
139
136
MAE = error .abs ().mean ()
140
- print (f"{ MAE = :.3} " )
141
- print (f"{ RMSE = :.3} " )
142
137
143
138
# anno_text = f"Prevalence = {null:.2f}\nPrecision = {ppv:.2f}\nRecall = {tpr:.2f}",
144
- anno_text = f"Enrichment\n Factor = { ppv / null :.1f} "
145
-
146
- ax .text (0.75 , 0.9 , anno_text , transform = ax .transAxes , fontsize = 20 )
139
+ anno_text = f"Enrichment Factor = { PPV / Null :.3} "
140
+ if annotate_all_stats :
141
+ anno_text += f"\n { MAE = :.3} \n { RMSE = :.3} \n { Null = :.3} \n { TPR = :.3} "
142
+ else :
143
+ print (f"{ PPV = :.3} " )
144
+ print (f"{ TPR = :.3} " )
145
+ print (f"{ F1 = :.3} " )
146
+ print (f"Enrich: { PPV / Null :.2f} " )
147
+ print (f"{ Null = :.3} " )
148
+ print (f"{ MAE = :.3} " )
149
+ print (f"{ RMSE = :.3} " )
150
+
151
+ text_box = AnchoredText (
152
+ anno_text , loc = "upper right" , frameon = False , prop = dict (fontsize = 16 )
153
+ )
154
+ ax .add_artist (text_box )
147
155
148
156
ax .set (
149
157
xlabel = xlabel ,
0 commit comments