|
11 | 11 | import os
|
12 | 12 | from importlib.metadata import version
|
13 | 13 |
|
| 14 | +import numpy as np |
14 | 15 | import pandas as pd
|
15 | 16 | import wandb
|
16 | 17 | from megnet.utils.models import load_model
|
|
21 | 22 | from matbench_discovery import DEBUG, timestamp, today
|
22 | 23 | from matbench_discovery.data import DATA_FILES, df_wbm
|
23 | 24 | from matbench_discovery.plots import wandb_scatter
|
| 25 | +from matbench_discovery.preds import PRED_FILES |
24 | 26 | from matbench_discovery.slurm import slurm_submit
|
25 | 27 |
|
26 | 28 | __author__ = "Janosh Riebesell"
|
27 | 29 | __date__ = "2022-11-14"
|
28 | 30 |
|
29 |
| -task_type = "IS2RE" |
| 31 | +task_type = "chgnet_structure" |
30 | 32 | module_dir = os.path.dirname(__file__)
|
31 | 33 | job_name = f"megnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
|
32 | 34 | out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
|
| 35 | +slurm_array_task_count = 20 |
33 | 36 |
|
34 | 37 | slurm_vars = slurm_submit(
|
35 | 38 | job_name=job_name,
|
|
38 | 41 | account="LEE-SL3-CPU",
|
39 | 42 | time="12:0:0",
|
40 | 43 | slurm_flags=("--mem", "30G"),
|
| 44 | + array=f"1-{slurm_array_task_count}", |
41 | 45 | # TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
|
42 | 46 | # https://stackoverflow.com/a/40982782
|
43 | 47 | pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
|
44 | 48 | )
|
45 | 49 |
|
46 | 50 |
|
47 | 51 | # %%
|
| 52 | +slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) |
48 | 53 | out_path = f"{out_dir}/megnet-e-form-preds.csv"
|
49 | 54 | if os.path.isfile(out_path):
|
50 | 55 | raise SystemExit(f"{out_path = } already exists, exciting early")
|
51 | 56 |
|
52 | 57 | data_path = {
|
53 | 58 | "IS2RE": DATA_FILES.wbm_initial_structures,
|
54 | 59 | "RS2RE": DATA_FILES.wbm_computed_structure_entries,
|
| 60 | + "chgnet_structure": PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"), |
| 61 | + "m3gnet_structure": PRED_FILES.__dict__["M3GNet"].replace(".csv", ".json.gz"), |
55 | 62 | }[task_type]
|
56 | 63 | print(f"\nJob started running {timestamp}")
|
57 | 64 | print(f"{data_path=}")
|
58 | 65 | e_form_col = "e_form_per_atom_mp2020_corrected"
|
59 | 66 | assert e_form_col in df_wbm, f"{e_form_col=} not in {list(df_wbm)=}"
|
60 | 67 |
|
61 |
| -df_in = pd.read_json(data_path).set_index("material_id") |
| 68 | +df_in: pd.DataFrame = np.array_split( |
| 69 | + pd.read_json(data_path).set_index("material_id"), slurm_array_task_count |
| 70 | +)[slurm_array_task_id - 1] |
62 | 71 | megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
|
63 | 72 |
|
64 | 73 |
|
|
77 | 86 |
|
78 | 87 |
|
79 | 88 | # %%
|
80 |
| -input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type] |
| 89 | +input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}.get( |
| 90 | + task_type, task_type # input_col=task_type for CHGNet and M3GNet |
| 91 | +) |
81 | 92 |
|
82 | 93 | if task_type == "RS2RE":
|
83 | 94 | df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
|
84 | 95 |
|
85 | 96 | structures = df_in[input_col].map(Structure.from_dict).to_dict()
|
86 | 97 |
|
87 | 98 | megnet_e_form_preds = {}
|
88 |
| -for material_id in tqdm(structures, disable=None): |
| 99 | +for material_id in tqdm(structures): |
89 | 100 | if material_id in megnet_e_form_preds:
|
90 | 101 | continue
|
91 | 102 | try:
|
|
0 commit comments