Skip to content

Commit 87d293a

Browse files
feat: apply descriptor exclude_types to env mat stat (deepmodeling#3625)
Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 39d027e commit 87d293a

File tree

6 files changed

+84
-4
lines changed

6 files changed

+84
-4
lines changed

deepmd/pt/utils/env_mat_stat.py

+3
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def iter(
146146
radial_only,
147147
protection=self.descriptor.env_protection,
148148
)
149+
# apply excluded_types
150+
exclude_mask = self.descriptor.emask(nlist, extended_atype)
151+
env_mat *= exclude_mask.unsqueeze(-1)
149152
# reshape to nframes * nloc at the atom level,
150153
# so nframes/mixed_type do not matter
151154
env_mat = env_mat.view(

deepmd/tf/descriptor/descriptor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Dict,
88
List,
99
Optional,
10+
Set,
1011
Tuple,
1112
)
1213

@@ -357,7 +358,7 @@ def pass_tensors_from_frz_model(
357358

358359
def build_type_exclude_mask(
359360
self,
360-
exclude_types: List[Tuple[int, int]],
361+
exclude_types: Set[Tuple[int, int]],
361362
ntypes: int,
362363
sel: List[int],
363364
ndescrpt: int,

deepmd/tf/descriptor/se_a.py

+12
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,18 @@ def __init__(
288288
sel_a=self.sel_a,
289289
sel_r=self.sel_r,
290290
)
291+
if len(self.exclude_types):
292+
# exclude types applied to data stat
293+
mask = self.build_type_exclude_mask(
294+
self.exclude_types,
295+
self.ntypes,
296+
self.sel_a,
297+
self.ndescrpt,
298+
# for data stat, nloc == nall
299+
self.place_holders["type"],
300+
tf.size(self.place_holders["type"]),
301+
)
302+
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
291303
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
292304
self.original_sel = None
293305
self.multi_task = multi_task

deepmd/tf/descriptor/se_atten.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (
55
List,
66
Optional,
7+
Set,
78
Tuple,
89
)
910

@@ -282,6 +283,19 @@ def __init__(
282283
sel_a=self.sel_all_a,
283284
sel_r=self.sel_all_r,
284285
)
286+
if len(self.exclude_types):
287+
# exclude types applied to data stat
288+
mask = self.build_type_exclude_mask_mixed(
289+
self.exclude_types,
290+
self.ntypes,
291+
self.sel_a,
292+
self.ndescrpt,
293+
# for data stat, nloc == nall
294+
self.place_holders["type"],
295+
tf.size(self.place_holders["type"]),
296+
self.nei_type_vec_t, # extra input for atten
297+
)
298+
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
285299
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
286300

287301
def compute_input_stats(
@@ -672,7 +686,7 @@ def _pass_filter(
672686
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
673687
type_i = -1
674688
if len(self.exclude_types):
675-
mask = self.build_type_exclude_mask(
689+
mask = self.build_type_exclude_mask_mixed(
676690
self.exclude_types,
677691
self.ntypes,
678692
self.sel_a,
@@ -1367,9 +1381,9 @@ def init_variables(
13671381
)
13681382
)
13691383

1370-
def build_type_exclude_mask(
1384+
def build_type_exclude_mask_mixed(
13711385
self,
1372-
exclude_types: List[Tuple[int, int]],
1386+
exclude_types: Set[Tuple[int, int]],
13731387
ntypes: int,
13741388
sel: List[int],
13751389
ndescrpt: int,

deepmd/tf/descriptor/se_r.py

+12
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ def __init__(
196196
rcut_smth=self.rcut_smth,
197197
sel=self.sel_r,
198198
)
199+
if len(self.exclude_types):
200+
# exclude types applied to data stat
201+
mask = self.build_type_exclude_mask(
202+
self.exclude_types,
203+
self.ntypes,
204+
self.sel_r,
205+
self.ndescrpt,
206+
# for data stat, nloc == nall
207+
self.place_holders["type"],
208+
tf.size(self.place_holders["type"]),
209+
)
210+
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
199211
self.sub_sess = tf.Session(
200212
graph=sub_graph, config=default_tf_session_config
201213
)

source/tests/pt/test_stat.py

+38
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,44 @@ def tf_compute_input_stats(self):
337337
)
338338

339339

340+
class TestExcludeTypes(DatasetTest, unittest.TestCase):
341+
def setup_data(self):
342+
original_data = str(Path(__file__).parent / "water/data/data_0")
343+
picked_data = str(Path(__file__).parent / "picked_data_for_test_stat")
344+
dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy(
345+
picked_data
346+
)
347+
self.mixed_type = False
348+
return picked_data
349+
350+
def setup_tf(self):
351+
return DescrptSeA_tf(
352+
rcut=self.rcut,
353+
rcut_smth=self.rcut_smth,
354+
sel=self.sel,
355+
neuron=self.filter_neuron,
356+
axis_neuron=self.axis_neuron,
357+
exclude_types=[[0, 0], [1, 1]],
358+
)
359+
360+
def setup_pt(self):
361+
return DescrptSeA(
362+
self.rcut,
363+
self.rcut_smth,
364+
self.sel,
365+
self.filter_neuron,
366+
self.axis_neuron,
367+
exclude_types=[[0, 0], [1, 1]],
368+
).sea # get the block who has stat as private vars
369+
370+
def tf_compute_input_stats(self):
371+
coord = self.dp_merged["coord"]
372+
atype = self.dp_merged["type"]
373+
natoms = self.dp_merged["natoms_vec"]
374+
box = self.dp_merged["box"]
375+
self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {})
376+
377+
340378
class TestOutputStat(unittest.TestCase):
341379
def setUp(self):
342380
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]

0 commit comments

Comments
 (0)