2
2
3
3
from typing import Any , Sequence
4
4
5
+ import matplotlib .pyplot as plt
5
6
import pandas as pd
6
7
import pytest
7
8
8
9
from mb_discovery import ROOT
9
- from mb_discovery .plot_scripts .plot_funcs import precision_recall_vs_calc_count
10
+ from mb_discovery .plot_scripts .plot_funcs import (
11
+ precision_recall_vs_calc_count ,
12
+ rolling_mae_vs_hull_dist ,
13
+ )
10
14
11
15
12
16
DATA_DIR = f"{ ROOT } /data/2022-06-11-from-rhys"
28
32
"intersect_lines, stability_crit, stability_threshold, expected_line_count" ,
29
33
[
30
34
((), "energy" , 0 , 11 ),
31
- ("precision_x" , "energy+std" , 0 , 23 ),
32
- (["recall_y" ], "energy" , - 0.1 , 35 ),
33
- ("all" , "energy-std" , 0.1 , 56 ),
35
+ ("precision_x" , "energy+std" , 0 , 14 ),
36
+ (["recall_y" ], "energy" , - 0.1 , 14 ),
37
+ ("all" , "energy-std" , 0.1 , 23 ),
34
38
],
35
39
)
36
40
def test_precision_recall_vs_calc_count (
@@ -39,7 +43,7 @@ def test_precision_recall_vs_calc_count(
39
43
stability_threshold : float ,
40
44
expected_line_count : int ,
41
45
) -> None :
42
- ax = None
46
+ ax = plt . figure (). gca () # ensure test functions use different axes
43
47
44
48
for (model_name , df ), color in zip (
45
49
test_dfs .items (), ("tab:blue" , "tab:orange" , "tab:pink" )
@@ -66,6 +70,9 @@ def test_precision_recall_vs_calc_count(
66
70
assert ax .get_ylim () == (0 , 100 )
67
71
assert ax .get_xlim () == pytest .approx ((- 1.4 , 29.4 ))
68
72
73
+ assert ax .get_xlabel () == "Number of Calculations"
74
+ assert ax .get_ylabel () == "Precision and Recall (%)"
75
+
69
76
70
77
@pytest .mark .parametrize (
71
78
"kwargs, expected_exc, match_pat" ,
@@ -84,3 +91,37 @@ def test_precision_recall_vs_calc_count_raises(
84
91
e_above_hull_col = "e_above_mp_hull" ,
85
92
** kwargs ,
86
93
)
94
+
95
+
96
+ @pytest .mark .parametrize ("half_window" , (0.02 , 0.002 ))
97
+ @pytest .mark .parametrize ("bin_width" , (0.1 , 0.001 ))
98
+ @pytest .mark .parametrize ("x_lim" , ((0 , 0.6 ), (- 0.2 , 0.8 )))
99
+ def test_rolling_mae_vs_hull_dist (
100
+ half_window : float , bin_width : float , x_lim : tuple [float , float ]
101
+ ) -> None :
102
+ ax = plt .figure ().gca () # ensure test functions use different axes
103
+
104
+ for (model_name , df ), color in zip (
105
+ test_dfs .items (), ("tab:blue" , "tab:orange" , "tab:pink" )
106
+ ):
107
+ model_preds = df .filter (like = r"_pred" ).mean (axis = 1 )
108
+ targets = df .e_form_target
109
+
110
+ df ["residual" ] = model_preds - targets + df .e_above_mp_hull
111
+
112
+ ax = rolling_mae_vs_hull_dist (
113
+ df ,
114
+ residual_col = "residual" ,
115
+ e_above_hull_col = "e_above_mp_hull" ,
116
+ color = color ,
117
+ label = model_name ,
118
+ ax = ax ,
119
+ x_lim = x_lim ,
120
+ half_window = half_window ,
121
+ bin_width = bin_width ,
122
+ )
123
+
124
+ assert ax is not None
125
+ assert ax .get_ylim () == pytest .approx ((0 , 0.14 ))
126
+ assert ax .get_ylabel () == "MAE (eV / atom)"
127
+ assert ax .get_xlabel () == r"$\Delta E_{Hull-MP}$ (eV / atom)"
0 commit comments