Skip to content

Commit c9fed5a

Browse files
committed
add models/chgnet/ctk_struct_traj.py and models/wrenformer/analyze_wrenformer.py
1 parent abfc0ac commit c9fed5a

File tree

6 files changed

+324
-38
lines changed

6 files changed

+324
-38
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.257
10+
rev: v0.0.258
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

matbench_discovery/preds.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,22 @@ class PredFiles(Files):
3232
# bowsr optimizer coupled with original megnet
3333
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
3434
# default CHGNet model from publication with 400,438 params
35-
chgnet = CHGNet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
35+
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
3636
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
3737
# CGCnn 10-member ensemble
38-
cgcnn = cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
38+
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
3939
# cgcnn 10-member ensemble with 5-fold training set perturbations
40-
cgcnn_p = CGCNN_P = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
40+
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
4141
# original m3gnet straight from publication, not re-trained
42-
m3gnet = M3GNet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
42+
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
4343
# m3gnet-relaxed structures fed into megnet for formation energy prediction
4444
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
4545
# original megnet straight from publication, not re-trained
46-
megnet = MEGNet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
46+
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
4747
# magpie composition+voronoi tessellation structure features + sklearn random forest
48-
voronoi_rf = Voronoi_RF = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
48+
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
4949
# wrenformer 10-member ensemble
50-
wrenformer = Wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
50+
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
5151

5252

5353
PRED_FILES = PredFiles()

models/chgnet/analyze_chgnet.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,78 @@
99
from pymatgen.core import Structure
1010
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
1111

12-
from matbench_discovery import plots
12+
from matbench_discovery import plots as plots
1313
from matbench_discovery.data import DATA_FILES, df_wbm
1414
from matbench_discovery.preds import PRED_FILES
1515

1616
__author__ = "Janosh Riebesell"
1717
__date__ = "2023-03-06"
1818

1919
module_dir = os.path.dirname(__file__)
20-
del plots # https://github.com/PyCQA/pyflakes/issues/366
20+
id_col = "material_id"
2121

2222

2323
# %%
24-
df_chgnet = pd.read_csv(PRED_FILES.CHGNet)
25-
df_chgnet = df_chgnet.set_index("material_id").add_suffix("_2000")
26-
df_chgnet_500 = pd.read_csv(PRED_FILES.CHGNet.replace("-06", "-04"))
27-
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
24+
df_chgnet = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
25+
df_chgnet = df_chgnet.set_index(id_col).add_suffix("_2000")
26+
df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
27+
df_chgnet_500 = df_chgnet_500.set_index(id_col).add_suffix("_500")
2828
df_chgnet[list(df_chgnet_500)] = df_chgnet_500
2929
df_chgnet["formula"] = df_wbm.formula
3030

3131
e_form_2000 = "e_form_per_atom_chgnet_2000"
3232
e_form_500 = "e_form_per_atom_chgnet_500"
3333

34-
min_e_diff = 0.35
34+
min_e_diff = 0.1
35+
# structures with smaller energy after longer relaxation need many steps
36+
df_long = df_chgnet.query(f"{e_form_2000} - {e_form_500} < -{min_e_diff}")
37+
# structures with larger energy after longer relaxation are problematic
3538
df_bad = df_chgnet.query(f"{e_form_2000} - {e_form_500} > {min_e_diff}")
39+
# both combined
40+
df_diff = df_chgnet.query(f"abs({e_form_2000} - {e_form_500}) > {min_e_diff}")
41+
42+
assert len(df_long) + len(df_bad) == len(df_diff)
43+
44+
45+
# %%
46+
density_scatter(df=df_chgnet, x=e_form_500, y=e_form_2000)
3647

3748

3849
# %%
39-
density_scatter(df=df_chgnet, x=e_form_2000, y=e_form_500)
50+
df_diff.reset_index().plot.scatter(
51+
x=e_form_500,
52+
y=e_form_2000,
53+
hover_name=id_col,
54+
hover_data=["formula"],
55+
backend="plotly",
56+
title=f"{len(df_diff)} structures have > {min_e_diff} eV/atom energy diff after "
57+
"longer relaxation",
58+
)
4059

4160

4261
# %%
4362
fig = ptable_heatmap_plotly(df_bad.formula)
44-
title = "structures with larger error after longer relaxation"
45-
fig.layout.title.update(text=f"{len(df_bad)} {title}")
63+
title = "structures with larger error<br>after longer relaxation"
64+
fig.layout.title.update(text=f"{len(df_diff)} {title}", x=0.4, y=0.9)
65+
fig.show()
4666

4767

4868
# %%
49-
df_cse = pd.read_json(DATA_FILES.wbm_initial_structures).set_index("material_id")
69+
df_cse = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(id_col)
70+
df_cse.loc[df_diff.index].reset_index().to_json(
71+
f"{module_dir}/wbm-chgnet-bad-relax.json.gz"
72+
)
5073

5174

5275
# %%
5376
n_rows, n_cols = 3, 4
5477
fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 4 * n_rows))
55-
n_struct = min(n_rows * n_cols, len(df_bad))
78+
n_struct = min(n_rows * n_cols, len(df_diff))
5679
struct_col = "initial_structure"
5780

5881
fig.suptitle(f"{n_struct} {struct_col} {title}", fontsize=16, fontweight="bold", y=1.05)
5982
for idx, (ax, row) in enumerate(
60-
zip(axs.flat, df_cse.loc[df_bad.index].itertuples()), 1
83+
zip(axs.flat, df_cse.loc[df_diff.index].itertuples()), 1
6184
):
6285
struct = Structure.from_dict(getattr(row, struct_col))
6386
plot_structure_2d(struct, ax=ax)

scripts/ctk_structure_viewer.py models/chgnet/ctk_structure_viewer.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pandas as pd
44
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer
55

6+
from matbench_discovery.preds import PRED_FILES
7+
68
__author__ = "Janosh Riebesell"
79
__date__ = "2023-03-07"
810

@@ -13,28 +15,24 @@
1315
Then open http://localhost:8000 in your browser.
1416
"""
1517

16-
df_plot = None
17-
min_e_diff = 0.1
1818
e_form_2000 = "e_form_per_atom_chgnet_2000"
1919
e_form_500 = "e_form_per_atom_chgnet_500"
2020

21-
if df_plot is None:
22-
from matbench_discovery.preds import PRED_FILES
21+
df_chgnet = pd.read_json(PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"))
22+
df_chgnet = df_chgnet.set_index("material_id")
2323

24-
df_chgnet = pd.read_json(PRED_FILES.CHGNet.replace(".csv", ".json.gz"))
25-
df_chgnet = df_chgnet.set_index("material_id")
24+
df_chgnet_2000 = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
25+
df_chgnet_2000 = df_chgnet_2000.set_index("material_id").add_suffix("_2000")
26+
df_chgnet[list(df_chgnet_2000)] = df_chgnet_2000
2627

27-
df_chgnet_2000 = pd.read_csv(PRED_FILES.CHGNet)
28-
df_chgnet_2000 = df_chgnet_2000.set_index("material_id").add_suffix("_2000")
29-
df_chgnet[list(df_chgnet_2000)] = df_chgnet_2000
28+
df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
29+
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
30+
df_chgnet[list(df_chgnet_500)] = df_chgnet_500
3031

31-
df_chgnet_500 = pd.read_csv(PRED_FILES.CHGNet.replace("-06", "-04"))
32-
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
33-
df_chgnet[list(df_chgnet_500)] = df_chgnet_500
34-
35-
e_form_abs_diff = "e_form_abs_diff"
36-
df_chgnet[e_form_abs_diff] = abs(df_chgnet[e_form_2000] - df_chgnet[e_form_500])
37-
df_plot = df_chgnet.round(3).query(f"{e_form_abs_diff} > {min_e_diff}")
32+
e_form_abs_diff = "e_form_abs_diff"
33+
min_e_diff = 0.1
34+
df_chgnet[e_form_abs_diff] = abs(df_chgnet[e_form_2000] - df_chgnet[e_form_500])
35+
df_plot = df_chgnet.round(3).query(f"{e_form_abs_diff} > {min_e_diff}")
3836

3937

4038
plot_labels = {
@@ -69,4 +67,4 @@
6967
# validate_id requires material_id to be hover_name
7068
validate_id=lambda id: id.startswith(("wbm-", "mp-", "mvc-")),
7169
)
72-
app.run_server(debug=True, port=8000)
70+
app.run(debug=True, port=8000)

0 commit comments

Comments
 (0)