|
| 1 | +# %% |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import contextlib |
| 5 | +import os |
| 6 | +from datetime import datetime |
| 7 | +from importlib.metadata import version |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import pandas as pd |
| 12 | +import wandb |
| 13 | +from maml.apps.bowsr.model.megnet import MEGNet |
| 14 | +from maml.apps.bowsr.optimizer import BayesianOptimizer |
| 15 | +from tqdm import tqdm |
| 16 | + |
| 17 | +from mb_discovery import ROOT, as_dict_handler |
| 18 | + |
| 19 | +""" |
| 20 | +To slurm submit this file, use |
| 21 | +
|
| 22 | +```sh |
| 23 | +# slurm will not create logdir automatically and fail if missing |
| 24 | +mkdir -p models/bowsr/slurm_logs |
| 25 | +sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-500 \ |
| 26 | + --time 12:0:0 --job-name bowsr-megnet-wbm-IS2RE --mem 12000 \ |
| 27 | + --output models/bowsr/slurm_logs/slurm-%A-%a.out \ |
| 28 | + --wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/bowsr/slurm_array_bowsr_wbm.py" |
| 29 | +``` |
| 30 | +
|
| 31 | +--time 2h is probably enough but missing indices are annoying so best be safe. |
| 32 | +--mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s) |
| 33 | + Some of your processes may have been killed by the cgroup out-of-memory handler. |
| 34 | +
|
| 35 | +TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed |
| 36 | +https://stackoverflow.com/a/40982782 |
| 37 | +
|
| 38 | +Requires MEGNet and MAML installation: pip install megnet maml |
| 39 | +""" |
| 40 | + |
| 41 | +__author__ = "Janosh Riebesell" |
| 42 | +__date__ = "2022-08-15" |
| 43 | + |
| 44 | + |
| 45 | +task_type = "IS2RE" |
| 46 | +# task_type = "RS2RE" |
| 47 | +data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz" |
| 48 | + |
| 49 | +module_dir = os.path.dirname(__file__) |
| 50 | +job_id = os.environ.get("SLURM_JOB_ID", "debug") |
| 51 | +job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) |
| 52 | +# set large fallback job array size for fast testing/debugging |
| 53 | +job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000)) |
| 54 | + |
| 55 | +print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}") |
| 56 | +print(f"{job_id = }") |
| 57 | +print(f"{job_array_id = }") |
| 58 | +print(f"{version('maml') = }") |
| 59 | +print(f"{version('megnet') = }") |
| 60 | + |
| 61 | +today = f"{datetime.now():%Y-%m-%d}" |
| 62 | +out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}" |
| 63 | +os.makedirs(out_dir, exist_ok=True) |
| 64 | +json_out_path = f"{out_dir}/{job_array_id}.json.gz" |
| 65 | + |
| 66 | +if os.path.isfile(json_out_path): |
| 67 | + raise SystemExit(f"{json_out_path = } already exists, exciting early") |
| 68 | + |
| 69 | + |
| 70 | +# %% |
| 71 | +bayes_optim_kwargs = dict( |
| 72 | + relax_coords=True, |
| 73 | + relax_lattice=True, |
| 74 | + use_symmetry=True, |
| 75 | + seed=42, |
| 76 | +) |
| 77 | +optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2) |
| 78 | + |
| 79 | +run_params = dict( |
| 80 | + megnet_version=version("megnet"), |
| 81 | + maml_version=version("maml"), |
| 82 | + job_id=job_id, |
| 83 | + job_array_id=job_array_id, |
| 84 | + data_path=data_path, |
| 85 | + bayes_optim_kwargs=bayes_optim_kwargs, |
| 86 | + optimize_kwargs=optimize_kwargs, |
| 87 | +) |
| 88 | +if wandb.run is None: |
| 89 | + wandb.login() |
| 90 | + |
| 91 | +# getting wandb: 429 encountered ({"error":"rate limit exceeded"}), retrying request |
| 92 | +# https://community.wandb.ai/t/753/14 |
| 93 | +wandb.init( |
| 94 | + entity="janosh", |
| 95 | + project="matbench-discovery", |
| 96 | + name=f"bowsr-megnet-wbm-{task_type}-{job_id}-{job_array_id}", |
| 97 | + config=run_params, |
| 98 | +) |
| 99 | + |
| 100 | + |
| 101 | +# %% |
| 102 | +print(f"Loading from {data_path=}") |
| 103 | +df_wbm = pd.read_json(data_path).set_index("material_id") |
| 104 | + |
| 105 | +df_this_job = np.array_split(df_wbm, job_array_size + 1)[job_array_id] |
| 106 | + |
| 107 | + |
| 108 | +# %% |
| 109 | +model = MEGNet() |
| 110 | +relax_results: dict[str, dict[str, Any]] = {} |
| 111 | + |
| 112 | +if task_type == "IS2RE": |
| 113 | + from pymatgen.core import Structure |
| 114 | + |
| 115 | + structures = df_this_job.initial_structure.map(Structure.from_dict) |
| 116 | +elif task_type == "RS2RE": |
| 117 | + from pymatgen.entries.computed_entries import ComputedStructureEntry |
| 118 | + |
| 119 | + structures = df_this_job.cse.map( |
| 120 | + lambda x: ComputedStructureEntry.from_dict(x).structure |
| 121 | + ) |
| 122 | +else: |
| 123 | + raise ValueError(f"Unknown {task_type = }") |
| 124 | + |
| 125 | + |
| 126 | +for material_id, structure in tqdm( |
| 127 | + structures.items(), desc="Main loop", total=len(structures) |
| 128 | +): |
| 129 | + if material_id in relax_results: |
| 130 | + continue |
| 131 | + bayes_optimizer = BayesianOptimizer( |
| 132 | + model=model, structure=structure, **bayes_optim_kwargs |
| 133 | + ) |
| 134 | + bayes_optimizer.set_bounds() |
| 135 | + # reason for devnull here: https://github.com/materialsvirtuallab/maml/issues/469 |
| 136 | + with open(os.devnull, "w") as devnull, contextlib.redirect_stdout(devnull): |
| 137 | + bayes_optimizer.optimize(**optimize_kwargs) |
| 138 | + |
| 139 | + structure_pred, energy_pred = bayes_optimizer.get_optimized_structure_and_energy() |
| 140 | + |
| 141 | + results = dict( |
| 142 | + e_form_per_atom_pred=model.predict_energy(structure), |
| 143 | + structure_pred=structure_pred, |
| 144 | + energy_pred=energy_pred, |
| 145 | + ) |
| 146 | + |
| 147 | + relax_results[material_id] = results |
| 148 | + |
| 149 | + |
| 150 | +# %% |
| 151 | +df_output = pd.DataFrame(relax_results).T |
| 152 | +df_output.index.name = "material_id" |
| 153 | + |
| 154 | +df_output.reset_index().to_json(json_out_path, default_handler=as_dict_handler) |
| 155 | + |
| 156 | +wandb.log_artifact(json_out_path, type=f"bowsr-megnet-wbm-{task_type}") |
0 commit comments