|
18 | 18 | import os
|
19 | 19 | import queue
|
20 | 20 | import random
|
| 21 | +import re |
21 | 22 | import shutil
|
22 | 23 | import sys
|
23 | 24 | import warnings
|
@@ -915,7 +916,11 @@ def parse_cur_job(cur_job):
|
915 | 916 | dt = _get_param_alias(cur_job, ["dt"])
|
916 | 917 | else:
|
917 | 918 | dt = None
|
918 |
| - return ensemble, nsteps, trj_freq, temps, press, pka_e, dt |
| 919 | + if "nbeads" in cur_job: |
| 920 | + nbeads = _get_param_alias(cur_job, ["nbeads"]) |
| 921 | + else: |
| 922 | + nbeads = None |
| 923 | + return ensemble, nsteps, trj_freq, temps, press, pka_e, dt, nbeads |
919 | 924 |
|
920 | 925 |
|
921 | 926 | def expand_matrix_values(target_list, cur_idx=0):
|
@@ -1457,7 +1462,21 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
|
1457 | 1462 | if iter_index >= len(model_devi_jobs):
|
1458 | 1463 | return False
|
1459 | 1464 | cur_job = model_devi_jobs[iter_index]
|
1460 |
| - ensemble, nsteps, trj_freq, temps, press, pka_e, dt = parse_cur_job(cur_job) |
| 1465 | + ensemble, nsteps, trj_freq, temps, press, pka_e, dt, nbeads = parse_cur_job(cur_job) |
| 1466 | + model_devi_f_avg_relative = jdata.get("model_devi_f_avg_relative", False) |
| 1467 | + model_devi_merge_traj = jdata.get("model_devi_merge_traj", False) |
| 1468 | + if (nbeads is not None) and model_devi_f_avg_relative: |
| 1469 | + raise RuntimeError( |
| 1470 | + "model_devi_f_avg_relative has not been supported for pimd. Set model_devi_f_avg_relative to False." |
| 1471 | + ) |
| 1472 | + if (nbeads is not None) and (model_devi_merge_traj): |
| 1473 | + raise RuntimeError( |
| 1474 | + "model_devi_merge_traj has not been supported for pimd. Set model_devi_merge_traj to False." |
| 1475 | + ) |
| 1476 | + if (nbeads is not None) and (not (nsteps % trj_freq == 0)): |
| 1477 | + raise RuntimeError( |
| 1478 | + "trj_freq should be a factor of nsteps for pimd. Please check your input." |
| 1479 | + ) |
1461 | 1480 | if dt is not None:
|
1462 | 1481 | model_devi_dt = dt
|
1463 | 1482 | sys_idx = expand_idx(cur_job["sys_idx"])
|
@@ -1560,6 +1579,7 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
|
1560 | 1579 | ele_temp_a=te_a,
|
1561 | 1580 | nopbc=nopbc,
|
1562 | 1581 | deepmd_version=deepmd_version,
|
| 1582 | + nbeads=nbeads, |
1563 | 1583 | )
|
1564 | 1584 | job = {}
|
1565 | 1585 | job["ensemble"] = ensemble
|
@@ -1916,12 +1936,23 @@ def run_md_model_devi(iter_index, jdata, mdata):
|
1916 | 1936 |
|
1917 | 1937 | model_devi_engine = jdata.get("model_devi_engine", "lammps")
|
1918 | 1938 | if model_devi_engine == "lammps":
|
1919 |
| - command = f"{{ if [ ! -f dpgen.restart.10000 ]; then {model_devi_exec} -i input.lammps -v restart 0; else {model_devi_exec} -i input.lammps -v restart 1; fi }}" |
1920 |
| - command = "/bin/sh -c '%s'" % command |
| 1939 | + nbeads = jdata["model_devi_jobs"][iter_index].get("nbeads") |
| 1940 | + if nbeads is None: |
| 1941 | + command = f"{{ if [ ! -f dpgen.restart.10000 ]; then {model_devi_exec} -i input.lammps -v restart 0; else {model_devi_exec} -i input.lammps -v restart 1; fi }}" |
| 1942 | + else: |
| 1943 | + command = f"{{ all_exist=true; for i in $(seq -w 1 {nbeads}); do [[ ! -f dpgen.restart${{i}}.10000 ]] && {{ all_exist=false; break; }}; done; $all_exist && {{ {model_devi_exec} -p {nbeads}x1 -i input.lammps -v restart 1; }} || {{ {model_devi_exec} -p {nbeads}x1 -i input.lammps -v restart 0; }} }}" |
| 1944 | + command = "/bin/bash -c '%s'" % command |
1921 | 1945 | commands = [command]
|
1922 | 1946 |
|
1923 | 1947 | forward_files = ["conf.lmp", "input.lammps"]
|
1924 |
| - backward_files = ["model_devi.out", "model_devi.log"] |
| 1948 | + backward_files = ["model_devi.log"] |
| 1949 | + if nbeads is None: |
| 1950 | + backward_files += ["model_devi.out"] |
| 1951 | + else: |
| 1952 | + num_digits = np.ceil(np.log10(nbeads + 1)).astype(int) |
| 1953 | + backward_files += [ |
| 1954 | + f"model_devi{i+1:0{num_digits}d}.out" for i in range(nbeads) |
| 1955 | + ] |
1925 | 1956 | if model_devi_merge_traj:
|
1926 | 1957 | backward_files += ["all.lammpstrj"]
|
1927 | 1958 | else:
|
@@ -2131,6 +2162,75 @@ def _read_model_devi_file(
|
2131 | 2162 | model_devi_f_avg_relative: bool = False,
|
2132 | 2163 | model_devi_merge_traj: bool = False,
|
2133 | 2164 | ):
|
| 2165 | + model_devi_files = glob.glob(os.path.join(task_path, "model_devi*.out")) |
| 2166 | + model_devi_files_sorted = sorted( |
| 2167 | + model_devi_files, key=lambda x: int(re.search(r"(\d+)", x).group(1)) |
| 2168 | + ) |
| 2169 | + if len(model_devi_files_sorted) > 1: |
| 2170 | + with open(model_devi_files_sorted[0]) as f: |
| 2171 | + first_line = f.readline() |
| 2172 | + if not (first_line.startswith("#")): |
| 2173 | + first_line = "#" |
| 2174 | + num_beads = len(model_devi_files_sorted) |
| 2175 | + model_devi_contents = [] |
| 2176 | + for file in model_devi_files_sorted: |
| 2177 | + model_devi_contents.append(np.loadtxt(file)) |
| 2178 | + assert all( |
| 2179 | + model_devi_content.shape[0] == model_devi_contents[0].shape[0] |
| 2180 | + for model_devi_content in model_devi_contents |
| 2181 | + ), "Not all beads generated the same number of lines in the model_devi$\{ibead\}.out file. Check your pimd task carefully." |
| 2182 | + for file in model_devi_files_sorted: |
| 2183 | + os.remove(file) |
| 2184 | + last_step = model_devi_contents[0][-1, 0] |
| 2185 | + for ibead in range(1, num_beads): |
| 2186 | + model_devi_contents[ibead][:, 0] = model_devi_contents[ibead][ |
| 2187 | + :, 0 |
| 2188 | + ] + ibead * (last_step + 1) |
| 2189 | + model_devi = np.concatenate(model_devi_contents, axis=0) |
| 2190 | + num_columns = model_devi.shape[1] |
| 2191 | + formats = ["%12d"] + ["%22.6e"] * (num_columns - 1) |
| 2192 | + np.savetxt( |
| 2193 | + os.path.join(task_path, "model_devi.out"), |
| 2194 | + model_devi, |
| 2195 | + fmt=formats, |
| 2196 | + header=first_line.rstrip(), |
| 2197 | + comments="", |
| 2198 | + ) |
| 2199 | + |
| 2200 | + if not model_devi_merge_traj: |
| 2201 | + num_digits = np.ceil(np.log10(num_beads + 1)).astype(int) |
| 2202 | + traj_files_sorted = [] |
| 2203 | + for ibead in range(num_beads): |
| 2204 | + traj_files = glob.glob( |
| 2205 | + os.path.join( |
| 2206 | + task_path, "traj", f"*lammpstrj{ibead+1:0{num_digits}d}" |
| 2207 | + ) |
| 2208 | + ) |
| 2209 | + traj_files_sorted.append( |
| 2210 | + sorted( |
| 2211 | + traj_files, |
| 2212 | + key=lambda x: int( |
| 2213 | + re.search(r"^(\d+)\.lammpstrj", os.path.basename(x)).group( |
| 2214 | + 1 |
| 2215 | + ) |
| 2216 | + ), |
| 2217 | + ) |
| 2218 | + ) |
| 2219 | + assert all( |
| 2220 | + len(traj_list) == len(traj_files_sorted[0]) |
| 2221 | + for traj_list in traj_files_sorted |
| 2222 | + ), "Not all beads generated the same number of frames. Check your pimd task carefully." |
| 2223 | + for ibead in range(num_beads): |
| 2224 | + for itraj in range(len(traj_files_sorted[0])): |
| 2225 | + base_path, original_filename = os.path.split( |
| 2226 | + traj_files_sorted[ibead][itraj] |
| 2227 | + ) |
| 2228 | + frame_number = int(original_filename.split(".")[0]) |
| 2229 | + new_filename = os.path.join( |
| 2230 | + base_path, |
| 2231 | + f"{frame_number + ibead * (int(last_step)+1):d}.lammpstrj", |
| 2232 | + ) |
| 2233 | + os.rename(traj_files_sorted[ibead][itraj], new_filename) |
2134 | 2234 | model_devi = np.loadtxt(os.path.join(task_path, "model_devi.out"))
|
2135 | 2235 | if model_devi_f_avg_relative:
|
2136 | 2236 | if model_devi_merge_traj is True:
|
|
0 commit comments