Skip to content

Rerun M3GNet with new ASE FrechetCellFilter (prev ExpCellFilter) #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
stress_col = "stress"
stress_trace_col = "stress_trace"
n_sites_col = "n_sites"
entry_col = "computed_structure_entry"

# load figshare 1.0.0
with open(f"{FIGSHARE}/1.0.0.json") as file:
Expand Down
4 changes: 2 additions & 2 deletions matbench_discovery/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,6 @@ def get_e_form_per_atom(
if isinstance(ref_entry, dict):
e_refs[key] = PDEntry.from_dict(ref_entry)

form_energy = energy - sum(comp[el] * e_refs[str(el)] for el in comp)
e_form = energy - sum(comp[el] * e_refs[str(el)] for el in comp)

return form_energy / comp.num_atoms
return e_form / comp.num_atoms
2 changes: 1 addition & 1 deletion matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PredFiles(Files):
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5-wbm-IS2RE.csv.gz"

# original M3GNet straight from publication, not re-trained
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv.gz"
m3gnet = "m3gnet/2023-12-28-m3gnet-wbm-IS2RE.csv.gz"
# m3gnet_direct = "m3gnet/2023-05-30-m3gnet-direct-wbm-IS2RE.csv.gz"
# m3gnet_ms = "m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.csv.gz"

Expand Down
Binary file added models/m3gnet/2023-12-28-m3gnet-wbm-IS2RE.csv.gz
Binary file not shown.
18 changes: 10 additions & 8 deletions models/m3gnet/join_m3gnet_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pymatgen.entries.computed_entries import ComputedStructureEntry
from tqdm import tqdm

from matbench_discovery.data import DATA_FILES, as_dict_handler, id_col
from matbench_discovery import entry_col, id_col
from matbench_discovery.data import DATA_FILES, as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom

__author__ = "Janosh Riebesell"
Expand All @@ -26,12 +27,11 @@
# %%
module_dir = os.path.dirname(__file__)
task_type = "IS2RE"
date = "2023-05-30"
date = "2023-12-28"
# direct: cluster sampling, ms: manual sampling
model_type: Literal["orig", "direct", "ms"] = "ms"
model_type: Literal["orig", "direct", "ms"] = "orig"
glob_pattern = f"{date}-m3gnet-{model_type}-wbm-{task_type}/*.json.gz"
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
struct_col = "m3gnet_structure"
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

# prevent accidental overwrites
Expand All @@ -52,7 +52,6 @@

# %%
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(id_col)
entry_col = "computed_structure_entry"
df_cse[entry_col] = [
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_cse.computed_structure_entry)
Expand All @@ -62,9 +61,12 @@
# %% transfer M3GNet energies and relaxed structures WBM CSEs since MP2020 energy
# corrections applied below are structure-dependent (for oxides and sulfides)
cse: ComputedStructureEntry
for row in tqdm(df_m3gnet.itertuples(), total=len(df_m3gnet)):
mat_id, struct_dict, m3gnet_energy, *_ = row
mlip_struct = Structure.from_dict(struct_dict)
e_col = "m3gnet_orig_energy"
struct_col = "m3gnet_orig_structure"

for mat_id in tqdm(df_m3gnet.index):
m3gnet_energy = df_m3gnet.loc[mat_id, e_col]
mlip_struct = Structure.from_dict(df_m3gnet.loc[mat_id, struct_col])
df_m3gnet.at[mat_id, struct_col] = mlip_struct # noqa: PD008
cse = df_cse.loc[mat_id, entry_col]
cse._energy = m3gnet_energy # cse._energy is the uncorrected energy # noqa: SLF001
Expand Down
13 changes: 9 additions & 4 deletions models/m3gnet/test_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
model_type: Literal["orig", "direct", "manual-sampling"] = "orig"
# set large job array size for smaller data splits and faster testing/debugging
slurm_array_task_count = 50
record_traj = False
job_name = f"m3gnet-{model_type}-wbm-{task_type}"
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")

Expand Down Expand Up @@ -94,6 +95,7 @@
model_type=model_type,
out_path=out_path,
job_name=job_name,
record_traj=record_traj,
)

run_name = f"{job_name}-{slurm_array_task_id}"
Expand All @@ -112,12 +114,15 @@
if material_id in relax_results:
continue
try:
relax_result = m3gnet.relax(structures[material_id])
result = m3gnet.relax(structures[material_id])
relax_results[material_id] = {
f"m3gnet_{model_type}_structure": relax_result["final_structure"],
f"m3gnet_{model_type}_trajectory": relax_result["trajectory"].__dict__,
e_pred_col: relax_result["trajectory"].energies[-1],
f"m3gnet_{model_type}_structure": result["final_structure"],
e_pred_col: result["trajectory"].energies[-1],
}
if record_traj:
relax_results[f"m3gnet_{model_type}_trajectory"] = result[
"trajectory"
].__dict__
except Exception as exc:
print(f"Failed to relax {material_id}: {exc!r}")

Expand Down
3 changes: 1 addition & 2 deletions models/mace/join_mace_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pymatviz import density_scatter
from tqdm import tqdm

from matbench_discovery import formula_col, id_col
from matbench_discovery import entry_col, formula_col, id_col
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.preds import e_form_col
Expand Down Expand Up @@ -52,7 +52,6 @@
# %%
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(id_col)

entry_col = "computed_structure_entry"
df_cse[entry_col] = [
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_cse.computed_structure_entry)
Expand Down
11 changes: 3 additions & 8 deletions scripts/model_figs/rolling_mae_vs_hull_dist_wbm_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,16 @@

from matbench_discovery import PDF_FIGS, SITE_FIGS, today
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
from matbench_discovery.preds import (
df_each_pred,
df_preds,
e_form_col,
each_true_col,
models,
)
from matbench_discovery.preds import df_each_pred, df_preds, e_form_col, each_true_col
from matbench_discovery.preds import models as all_models

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"

batch_col = "batch_idx"
df_each_pred[batch_col] = "Batch " + df_each_pred.index.str.split("-").str[1]
df_err, df_std = None, None # variables to cache rolling MAE and std
models = globals().get("models", models)
models = globals().get("models", all_models)


# %% matplotlib version
Expand Down
File renamed without changes.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/box-hull-dist-errors.svelte

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/cumulative-mae.svelte

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/cumulative-precision-recall.svelte

Large diffs are not rendered by default.

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions site/src/figs/metrics-table-first-10k.svelte

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 39 additions & 31 deletions site/src/figs/metrics-table.svelte

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading