Skip to content

Commit 9ddef6a

Browse files
Yi-FanLipre-commit-ci[bot]njzjz
authored
model_devi: add support for pimd (#1366)
Add support for LAMMPS's fix pimd/langevin in model deviation tasks. --------- Signed-off-by: Yifan Li李一帆 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent 3da2472 commit 9ddef6a

File tree

6 files changed

+426
-32
lines changed

6 files changed

+426
-32
lines changed

dpgen/generator/arginfo.py

+2
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def model_devi_jobs_args() -> list[Argument]:
262262
doc_press = "Pressure (Bar) in MD. Required when ensemble is npt."
263263
doc_trj_freq = "Frequecy of trajectory saved in MD."
264264
doc_nsteps = "Running steps of MD. It is not optional when not using a template."
265+
doc_nbeads = "Number of beads in PIMD. If not given, classical MD will be performed. Only supported for LAMMPS version >= 20230615."
265266
doc_ensemble = "Determining which ensemble used in MD, options include “npt” and “nvt”. It is not optional when not using a template."
266267
doc_neidelay = "delay building until this many steps since last build."
267268
doc_taut = "Coupling time of thermostat (ps)."
@@ -280,6 +281,7 @@ def model_devi_jobs_args() -> list[Argument]:
280281
Argument("press", list[float], optional=True, doc=doc_press),
281282
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
282283
Argument("nsteps", int, optional=True, doc=doc_nsteps),
284+
Argument("nbeads", int, optional=True, doc=doc_nbeads),
283285
Argument("ensemble", str, optional=True, doc=doc_ensemble),
284286
Argument("neidelay", int, optional=True, doc=doc_neidelay),
285287
Argument("taut", float, optional=True, doc=doc_taut),

dpgen/generator/lib/lammps.py

+80-24
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def make_lammps_input(
3737
max_seed=1000000,
3838
nopbc=False,
3939
deepmd_version="0.1",
40+
nbeads=None,
4041
):
4142
if (ele_temp_f is not None or ele_temp_a is not None) and Version(
4243
deepmd_version
@@ -49,9 +50,22 @@ def make_lammps_input(
4950
"the frame style ele_temp and atom style ele_temp should not be set at the same time"
5051
)
5152
ret = "variable NSTEPS equal %d\n" % nsteps
53+
if nbeads is not None:
54+
if nbeads <= 0:
55+
raise ValueError(
56+
"The number of beads should be positive. Check your nbeads setting."
57+
)
58+
power = 1
59+
while power < nbeads:
60+
power *= 10
61+
ret += "variable ibead uloop %d pad\n" % (power - 1)
62+
if nbeads is not None:
63+
ret += "atom_modify map yes\n"
5264
ret += "variable THERMO_FREQ equal %d\n" % trj_freq
5365
ret += "variable DUMP_FREQ equal %d\n" % trj_freq
5466
ret += "variable TEMP equal %f\n" % temp
67+
if nbeads is not None:
68+
ret += "variable TEMP_NBEADS equal %f\n" % (temp * nbeads)
5569
if ele_temp_f is not None:
5670
ret += "variable ELE_TEMP equal %f\n" % ele_temp_f
5771
if ele_temp_a is not None:
@@ -72,10 +86,16 @@ def make_lammps_input(
7286
ret += "neigh_modify delay %d\n" % neidelay
7387
ret += "\n"
7488
ret += "box tilt large\n"
75-
ret += (
76-
'if "${restart} > 0" then "read_restart dpgen.restart.*" else "read_data %s"\n'
77-
% conf_file
78-
)
89+
if nbeads is None:
90+
ret += (
91+
'if "${restart} > 0" then "read_restart dpgen.restart.*" else "read_data %s"\n'
92+
% conf_file
93+
)
94+
else:
95+
ret += (
96+
'if "${restart} > 0" then "read_restart dpgen.restart${ibead}.*" else "read_data %s"\n'
97+
% conf_file
98+
)
7999
ret += "change_box all triclinic\n"
80100
for jj in range(len(mass_map)):
81101
ret += "mass %d %f\n" % (jj + 1, mass_map[jj])
@@ -98,23 +118,43 @@ def make_lammps_input(
98118
keywords += "fparam ${ELE_TEMP}"
99119
if ele_temp_a is not None:
100120
keywords += "aparam ${ELE_TEMP}"
101-
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi.out {keywords}\n"
121+
if nbeads is None:
122+
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi.out {keywords}\n"
123+
else:
124+
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi${{ibead}}.out {keywords}\n"
102125
ret += "pair_coeff * *\n"
103126
ret += "\n"
104127
ret += "thermo_style custom step temp pe ke etotal press vol lx ly lz xy xz yz\n"
105128
ret += "thermo ${THERMO_FREQ}\n"
106129
model_devi_merge_traj = jdata.get("model_devi_merge_traj", False)
107-
if model_devi_merge_traj is True:
108-
ret += "dump 1 all custom ${DUMP_FREQ} all.lammpstrj id type x y z fx fy fz\n"
109-
ret += 'if "${restart} > 0" then "dump_modify 1 append yes"\n'
130+
if nbeads is None:
131+
if model_devi_merge_traj is True:
132+
ret += "dump 1 all custom ${DUMP_FREQ} all.lammpstrj id type x y z fx fy fz\n"
133+
ret += 'if "${restart} > 0" then "dump_modify 1 append yes"\n'
134+
else:
135+
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n"
110136
else:
111-
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n"
112-
ret += "restart 10000 dpgen.restart\n"
137+
if model_devi_merge_traj is True:
138+
ret += "dump 1 all custom ${DUMP_FREQ} all.lammpstrj${ibead} id type x y z fx fy fz\n"
139+
ret += 'if "${restart} > 0" then "dump_modify 1 append yes"\n'
140+
else:
141+
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj${ibead} id type x y z fx fy fz\n"
142+
if nbeads is None:
143+
ret += "restart 10000 dpgen.restart\n"
144+
else:
145+
ret += "restart 10000 dpgen.restart${ibead}\n"
113146
ret += "\n"
114147
if pka_e is None:
115-
ret += 'if "${restart} == 0" then "velocity all create ${TEMP} %d"' % (
116-
random.randrange(max_seed - 1) + 1
117-
)
148+
if nbeads is None:
149+
ret += (
150+
'if "${restart} == 0" then "velocity all create ${TEMP} %d"'
151+
% (random.randrange(max_seed - 1) + 1)
152+
)
153+
else:
154+
ret += (
155+
'if "${restart} == 0" then "velocity all create ${TEMP_NBEADS} %d"'
156+
% (random.randrange(max_seed - 1) + 1)
157+
)
118158
else:
119159
sys = dpdata.System(conf_file, fmt="lammps/lmp")
120160
sys_data = sys.data
@@ -140,18 +180,34 @@ def make_lammps_input(
140180
assert pres is not None
141181
if nopbc:
142182
raise RuntimeError("ensemble %s is conflicting with nopbc" % ensemble)
143-
if ensemble == "npt" or ensemble == "npt-i" or ensemble == "npt-iso":
144-
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} iso ${PRES} ${PRES} ${TAU_P}\n"
145-
elif ensemble == "npt-a" or ensemble == "npt-aniso":
146-
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} aniso ${PRES} ${PRES} ${TAU_P}\n"
147-
elif ensemble == "npt-t" or ensemble == "npt-tri":
148-
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} tri ${PRES} ${PRES} ${TAU_P}\n"
149-
elif ensemble == "nvt":
150-
ret += "fix 1 all nvt temp ${TEMP} ${TEMP} ${TAU_T}\n"
151-
elif ensemble == "nve":
152-
ret += "fix 1 all nve\n"
183+
if nbeads is None:
184+
if ensemble == "npt" or ensemble == "npt-i" or ensemble == "npt-iso":
185+
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} iso ${PRES} ${PRES} ${TAU_P}\n"
186+
elif ensemble == "npt-a" or ensemble == "npt-aniso":
187+
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} aniso ${PRES} ${PRES} ${TAU_P}\n"
188+
elif ensemble == "npt-t" or ensemble == "npt-tri":
189+
ret += "fix 1 all npt temp ${TEMP} ${TEMP} ${TAU_T} tri ${PRES} ${PRES} ${TAU_P}\n"
190+
elif ensemble == "nvt":
191+
ret += "fix 1 all nvt temp ${TEMP} ${TEMP} ${TAU_T}\n"
192+
elif ensemble == "nve":
193+
ret += "fix 1 all nve\n"
194+
else:
195+
raise RuntimeError("unknown emsemble " + ensemble)
153196
else:
154-
raise RuntimeError("unknown emsemble " + ensemble)
197+
if ensemble == "npt" or ensemble == "npt-i" or ensemble == "npt-iso":
198+
ret += "fix 1 all pimd/langevin fmmode physical ensemble npt integrator obabo thermostat PILE_L ${ibead} temp ${TEMP} tau ${TAU_T} scale 1.0 barostat BZP iso ${PRES} taup ${TAU_P}\n"
199+
elif ensemble == "npt-a" or ensemble == "npt-aniso":
200+
ret += "fix 1 all pimd/langevin fmmode physical ensemble npt integrator obabo thermostat PILE_L ${ibead} temp ${TEMP} tau ${TAU_T} scale 1.0 barostat BZP aniso ${PRES} taup ${TAU_P}\n"
201+
elif ensemble == "nvt":
202+
ret += "fix 1 all pimd/langevin fmmode physical ensemble nvt integrator obabo thermostat PILE_L ${ibead} temp ${TEMP} tau ${TAU_T} scale 1.0\n"
203+
elif ensemble == "nve":
204+
ret += "fix 1 all pimd/langevin fmmode physical ensemble nve integrator obabo temp ${TEMP}\n"
205+
else:
206+
raise RuntimeError(
207+
"unknown emsemble "
208+
+ ensemble
209+
+ " for fix pimd/langevin\nrefer to https://docs.lammps.org/fix_pimd.html for more information"
210+
)
155211
if nopbc:
156212
ret += "velocity all zero linear\n"
157213
ret += "fix fm all momentum 1 linear 1 1 1\n"

dpgen/generator/run.py

+105-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import queue
2020
import random
21+
import re
2122
import shutil
2223
import sys
2324
import warnings
@@ -915,7 +916,11 @@ def parse_cur_job(cur_job):
915916
dt = _get_param_alias(cur_job, ["dt"])
916917
else:
917918
dt = None
918-
return ensemble, nsteps, trj_freq, temps, press, pka_e, dt
919+
if "nbeads" in cur_job:
920+
nbeads = _get_param_alias(cur_job, ["nbeads"])
921+
else:
922+
nbeads = None
923+
return ensemble, nsteps, trj_freq, temps, press, pka_e, dt, nbeads
919924

920925

921926
def expand_matrix_values(target_list, cur_idx=0):
@@ -1457,7 +1462,21 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
14571462
if iter_index >= len(model_devi_jobs):
14581463
return False
14591464
cur_job = model_devi_jobs[iter_index]
1460-
ensemble, nsteps, trj_freq, temps, press, pka_e, dt = parse_cur_job(cur_job)
1465+
ensemble, nsteps, trj_freq, temps, press, pka_e, dt, nbeads = parse_cur_job(cur_job)
1466+
model_devi_f_avg_relative = jdata.get("model_devi_f_avg_relative", False)
1467+
model_devi_merge_traj = jdata.get("model_devi_merge_traj", False)
1468+
if (nbeads is not None) and model_devi_f_avg_relative:
1469+
raise RuntimeError(
1470+
"model_devi_f_avg_relative has not been supported for pimd. Set model_devi_f_avg_relative to False."
1471+
)
1472+
if (nbeads is not None) and (model_devi_merge_traj):
1473+
raise RuntimeError(
1474+
"model_devi_merge_traj has not been supported for pimd. Set model_devi_merge_traj to False."
1475+
)
1476+
if (nbeads is not None) and (not (nsteps % trj_freq == 0)):
1477+
raise RuntimeError(
1478+
"trj_freq should be a factor of nsteps for pimd. Please check your input."
1479+
)
14611480
if dt is not None:
14621481
model_devi_dt = dt
14631482
sys_idx = expand_idx(cur_job["sys_idx"])
@@ -1560,6 +1579,7 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
15601579
ele_temp_a=te_a,
15611580
nopbc=nopbc,
15621581
deepmd_version=deepmd_version,
1582+
nbeads=nbeads,
15631583
)
15641584
job = {}
15651585
job["ensemble"] = ensemble
@@ -1916,12 +1936,23 @@ def run_md_model_devi(iter_index, jdata, mdata):
19161936

19171937
model_devi_engine = jdata.get("model_devi_engine", "lammps")
19181938
if model_devi_engine == "lammps":
1919-
command = f"{{ if [ ! -f dpgen.restart.10000 ]; then {model_devi_exec} -i input.lammps -v restart 0; else {model_devi_exec} -i input.lammps -v restart 1; fi }}"
1920-
command = "/bin/sh -c '%s'" % command
1939+
nbeads = jdata["model_devi_jobs"][iter_index].get("nbeads")
1940+
if nbeads is None:
1941+
command = f"{{ if [ ! -f dpgen.restart.10000 ]; then {model_devi_exec} -i input.lammps -v restart 0; else {model_devi_exec} -i input.lammps -v restart 1; fi }}"
1942+
else:
1943+
command = f"{{ all_exist=true; for i in $(seq -w 1 {nbeads}); do [[ ! -f dpgen.restart${{i}}.10000 ]] && {{ all_exist=false; break; }}; done; $all_exist && {{ {model_devi_exec} -p {nbeads}x1 -i input.lammps -v restart 1; }} || {{ {model_devi_exec} -p {nbeads}x1 -i input.lammps -v restart 0; }} }}"
1944+
command = "/bin/bash -c '%s'" % command
19211945
commands = [command]
19221946

19231947
forward_files = ["conf.lmp", "input.lammps"]
1924-
backward_files = ["model_devi.out", "model_devi.log"]
1948+
backward_files = ["model_devi.log"]
1949+
if nbeads is None:
1950+
backward_files += ["model_devi.out"]
1951+
else:
1952+
num_digits = np.ceil(np.log10(nbeads + 1)).astype(int)
1953+
backward_files += [
1954+
f"model_devi{i+1:0{num_digits}d}.out" for i in range(nbeads)
1955+
]
19251956
if model_devi_merge_traj:
19261957
backward_files += ["all.lammpstrj"]
19271958
else:
@@ -2131,6 +2162,75 @@ def _read_model_devi_file(
21312162
model_devi_f_avg_relative: bool = False,
21322163
model_devi_merge_traj: bool = False,
21332164
):
2165+
model_devi_files = glob.glob(os.path.join(task_path, "model_devi*.out"))
2166+
model_devi_files_sorted = sorted(
2167+
model_devi_files, key=lambda x: int(re.search(r"(\d+)", x).group(1))
2168+
)
2169+
if len(model_devi_files_sorted) > 1:
2170+
with open(model_devi_files_sorted[0]) as f:
2171+
first_line = f.readline()
2172+
if not (first_line.startswith("#")):
2173+
first_line = "#"
2174+
num_beads = len(model_devi_files_sorted)
2175+
model_devi_contents = []
2176+
for file in model_devi_files_sorted:
2177+
model_devi_contents.append(np.loadtxt(file))
2178+
assert all(
2179+
model_devi_content.shape[0] == model_devi_contents[0].shape[0]
2180+
for model_devi_content in model_devi_contents
2181+
), "Not all beads generated the same number of lines in the model_devi$\{ibead\}.out file. Check your pimd task carefully."
2182+
for file in model_devi_files_sorted:
2183+
os.remove(file)
2184+
last_step = model_devi_contents[0][-1, 0]
2185+
for ibead in range(1, num_beads):
2186+
model_devi_contents[ibead][:, 0] = model_devi_contents[ibead][
2187+
:, 0
2188+
] + ibead * (last_step + 1)
2189+
model_devi = np.concatenate(model_devi_contents, axis=0)
2190+
num_columns = model_devi.shape[1]
2191+
formats = ["%12d"] + ["%22.6e"] * (num_columns - 1)
2192+
np.savetxt(
2193+
os.path.join(task_path, "model_devi.out"),
2194+
model_devi,
2195+
fmt=formats,
2196+
header=first_line.rstrip(),
2197+
comments="",
2198+
)
2199+
2200+
if not model_devi_merge_traj:
2201+
num_digits = np.ceil(np.log10(num_beads + 1)).astype(int)
2202+
traj_files_sorted = []
2203+
for ibead in range(num_beads):
2204+
traj_files = glob.glob(
2205+
os.path.join(
2206+
task_path, "traj", f"*lammpstrj{ibead+1:0{num_digits}d}"
2207+
)
2208+
)
2209+
traj_files_sorted.append(
2210+
sorted(
2211+
traj_files,
2212+
key=lambda x: int(
2213+
re.search(r"^(\d+)\.lammpstrj", os.path.basename(x)).group(
2214+
1
2215+
)
2216+
),
2217+
)
2218+
)
2219+
assert all(
2220+
len(traj_list) == len(traj_files_sorted[0])
2221+
for traj_list in traj_files_sorted
2222+
), "Not all beads generated the same number of frames. Check your pimd task carefully."
2223+
for ibead in range(num_beads):
2224+
for itraj in range(len(traj_files_sorted[0])):
2225+
base_path, original_filename = os.path.split(
2226+
traj_files_sorted[ibead][itraj]
2227+
)
2228+
frame_number = int(original_filename.split(".")[0])
2229+
new_filename = os.path.join(
2230+
base_path,
2231+
f"{frame_number + ibead * (int(last_step)+1):d}.lammpstrj",
2232+
)
2233+
os.rename(traj_files_sorted[ibead][itraj], new_filename)
21342234
model_devi = np.loadtxt(os.path.join(task_path, "model_devi.out"))
21352235
if model_devi_f_avg_relative:
21362236
if model_devi_merge_traj is True:

tests/generator/context.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dpgen.util import setup_ele_temp # noqa: F401
2121

2222
param_file = "param-mg-vasp.json"
23+
param_pimd_file = "param-mg-pimd-vasp.json"
2324
param_file_merge_traj = "param-mg-vasp_merge_traj.json"
2425
param_file_v1 = "param-mg-vasp-v1.json"
2526
param_file_v1_et = "param-mg-vasp-v1-et.json"

0 commit comments

Comments
 (0)