1
1
from __future__ import annotations
2
2
3
- from typing import Any , Literal , get_args
3
+ from typing import Any , Literal
4
4
5
5
import matplotlib .pyplot as plt
6
6
import numpy as np
15
15
__author__ = "Janosh Riebesell"
16
16
__date__ = "2022-08-05"
17
17
18
- StabilityCriterion = Literal ["energy" , "energy+std" , "energy-std" ]
19
18
WhichEnergy = Literal ["true" , "pred" ]
20
19
AxLine = Literal ["x" , "y" , "xy" , "" ]
21
20
72
71
def hist_classified_stable_vs_hull_dist (
73
72
e_above_hull_pred : pd .Series ,
74
73
e_above_hull_true : pd .Series ,
75
- std_pred : pd .Series = None ,
76
74
ax : plt .Axes = None ,
77
75
which_energy : WhichEnergy = "true" ,
78
- stability_crit : StabilityCriterion = "energy" ,
79
76
stability_threshold : float = 0 ,
80
77
show_threshold : bool = True ,
81
78
x_lim : tuple [float | None , float | None ] = (- 0.4 , 0.4 ),
82
- rolling_accuracy : float = 0.02 ,
79
+ rolling_accuracy : float | None = 0.02 ,
83
80
) -> tuple [plt .Axes , dict [str , float ]]:
84
81
"""
85
82
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -98,21 +95,16 @@ def hist_classified_stable_vs_hull_dist(
98
95
energy.
99
96
e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
100
97
ground truth.
101
- std_pred (pd.Series, optional): standard deviation of the model's predicted
102
- formation energy.
103
98
ax (plt.Axes, optional): matplotlib axes to plot on.
104
99
which_energy (WhichEnergy, optional): Whether to use the true formation energy
105
100
or the model's predicted formation energy for the histogram.
106
- stability_crit (StabilityCriterion, optional): Whether to add/subtract the
107
- model's predicted uncertainty from its energy prediction when measuring
108
- predicted stability.
109
101
stability_threshold (float, optional): set stability threshold as distance to
110
102
convex hull in eV/atom, usually 0 or 0.1 eV.
111
103
show_threshold (bool, optional): Whether to plot stability threshold as dashed
112
104
vertical line.
113
105
x_lim (tuple[float | None, float | None]): x-axis limits.
114
- rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to 0 to
115
- disable. Defaults to 0.01.
106
+ rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
107
+ or 0 to disable. Defaults to 0.01.
116
108
117
109
Returns:
118
110
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
@@ -122,17 +114,7 @@ def hist_classified_stable_vs_hull_dist(
122
114
"""
123
115
ax = ax or plt .gca ()
124
116
125
- if stability_crit not in get_args (StabilityCriterion ):
126
- raise ValueError (
127
- f"Invalid { stability_crit = } must be one of { get_args (StabilityCriterion )} "
128
- )
129
-
130
117
test = e_above_hull_pred + e_above_hull_true
131
- if stability_crit == "energy+std" :
132
- test += std_pred
133
- elif stability_crit == "energy-std" :
134
- test -= std_pred
135
-
136
118
# --- histogram of DFT-computed distance to convex hull
137
119
if which_energy == "true" :
138
120
actual_pos = e_above_hull_true <= stability_threshold
@@ -348,8 +330,6 @@ def cumulative_clf_metric(
348
330
e_above_hull_error : pd .Series ,
349
331
e_above_hull_true : pd .Series ,
350
332
metric : Literal ["precision" , "recall" ],
351
- std_pred : pd .Series = None ,
352
- stability_crit : StabilityCriterion = "energy" ,
353
333
stability_threshold : float = 0 , # set stability threshold as distance to convex
354
334
# hull in eV / atom, usually 0 or 0.1 eV
355
335
ax : plt .Axes = None ,
@@ -370,9 +350,6 @@ def cumulative_clf_metric(
370
350
e_above_hull_true (str, optional): Column name with convex hull distance values.
371
351
Defaults to "e_above_hull".
372
352
metric ('precision' | 'recall', optional): Metric to plot.
373
- stability_crit ('energy' | 'energy+std' | 'energy-std', optional): Whether to
374
- use energy+/-std as stability stability_crit where std is the model
375
- predicted uncertainty for the energy it stipulated. Defaults to "energy".
376
353
stability_threshold (float, optional): Max distance from convex hull before
377
354
material is considered unstable. Defaults to 0.
378
355
label (str, optional): Model name used to identify its liens in the legend.
@@ -391,15 +368,6 @@ def cumulative_clf_metric(
391
368
e_above_hull_error = e_above_hull_error .sort_values ()
392
369
e_above_hull_true = e_above_hull_true .loc [e_above_hull_error .index ]
393
370
394
- if stability_crit not in get_args (StabilityCriterion ):
395
- raise ValueError (
396
- f"Invalid { stability_crit = } must be one of { get_args (StabilityCriterion )} "
397
- )
398
- if stability_crit == "energy+std" :
399
- e_above_hull_error += std_pred
400
- elif stability_crit == "energy-std" :
401
- e_above_hull_error -= std_pred
402
-
403
371
true_pos_mask = (e_above_hull_true <= stability_threshold ) & (
404
372
e_above_hull_error <= stability_threshold
405
373
)
0 commit comments