Skip to content

simplify: support model deviation of energy per atom #1312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def general_simplify_arginfo() -> Argument:
doc_model_devi_f_trust_hi = (
"The higher bound of forces for the selection for the model deviation."
)
doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Expand All @@ -50,6 +52,20 @@ def general_simplify_arginfo() -> Argument:
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Argument(
"model_devi_e_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_model_devi_e_trust_lo,
),
Argument(
"model_devi_e_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_model_devi_e_trust_hi,
),
]


Expand Down
30 changes: 23 additions & 7 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def post_model_devi(iter_index, jdata, mdata):

f_trust_lo = jdata["model_devi_f_trust_lo"]
f_trust_hi = jdata["model_devi_f_trust_hi"]
e_trust_lo = jdata["model_devi_e_trust_lo"]
e_trust_hi = jdata["model_devi_e_trust_hi"]

type_map = jdata.get("type_map", [])
sys_accurate = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -285,16 +287,30 @@ def post_model_devi(iter_index, jdata, mdata):
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
pass
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:
# DeePMD-kit < 2.2.2
cidx_devi_e = None
else:
idx = int(line.split()[0])
f_devi = float(line.split()[4])
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
else:
e_devi = 0.0
subsys = sys_entire[name][idx]
if f_trust_lo <= f_devi < f_trust_hi:
sys_candinate.append(subsys)
elif f_devi >= f_trust_hi:
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif f_devi < f_trust_lo:
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
else:
raise RuntimeError("reach a place that should NOT be reached...")
Expand Down
116 changes: 116 additions & 0 deletions tests/simplify/test_post_model_devi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import shutil
import sys
import unittest
from pathlib import Path

import dpdata
import numpy as np

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
__package__ = "simplify"
from .context import dpgen


class TestSimplifyModelDevi(unittest.TestCase):
def setUp(self):
self.work_path = Path("iter.000001") / dpgen.simplify.simplify.model_devi_name
self.work_path.mkdir(exist_ok=True, parents=True)
self.system = dpdata.System(
data={
"atom_names": ["H"],
"atom_numbs": [1],
"atom_types": np.zeros((1,), dtype=int),
"coords": np.zeros((1, 1, 3), dtype=np.float32),
"cells": np.zeros((1, 3, 3), dtype=np.float32),
"orig": np.zeros(3, dtype=np.float32),
"nopbc": True,
"energies": np.zeros((1,), dtype=np.float32),
"forces": np.zeros((1, 1, 3), dtype=np.float32),
}
)
self.system.to_deepmd_npy(
self.work_path / "data.rest.old" / self.system.formula
)
model_devi = np.array([[0, 0.2, 0.1, 0.15, 0.2, 0.1, 0.15, 0.2]])
np.savetxt(
self.work_path / "details",
model_devi,
fmt=["%12d"] + ["%19.6e" for _ in range(7)],
header="data.rest.old/"
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)

def tearDown(self):
shutil.rmtree("iter.000001", ignore_errors=True)

def test_post_model_devi_f_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.15,
"model_devi_f_trust_hi": 0.25,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_e_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_f_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.0,
"model_devi_f_trust_hi": 0.0,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_e_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.0,
"model_devi_e_trust_hi": 0.0,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_accurate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.3,
"model_devi_e_trust_hi": 0.4,
"model_devi_f_trust_lo": 0.3,
"model_devi_f_trust_hi": 0.4,
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.accurate" / self.system.formula).exists()