|
5 | 5 | from datetime import datetime
|
6 | 6 | from importlib.metadata import version
|
7 | 7 |
|
8 |
| -import numpy as np |
9 | 8 | import pandas as pd
|
10 | 9 | import wandb
|
11 | 10 | from megnet.utils.models import load_model
|
12 | 11 | from tqdm import tqdm
|
13 | 12 |
|
14 | 13 | from matbench_discovery import ROOT
|
| 14 | +from matbench_discovery.plot_scripts import df_wbm |
15 | 15 | from matbench_discovery.slurm import slurm_submit
|
16 | 16 |
|
17 | 17 | """
|
|
23 | 23 | __author__ = "Janosh Riebesell"
|
24 | 24 | __date__ = "2022-11-14"
|
25 | 25 |
|
26 |
| -task_type = "IS2RE" # "RS2RE" |
| 26 | +task_type = "IS2RE" |
27 | 27 | timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
|
28 | 28 | today = timestamp.split("@")[0]
|
29 | 29 | module_dir = os.path.dirname(__file__)
|
30 |
| -# set large job array size for fast testing/debugging |
31 |
| -slurm_array_task_count = 1 |
32 |
| -slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug") |
33 |
| -job_name = f"megnet-wbm-{task_type}-{slurm_job_id}" |
| 30 | +job_name = f"megnet-wbm-{task_type}" |
34 | 31 | out_dir = f"{module_dir}/{today}-{job_name}"
|
35 | 32 |
|
36 | 33 | slurm_vars = slurm_submit(
|
|
39 | 36 | partition="icelake-himem",
|
40 | 37 | account="LEE-SL3-CPU",
|
41 | 38 | time=(slurm_max_job_time := "12:0:0"),
|
42 |
| - array=f"1-{slurm_array_task_count}", |
43 | 39 | # TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
|
44 | 40 | # https://stackoverflow.com/a/40982782
|
45 | 41 | pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
|
46 | 42 | )
|
47 | 43 |
|
48 | 44 |
|
49 | 45 | # %%
|
50 |
| -slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) |
51 |
| - |
52 | 46 | print(f"Job started running {timestamp}")
|
53 | 47 |
|
54 |
| -json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz" |
55 |
| -if os.path.isfile(json_out_path): |
56 |
| - raise SystemExit(f"{json_out_path = } already exists, exciting early") |
| 48 | +out_path = f"{out_dir}/megnet-e-form-preds.csv" |
| 49 | +if os.path.isfile(out_path): |
| 50 | + raise SystemExit(f"{out_path = } already exists, exciting early") |
57 | 51 |
|
58 | 52 |
|
59 | 53 | # %%
|
60 | 54 | data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
|
61 | 55 | print(f"Loading from {data_path=}")
|
62 |
| -df_wbm = pd.read_json(data_path).set_index("material_id") |
| 56 | +df_wbm_structs = pd.read_json(data_path).set_index("material_id") |
63 | 57 |
|
64 |
| -df_this_job: pd.DataFrame = np.array_split(df_wbm, slurm_array_task_count)[ |
65 |
| - slurm_array_task_id - 1 |
66 |
| -] |
67 | 58 |
|
68 | 59 | megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
|
69 | 60 |
|
| 61 | + |
| 62 | +# %% |
70 | 63 | run_params = dict(
|
71 | 64 | data_path=data_path,
|
72 | 65 | megnet_version=version("megnet"),
|
73 | 66 | model_name=model_name,
|
74 | 67 | task_type=task_type,
|
75 | 68 | slurm_max_job_time=slurm_max_job_time,
|
76 |
| - df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)), |
| 69 | + df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)), |
77 | 70 | slurm_vars=slurm_vars,
|
78 | 71 | )
|
79 | 72 | if wandb.run is None:
|
80 | 73 | wandb.login()
|
81 | 74 |
|
82 |
| -wandb.init( |
83 |
| - project="matbench-discovery", |
84 |
| - name=f"{job_name}-{slurm_array_task_id}", |
85 |
| - config=run_params, |
86 |
| -) |
| 75 | +wandb.init(project="matbench-discovery", name=job_name, config=run_params) |
87 | 76 |
|
88 | 77 |
|
89 | 78 | # %%
|
90 | 79 | if task_type == "IS2RE":
|
91 | 80 | from pymatgen.core import Structure
|
92 | 81 |
|
93 |
| - structures = df_this_job.initial_structure.map(Structure.from_dict) |
| 82 | + structures = df_wbm_structs.initial_structure.map(Structure.from_dict) |
94 | 83 | elif task_type == "RS2RE":
|
95 | 84 | from pymatgen.entries.computed_entries import ComputedStructureEntry
|
96 | 85 |
|
97 |
| - df_this_job.cse = df_this_job.cse.map(ComputedStructureEntry.from_dict) |
98 |
| - structures = df_this_job.cse.map(lambda x: x.structure) |
| 86 | + df_wbm_structs.cse = df_wbm_structs.cse.map(ComputedStructureEntry.from_dict) |
| 87 | + structures = df_wbm_structs.cse.map(lambda x: x.structure) |
99 | 88 | else:
|
100 | 89 | raise ValueError(f"Unknown {task_type = }")
|
101 | 90 |
|
102 |
| -megnet_preds = {} |
103 |
| -for material_id, structure in tqdm(structures.items(), disable=None): |
104 |
| - if material_id in megnet_preds: |
| 91 | +megnet_e_form_preds = {} |
| 92 | +for material_id, structure in tqdm(structures.items(), total=len(structures)): |
| 93 | + if material_id in megnet_e_form_preds: |
105 | 94 | continue
|
106 |
| - e_form_per_atom = megnet_mp_e_form.predict_structure(structure)[0] |
107 |
| - megnet_preds[material_id] = e_form_per_atom |
| 95 | + try: |
| 96 | + e_form_per_atom = megnet_mp_e_form.predict_structure(structure)[0] |
| 97 | + megnet_e_form_preds[material_id] = e_form_per_atom |
| 98 | + except Exception as exc: |
| 99 | + print(f"Failed to predict {material_id=}: {exc}") |
| 100 | + |
108 | 101 |
|
| 102 | +# %% |
| 103 | +print(f"{len(megnet_e_form_preds)=:,}") |
| 104 | +print(f"{len(structures)=:,}") |
| 105 | +print(f"missing: {len(structures) - len(megnet_e_form_preds):,}") |
| 106 | +out_col = "e_form_per_atom_megnet" |
| 107 | +df_wbm[out_col] = pd.Series(megnet_e_form_preds) |
109 | 108 |
|
110 |
| -assert len(megnet_preds) == len(structures) == len(df_this_job) |
111 |
| -out_col = "megnet_e_form" |
112 |
| -df_this_job[out_col] = pd.Series(megnet_preds) |
| 109 | +df_wbm[out_col].reset_index().to_csv(out_path) |
113 | 110 |
|
114 | 111 |
|
115 | 112 | # %%
|
116 |
| -df_this_job[out_col].reset_index().to_json(json_out_path) |
| 113 | +fields = {"x": "e_form_per_atom_mp2020_corrected", "y": out_col} |
| 114 | +cols = list(fields.values()) |
| 115 | +assert all(col in df_wbm for col in cols) |
| 116 | + |
| 117 | +table = wandb.Table(dataframe=df_wbm[cols].reset_index()) |
| 118 | + |
| 119 | +MAE = (df_wbm[fields["x"]] - df_wbm[fields["y"]]).abs().mean() |
| 120 | + |
| 121 | +scatter_plot = wandb.plot_table( |
| 122 | + vega_spec_name="janosh/scatter-parity", |
| 123 | + data_table=table, |
| 124 | + fields=fields, |
| 125 | + string_fields={"title": f"{model_name} {task_type} {MAE=:.4}"}, |
| 126 | +) |
117 | 127 |
|
118 |
| -wandb.log_artifact(json_out_path, type=f"m3gnet-wbm-{task_type}") |
| 128 | +wandb.log({"true_pred_scatter": scatter_plot}) |
0 commit comments