|
3 | 3 |
|
4 | 4 | import os
|
5 | 5 | from importlib.metadata import version
|
6 |
| -from typing import Any |
| 6 | +from typing import Any, Literal |
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | import pandas as pd
|
10 | 10 | import torch
|
11 | 11 | import wandb
|
12 |
| -from ase.filters import FrechetCellFilter |
| 12 | +from ase.filters import ExpCellFilter, FrechetCellFilter |
13 | 13 | from ase.optimize import FIRE, LBFGS
|
14 | 14 | from mace.calculators import mace_mp
|
15 | 15 | from mace.tools import count_parameters
|
|
31 | 31 | task_type = "IS2RE" # "RS2RE"
|
32 | 32 | module_dir = os.path.dirname(__file__)
|
33 | 33 | # set large job array size for smaller data splits and faster testing/debugging
|
34 |
| -slurm_array_task_count = 20 |
| 34 | +slurm_array_task_count = 50 |
35 | 35 | ase_optimizer = "FIRE"
|
36 | 36 | job_name = f"mace-wbm-{task_type}-{ase_optimizer}"
|
37 | 37 | out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
|
|
42 | 42 | "2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss",
|
43 | 43 | "https://tinyurl.com/y7uhwpje",
|
44 | 44 | ][-1]
|
| 45 | +ase_filter: Literal["frechet", "exp"] = "frechet" |
45 | 46 |
|
46 | 47 | slurm_vars = slurm_submit(
|
47 | 48 | job_name=job_name,
|
48 | 49 | out_dir=out_dir,
|
49 | 50 | account="matgen",
|
50 |
| - time="9:55:0", |
| 51 | + time="11:55:0", |
51 | 52 | array=f"1-{slurm_array_task_count}",
|
52 |
| - slurm_flags="--qos shared --constraint gpu --gpus 1", |
53 |
| - # slurm_flags="--qos shared --constraint cpu --mem 16G", |
| 53 | + # slurm_flags="--qos shared --constraint gpu --gpus 1", |
| 54 | + slurm_flags="--qos shared --constraint cpu --mem 32G", |
54 | 55 | )
|
55 | 56 |
|
56 | 57 |
|
|
98 | 99 | trainable_params=count_parameters(mace_calc.models[0]),
|
99 | 100 | model_name=model_name,
|
100 | 101 | dtype=dtype,
|
| 102 | + ase_filter=ase_filter, |
101 | 103 | )
|
102 | 104 |
|
103 | 105 | run_name = f"{job_name}-{slurm_array_task_id}"
|
|
112 | 114 | df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
|
113 | 115 |
|
114 | 116 | structs = df_in[input_col].map(Structure.from_dict).to_dict()
|
| 117 | +filter_cls = {"frechet": FrechetCellFilter, "exp": ExpCellFilter}[ase_filter] |
115 | 118 |
|
116 | 119 | for material_id in tqdm(structs, desc="Relaxing"):
|
117 | 120 | if material_id in relax_results:
|
|
121 | 124 | atoms = structs[material_id].to_ase_atoms()
|
122 | 125 | atoms.calc = mace_calc
|
123 | 126 | if max_steps > 0:
|
124 |
| - atoms = FrechetCellFilter(atoms) |
| 127 | + atoms = filter_cls(atoms) |
125 | 128 | optim_cls = {"FIRE": FIRE, "LBFGS": LBFGS}[ase_optimizer]
|
126 | 129 | optimizer = optim_cls(atoms, logfile="/dev/null")
|
127 | 130 |
|
|
0 commit comments