Skip to content

Commit 4b6e83a

Browse files
committed
add site/src/figs/(largest-fp-diff-each-error-models|largest-each-errors-fp-diff-models).svelte
shown on /models/tmi page (ex /models/per-element) generated by scripts/difficult_structures.py add col site_stats_fingerprint_init_final_norm_diff to data/wbm/2022-10-19-wbm-summary.csv sort EACH scatter plot facets and legend by MAE
1 parent 7946b5e commit 4b6e83a

10 files changed

+220
-87
lines changed

.gitattributes

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# exclude generated plot files when calculating repo language statistics on GitHub
2-
*/figs/* linguist-generated
2+
**/figs/* linguist-generated
33
data/**/*.svelte linguist-generated

data/mp/get_mp_energies.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
from aviary.wren.utils import get_aflow_label_from_spglib
66
from mp_api.client import MPRester
7-
from pymatviz.utils import annotate_mae_r2
7+
from pymatviz.utils import annotate_metrics
88
from tqdm import tqdm
99

1010
from matbench_discovery import today
@@ -71,7 +71,7 @@
7171
title=f"{today} - {len(df):,} MP entries",
7272
)
7373

74-
annotate_mae_r2(df.formation_energy_per_atom, df.decomposition_enthalpy)
74+
annotate_metrics(df.formation_energy_per_atom, df.decomposition_enthalpy)
7575
# result on 2023-01-10: plots match. no correlation between formation energy and
7676
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
7777
# ax.figure.savefig(f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)

data/wbm/fetch_process_wbm_dataset.py

+13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from glob import glob
88

9+
import numpy as np
910
import pandas as pd
1011
from aviary.wren.utils import get_aflow_label_from_spglib
1112
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
@@ -597,6 +598,18 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
597598
assert df_summary[wyckoff_col].isna().sum() == 0
598599

599600

601+
# %% site-stats.json.gz was generated by scripts/compute_struct_fingerprints.py
602+
df_fp = pd.read_json(f"{module_dir}/site-stats.json.gz").set_index("material_id")
603+
init_fp_col = "initial_site_stats_fingerprint"
604+
final_fp_col = "final_site_stats_fingerprint"
605+
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
606+
df_fp[fp_diff_col] = (
607+
df_fp[final_fp_col].map(np.array) - df_fp[init_fp_col].map(np.array)
608+
).map(np.linalg.norm)
609+
610+
df_fp[fp_diff_col].hist(bins=100, backend="plotly")
611+
612+
600613
# %% write final summary data to disk (yeah!)
601614
df_summary.round(6).to_csv(f"{module_dir}/{today}-wbm-summary.csv")
602615

scripts/compute_struct_fingerprints.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# %%
88
import os
99
import warnings
10+
from glob import glob
1011

1112
import numpy as np
1213
import pandas as pd
@@ -46,6 +47,7 @@
4647
account="LEE-SL3-CPU",
4748
time="6:0:0",
4849
array=f"1-{slurm_array_task_count}",
50+
slurm_flags=("--mem", "30G"),
4951
)
5052

5153

@@ -95,4 +97,28 @@
9597
except Exception as exc:
9698
print(f"{fp_col} for {row.Index} failed: {exc}")
9799

98-
df_in.filter(like="site_stats_fingerprint").to_json(out_path)
100+
df_in.filter(like="site_stats_fingerprint").reset_index().to_json(out_path)
101+
102+
103+
# %%
104+
running_as_slurm_job = os.getenv("SLURM_JOB_ID")
105+
if running_as_slurm_job:
106+
print(f"Job wrote {out_path=} and finished at {timestamp}")
107+
raise SystemExit(0)
108+
109+
110+
# %%
111+
out_files = glob(f"{out_dir}/site-stats-*.json.gz")
112+
113+
found_idx = [int(name.split("-")[-1].split(".")[0]) for name in out_files]
114+
print(f"Found {len(out_files)=:,}")
115+
missing_files = sorted(set(range(1, slurm_array_task_count + 1)) - set(found_idx))
116+
if missing_files:
117+
print(f"{len(missing_files)=}: {missing_files}")
118+
119+
df_out = pd.concat(pd.read_json(out_file) for out_file in tqdm(out_files))
120+
121+
122+
df_out.index.name = "material_id"
123+
124+
df_out.reset_index().to_json(f"{out_dir}/site-stats.json.gz")

0 commit comments

Comments
 (0)