Skip to content

Commit 63b14aa

Browse files
committed
create no-structure CSVs of m3gnet and bowsr-megnet preds for faster loading used in load_df_wbm_with_preds()
add test_glob_to_df() and dummy_df fixture
1 parent 6e58a1b commit 63b14aa

File tree

5 files changed

+32
-30
lines changed

5 files changed

+32
-30
lines changed

matbench_discovery/plot_scripts/precision_recall.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
# %%
1212
models = (
13-
"Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi IS2RE, Voronoi RS2RE, "
14-
"Wrenformer, MEGNet"
13+
"Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi RF, "
14+
"Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
1515
).split(", ")
1616

1717
df_wbm = load_df_wbm_with_preds(models=models).round(3)

models/bowsr/join_bowsr_results.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
# %%
38-
df_bowsr = pd.concat(dfs.values())
38+
df_bowsr = pd.concat(dfs.values()).round(6)
3939

4040

4141
# %% compare against WBM formation energy targets to make sure we got sensible results
@@ -57,5 +57,8 @@
5757
out_path = f"{ROOT}/models/bowsr/{today}-bowsr-megnet-wbm-{task_type}.json.gz"
5858
df_bowsr.reset_index().to_json(out_path, default_handler=lambda x: x.as_dict())
5959

60-
# out_path = f"{ROOT}/models/bowsr/2022-08-16-bowsr-megnet-wbm-IS2RE.json.gz"
61-
# df_bowsr = pd.read_json(out_path).set_index("material_id")
60+
# save energy and formation energy as CSV for fast loading
61+
df_bowsr.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
62+
63+
# in_path = f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz"
64+
# df_bowsr = pd.read_json(in_path).set_index("material_id")

models/m3gnet/join_m3gnet_results.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,24 @@
2727

2828

2929
# %%
30-
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
31-
# slower than serial loading
3230
for file_path in tqdm(file_paths):
3331
if file_path in dfs:
3432
continue
3533
df = pd.read_json(file_path).set_index("material_id")
36-
df.index.name = "material_id"
37-
col_map = dict(final_structure="structure_m3gnet", trajectory="m3gnet_trajectory")
38-
df = df.rename(columns=col_map)
39-
df.reset_index().to_json(file_path)
40-
df[f"m3gnet_energy_{task_type}"] = df.m3gnet_trajectory.map(
41-
lambda x: x["energies"][-1][0]
42-
)
34+
df[f"m3gnet_energy_{task_type}"] = [
35+
x["energies"][-1][0] for x in df.m3gnet_trajectory
36+
]
4337
# drop trajectory to save memory
44-
dfs[file_path] = df.drop(columns=["m3gnet_trajectory"])
38+
dfs[file_path] = df.drop(columns="m3gnet_trajectory")
4539

4640

4741
# %%
48-
df_m3gnet = pd.concat(dfs.values())
42+
df_m3gnet = pd.concat(dfs.values()).round(6)
4943

5044

5145
# %%
5246
df_m3gnet["e_form_per_atom_m3gnet"] = [
53-
get_e_form_per_atom(PDEntry(row.structure_m3gnet.composition, row.m3gnet_energy))
47+
get_e_form_per_atom(PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy))
5448
for row in tqdm(df_m3gnet.itertuples(), total=len(df_m3gnet), disable=None)
5549
]
5650
df_m3gnet.isna().sum()
@@ -60,5 +54,7 @@
6054
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-{task_type}.json.gz"
6155
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
6256

63-
# out_path = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
64-
# df_m3gnet = pd.read_json(out_path).set_index("material_id")
57+
df_m3gnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
58+
59+
# in_path = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
60+
# df_m3gnet = pd.read_json(in_path).set_index("material_id")

models/voronoi/join_voronoi_features.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323

2424
# %%
25-
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
26-
# slower than serial loading
2725
for file_path in tqdm(file_paths):
2826
if file_path in dfs:
2927
continue
@@ -32,7 +30,7 @@
3230

3331

3432
# %%
35-
df_features = pd.concat(dfs.values())
33+
df_features = pd.concat(dfs.values()).round(6)
3634

3735
ax = df_features.isna().sum().value_counts().T.plot.bar()
3836
ax.set(xlabel="# NaNs", ylabel="# columns", title="NaNs per column")

models/voronoi/train_test_voronoi_rf.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.pipeline import Pipeline
1111

1212
from matbench_discovery import DEBUG, ROOT, today
13-
from matbench_discovery.plot_scripts import df_wbm
13+
from matbench_discovery.plot_scripts import df_wbm, glob_to_df
1414
from matbench_discovery.plots import wandb_log_scatter
1515
from matbench_discovery.slurm import slurm_submit
1616
from models.voronoi import featurizer
@@ -41,24 +41,29 @@
4141

4242

4343
# %%
44-
train_path = f"{module_dir}/2022-11-25-features-mp.csv.bz2"
45-
print(f"{train_path=}")
46-
df_train = pd.read_csv(train_path).set_index("material_id")
44+
train_path = f"{module_dir}/2022-11-25-features-mp/voronoi-features-mp-*.csv.bz2"
45+
df_train = glob_to_df(train_path).set_index("material_id")
4746
print(f"{df_train.shape=}")
4847

4948
mp_energies_path = f"{ROOT}/data/mp/2022-08-13-mp-energies.json.gz"
5049
df_mp = pd.read_json(mp_energies_path).set_index("material_id")
5150
train_target_col = "formation_energy_per_atom"
52-
df_train[train_target_col] = df_mp[train_target_col]
53-
5451

5552
test_path = f"{module_dir}/2022-11-18-features-wbm-{task_type}.csv.bz2"
56-
print(f"{test_path=}")
5753
df_test = pd.read_csv(test_path).set_index("material_id")
5854
print(f"{df_test.shape=}")
5955

6056
test_target_col = "e_form_per_atom_mp2020_corrected"
61-
df_test[test_target_col] = df_wbm[test_target_col]
57+
58+
59+
for df, df_tar, col in (
60+
(df_train, df_mp, train_target_col),
61+
(df_test, df_wbm, test_target_col),
62+
):
63+
df[train_target_col] = df_tar[train_target_col]
64+
nans = df_tar[col].isna().sum()
65+
assert nans == 0, f"{nans} NaNs in {col} targets"
66+
6267
model_name = "Voronoi RandomForestRegressor"
6368

6469
run_params = dict(

0 commit comments

Comments
 (0)