Skip to content

Commit 5e9f5df

Browse files
committed
fix path issue on case-sensitive file systems
1 parent 92451da commit 5e9f5df

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

mb_discovery/plot_scripts/plot_funcs.py

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def hist_classified_stable_as_func_of_hull_dist(
5555
e_above_hull_vals = df[e_above_hull_col]
5656
residuals = error + e_above_hull_vals
5757

58+
if stability_crit not in get_args(StabilityCriterion):
59+
raise ValueError(
60+
f"Invalid {stability_crit=} must be one of {get_args(StabilityCriterion)}"
61+
)
5862
if stability_crit == "energy":
5963
test = residuals
6064
elif "std" in stability_crit:

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020

2121

2222
# %%
23-
df_hull = pd.read_csv(
24-
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
25-
).set_index("material_id")
23+
DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
24+
df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
2625

2726
dfs: dict[str, pd.DataFrame] = {}
2827
for model_name in ("Wren", "CGCNN", "Voronoi"):
29-
dfs[model_name] = pd.read_csv(
30-
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
28+
df = pd.read_csv(
29+
f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv"
3130
).set_index("material_id")
31+
dfs[model_name] = df
3232

3333
# dfs["M3GNet"] = pd.read_json(
3434
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"

tests/test_plot_funcs.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
from mb_discovery.plot_scripts.plot_funcs import precision_recall_vs_calc_count
1010

1111

12-
df_hull = pd.read_csv(
13-
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
14-
).set_index("material_id")
12+
DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
13+
14+
df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
1515

1616
test_dfs: dict[str, pd.DataFrame] = {}
1717
for model_name in ("Wren", "CGCNN", "Voronoi"):
1818
df = pd.read_csv(
19-
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv",
20-
nrows=100,
19+
f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv", nrows=100
2120
).set_index("material_id")
2221

2322
df["e_above_mp_hull"] = df_hull.e_above_mp_hull

0 commit comments

Comments
 (0)