File tree 3 files changed +13
-10
lines changed
mb_discovery/plot_scripts
3 files changed +13
-10
lines changed Original file line number Diff line number Diff line change @@ -55,6 +55,10 @@ def hist_classified_stable_as_func_of_hull_dist(
55
55
e_above_hull_vals = df [e_above_hull_col ]
56
56
residuals = error + e_above_hull_vals
57
57
58
+ if stability_crit not in get_args (StabilityCriterion ):
59
+ raise ValueError (
60
+ f"Invalid { stability_crit = } must be one of { get_args (StabilityCriterion )} "
61
+ )
58
62
if stability_crit == "energy" :
59
63
test = residuals
60
64
elif "std" in stability_crit :
Original file line number Diff line number Diff line change 20
20
21
21
22
22
# %%
23
- df_hull = pd .read_csv (
24
- f"{ ROOT } /data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
25
- ).set_index ("material_id" )
23
+ DATA_DIR = f"{ ROOT } /data/2022-06-11-from-rhys"
24
+ df_hull = pd .read_csv (f"{ DATA_DIR } /wbm-e-above-mp-hull.csv" ).set_index ("material_id" )
26
25
27
26
dfs : dict [str , pd .DataFrame ] = {}
28
27
for model_name in ("Wren" , "CGCNN" , "Voronoi" ):
29
- dfs [ model_name ] = pd .read_csv (
30
- f"{ ROOT } /data/2022-06-11-from-rhys/ { model_name } -mp-initial-structures.csv"
28
+ df = pd .read_csv (
29
+ f"{ DATA_DIR } / { model_name . lower () } -mp-initial-structures.csv"
31
30
).set_index ("material_id" )
31
+ dfs [model_name ] = df
32
32
33
33
# dfs["M3GNet"] = pd.read_json(
34
34
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
Original file line number Diff line number Diff line change 9
9
from mb_discovery .plot_scripts .plot_funcs import precision_recall_vs_calc_count
10
10
11
11
12
- df_hull = pd . read_csv (
13
- f" { ROOT } /data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
14
- ).set_index ("material_id" )
12
+ DATA_DIR = f" { ROOT } /data/2022-06-11-from-rhys"
13
+
14
+ df_hull = pd . read_csv ( f" { DATA_DIR } /wbm-e-above-mp-hull.csv" ).set_index ("material_id" )
15
15
16
16
test_dfs : dict [str , pd .DataFrame ] = {}
17
17
for model_name in ("Wren" , "CGCNN" , "Voronoi" ):
18
18
df = pd .read_csv (
19
- f"{ ROOT } /data/2022-06-11-from-rhys/{ model_name } -mp-initial-structures.csv" ,
20
- nrows = 100 ,
19
+ f"{ DATA_DIR } /{ model_name .lower ()} -mp-initial-structures.csv" , nrows = 100
21
20
).set_index ("material_id" )
22
21
23
22
df ["e_above_mp_hull" ] = df_hull .e_above_mp_hull
You can’t perform that action at this time.
0 commit comments