Skip to content

Commit 0ea1ff3

Browse files
committed
add m3gnet and bowsr-megnet to plot moving_hull_dist_mae_compare_models.py
record energy_model in test_bowsr.py run_params
1 parent 8777870 commit 0ea1ff3

File tree

4 files changed

+19
-26
lines changed

4 files changed

+19
-26
lines changed

matbench_discovery/plot_scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
df[pred_col] = df.e_form_per_atom_m3gnet
6161
if "bowsr_megnet" in dfs:
6262
df = dfs["bowsr_megnet"]
63-
df[pred_col] = df.e_form_per_atom_bowsr
63+
df[pred_col] = df.e_form_per_atom_bowsr_megnet
6464
if "wrenformer" in dfs:
6565
pred_col = "e_form_per_atom_mp2020_corrected_pred_ens"
6666

matbench_discovery/plot_scripts/precision_recall.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
4444
F1s: dict[str, float] = {}
4545

46-
for model_name, df in dfs.items():
46+
for model_name, df in sorted(dfs.items()):
4747
if "std" in stability_crit:
4848
# TODO column names to compute standard deviation from are currently hardcoded
4949
# needs to be updated when adding non-aviary models with uncertainty estimation
@@ -63,8 +63,8 @@
6363
# other cases are unexpected
6464
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
6565
model_preds = df[pred_cols].mean(axis=1)
66-
elif "bowsr" in model_name:
67-
model_preds = df.e_form_per_atom_bowsr
66+
elif model_name == "bowsr_megnet":
67+
model_preds = df.e_form_per_atom_bowsr_megnet
6868
else:
6969
raise ValueError(f"Unhandled {model_name = }")
7070
except AttributeError as exc:

models/bowsr/join_bowsr_results.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
for file_path in tqdm(file_paths):
3434
if file_path in dfs:
3535
continue
36-
# keep whole dataframe in memory
3736
df = pd.read_json(file_path).set_index("material_id")
3837

3938
df["bowsr_structure"] = df.structure_bowsr.map(Structure.from_dict)
@@ -51,20 +50,14 @@
5150
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
5251
df_wbm = pd.read_csv(data_path).set_index("material_id")
5352

54-
df_bowsr["e_form_wbm"] = df_wbm.e_form_per_atom
5553

56-
print(f"{len(df_bowsr) - len(df_wbm) = :,} = {len(df_bowsr):,} - {len(df_wbm):,}")
57-
58-
59-
# %%
60-
df_bowsr.hist(bins=200, figsize=(18, 12))
61-
df_bowsr.isna().sum()
54+
print(f"{len(df_bowsr):,} - {len(df_wbm):,} = {len(df_bowsr) - len(df_wbm) = :,}")
6255

6356

6457
# %%
6558
pymatviz.density_scatter(
66-
df_bowsr.dropna().e_form_per_atom_bowsr,
67-
df_bowsr.dropna().e_form_wbm,
59+
x=df_bowsr.e_form_per_atom_bowsr_megnet,
60+
y=df_bowsr.e_form_wbm,
6861
)
6962

7063

models/bowsr/test_bowsr.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
slurm_array_task_count = 500
3636
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3737
today = timestamp.split("@")[0]
38-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
39-
job_name = f"bowsr-megnet-wbm-{task_type}-{slurm_job_id}"
38+
energy_model = "megnet"
39+
job_name = f"bowsr-{energy_model}-wbm-{task_type}"
4040
out_dir = f"{module_dir}/{today}-{job_name}"
4141

4242
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
@@ -64,7 +64,7 @@
6464
print(f"{data_path = }")
6565
print(f"{out_path = }")
6666
print(f"{version('maml') = }")
67-
print(f"{version('megnet') = }")
67+
print(f"{version(energy_model) = }")
6868

6969

7070
if os.path.isfile(out_path):
@@ -94,7 +94,8 @@
9494
data_path=data_path,
9595
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
9696
maml_version=version("maml"),
97-
megnet_version=version("megnet"),
97+
energy_model=energy_model,
98+
energy_model_version=version(energy_model),
9899
optimize_kwargs=optimize_kwargs,
99100
task_type=task_type,
100101
slurm_max_job_time=slurm_max_job_time,
@@ -103,12 +104,11 @@
103104
if wandb.run is None:
104105
wandb.login()
105106

106-
# getting wandb: 429 encountered ({"error":"rate limit exceeded"}), retrying request
107-
# https://community.wandb.ai/t/753/14
107+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
108108
wandb.init(
109109
entity="janosh",
110110
project="matbench-discovery",
111-
name=f"{job_name}-{slurm_array_task_id}",
111+
name=f"{job_name}-{slurm_job_id}-{slurm_array_task_id}",
112112
config=run_params,
113113
)
114114

@@ -146,11 +146,11 @@
146146

147147
structure_bowsr, energy_bowsr = bayes_optimizer.get_optimized_structure_and_energy()
148148

149-
results = dict(
150-
e_form_per_atom_bowsr=model.predict_energy(structure),
151-
structure_bowsr=structure_bowsr,
152-
energy_bowsr=energy_bowsr,
153-
)
149+
results = {
150+
f"e_form_per_atom_bowsr_{energy_model}": model.predict_energy(structure),
151+
"structure_bowsr": structure_bowsr,
152+
"energy_bowsr": energy_bowsr,
153+
}
154154

155155
relax_results[material_id] = results
156156

0 commit comments

Comments
 (0)