Skip to content

Commit 30e57e4

Browse files
committed
tf: refactor neighbor stat (deepmodeling#3275)
Fix deepmodeling#3272. Apply implementation of deepmodeling#3271 into TF. Confirm consistent results on `examples/water`, `examples/nopbc`, and ANI-1x (deepmodeling#1624). 80x speed up: ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/85aa1fed-e3c0-4cb6-9082-db45c9a03f9d) --------- Signed-off-by: Jinzhe Zeng <[email protected]> (cherry picked from commit 02080db) --------- Cleanup for r2. Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 4ac18c7 commit 30e57e4

File tree

4 files changed

+537
-80
lines changed

4 files changed

+537
-80
lines changed

deepmd/utils/neighbor_stat.py

+245-80
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,161 @@
22
import logging
33
import math
44
from typing import (
5-
List,
5+
Iterator,
6+
Optional,
67
Tuple,
78
)
89

910
import numpy as np
1011

1112
from deepmd.env import (
1213
GLOBAL_NP_FLOAT_PRECISION,
14+
GLOBAL_TF_FLOAT_PRECISION,
1315
default_tf_session_config,
14-
op_module,
1516
tf,
1617
)
18+
from deepmd.utils.batch_size import (
19+
AutoBatchSize,
20+
)
1721
from deepmd.utils.data_system import (
1822
DeepmdDataSystem,
1923
)
20-
from deepmd.utils.parallel_op import (
21-
ParallelOp,
24+
from deepmd.utils.nlist import (
25+
extend_coord_with_ghosts,
26+
)
27+
from deepmd.utils.sess import (
28+
run_sess,
2229
)
2330

2431
log = logging.getLogger(__name__)
2532

2633

34+
class NeighborStatOP:
35+
"""Class for getting neighbor statics data information.
36+
37+
Parameters
38+
----------
39+
ntypes
40+
The num of atom types
41+
rcut
42+
The cut-off radius
43+
distinguish_types : bool, optional
44+
If False, treat all types as a single type.
45+
"""
46+
47+
def __init__(
48+
self,
49+
ntypes: int,
50+
rcut: float,
51+
distinguish_types: bool,
52+
) -> None:
53+
super().__init__()
54+
self.rcut = rcut
55+
self.ntypes = ntypes
56+
self.distinguish_types = distinguish_types
57+
58+
def build(
59+
self,
60+
coord: tf.Tensor,
61+
atype: tf.Tensor,
62+
cell: tf.Tensor,
63+
pbc: tf.Tensor,
64+
) -> Tuple[tf.Tensor, tf.Tensor]:
65+
"""Calculate the nearest neighbor distance between atoms, maximum nbor size of
66+
atoms and the output data range of the environment matrix.
67+
68+
Parameters
69+
----------
70+
coord
71+
The coordinates of atoms.
72+
atype
73+
The atom types.
74+
cell
75+
The cell.
76+
77+
Returns
78+
-------
79+
tf.Tensor
80+
The minimal squared distance between two atoms, in the shape of (nframes,)
81+
tf.Tensor
82+
The maximal number of neighbors
83+
"""
84+
# generated by GitHub Copilot, converted from PT codes
85+
nframes = tf.shape(coord)[0]
86+
coord = tf.reshape(coord, [nframes, -1, 3])
87+
nloc = tf.shape(coord)[1]
88+
coord = tf.reshape(coord, [nframes, nloc * 3])
89+
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
90+
coord, atype, cell, self.rcut, pbc
91+
)
92+
93+
coord1 = tf.reshape(extend_coord, [nframes, -1])
94+
nall = tf.shape(coord1)[1] // 3
95+
coord0 = coord1[:, : nloc * 3]
96+
diff = (
97+
tf.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
98+
- tf.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
99+
)
100+
# shape of diff: nframes, nloc, nall, 3
101+
# remove the diagonal elements
102+
mask = tf.eye(nloc, nall, dtype=tf.bool)
103+
# expand mask
104+
mask = tf.tile(mask[None, :, :], [nframes, 1, 1])
105+
# expand inf
106+
inf_mask = tf.constant(
107+
float("inf"), dtype=GLOBAL_TF_FLOAT_PRECISION, shape=[1, 1, 1]
108+
)
109+
inf_mask = tf.tile(inf_mask, [nframes, nloc, nall])
110+
# virtual type (<0) are not counted
111+
virtual_type_mask_i = tf.tile(tf.less(atype, 0)[:, :, None], [1, 1, nall])
112+
virtual_type_mask_j = tf.tile(
113+
tf.less(extend_atype, 0)[:, None, :], [1, nloc, 1]
114+
)
115+
mask = mask | virtual_type_mask_i | virtual_type_mask_j
116+
rr2 = tf.reduce_sum(tf.square(diff), axis=-1)
117+
rr2 = tf.where(mask, inf_mask, rr2)
118+
min_rr2 = tf.reduce_min(rr2, axis=(1, 2))
119+
# count the number of neighbors
120+
if self.distinguish_types:
121+
mask = rr2 < self.rcut**2
122+
nnei = []
123+
for ii in range(self.ntypes):
124+
nnei.append(
125+
tf.reduce_sum(
126+
tf.cast(
127+
mask & (tf.equal(extend_atype, ii))[:, None, :], tf.int32
128+
),
129+
axis=-1,
130+
)
131+
)
132+
# shape: nframes, nloc, ntypes
133+
nnei = tf.stack(nnei, axis=-1)
134+
else:
135+
mask = rr2 < self.rcut**2
136+
# virtual types (<0) are not counted
137+
nnei = tf.reshape(
138+
tf.reduce_sum(
139+
tf.cast(
140+
mask & tf.greater_equal(extend_atype, 0)[:, None, :], tf.int32
141+
),
142+
axis=-1,
143+
),
144+
[nframes, nloc, 1],
145+
)
146+
# nnei: nframes, nloc, ntypes
147+
# virtual type i (<0) are not counted
148+
nnei = tf.where(
149+
tf.tile(
150+
tf.less(atype, 0)[:, :, None],
151+
[1, 1, self.ntypes if self.distinguish_types else 1],
152+
),
153+
tf.zeros_like(nnei, dtype=tf.int32),
154+
nnei,
155+
)
156+
max_nnei = tf.reduce_max(nnei, axis=1)
157+
return min_rr2, max_nnei
158+
159+
27160
class NeighborStat:
28161
"""Class for getting training data information.
29162
@@ -46,52 +179,15 @@ def __init__(
46179
one_type: bool = False,
47180
) -> None:
48181
"""Constructor."""
49-
self.rcut = rcut
50-
self.ntypes = ntypes
51-
self.one_type = one_type
52-
sub_graph = tf.Graph()
53-
54-
def builder():
55-
place_holders = {}
56-
for ii in ["coord", "box"]:
57-
place_holders[ii] = tf.placeholder(
58-
GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii
59-
)
60-
place_holders["type"] = tf.placeholder(
61-
tf.int32, [None, None], name="t_type"
62-
)
63-
place_holders["natoms_vec"] = tf.placeholder(
64-
tf.int32, [self.ntypes + 2], name="t_natoms"
65-
)
66-
place_holders["default_mesh"] = tf.placeholder(
67-
tf.int32, [None], name="t_mesh"
68-
)
69-
t_type = place_holders["type"]
70-
t_natoms = place_holders["natoms_vec"]
71-
if self.one_type:
72-
# all types = 0, natoms_vec = [natoms, natoms, natoms]
73-
t_type = tf.clip_by_value(t_type, -1, 0)
74-
t_natoms = tf.tile(t_natoms[0:1], [3])
75-
76-
_max_nbor_size, _min_nbor_dist = op_module.neighbor_stat(
77-
place_holders["coord"],
78-
t_type,
79-
t_natoms,
80-
place_holders["box"],
81-
place_holders["default_mesh"],
82-
rcut=self.rcut,
83-
)
84-
place_holders["dir"] = tf.placeholder(tf.string)
85-
_min_nbor_dist = tf.reduce_min(_min_nbor_dist)
86-
_max_nbor_size = tf.reduce_max(_max_nbor_size, axis=0)
87-
return place_holders, (_max_nbor_size, _min_nbor_dist, place_holders["dir"])
88-
89-
with sub_graph.as_default():
90-
self.p = ParallelOp(builder, config=default_tf_session_config)
91-
182+
super().__init__(ntypes, rcut, one_type)
183+
self.auto_batch_size = AutoBatchSize()
184+
self.neighbor_stat = NeighborStatOP(ntypes, rcut, not one_type)
185+
self.place_holders = {}
186+
with tf.Graph().as_default() as sub_graph:
187+
self.op = self.build()
92188
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
93189

94-
def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]:
190+
def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, np.ndarray]:
95191
"""Get the data statistics of the training data, including nearest nbor distance between atoms, max nbor size of atoms.
96192
97193
Parameters
@@ -104,50 +200,119 @@ def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]:
104200
min_nbor_dist
105201
The nearest distance between neighbor atoms
106202
max_nbor_size
107-
A list with ntypes integers, denotes the actual achieved max sel
203+
An array with ntypes integers, denotes the actual achieved max sel
108204
"""
109-
self.min_nbor_dist = 100.0
110-
self.max_nbor_size = [0]
111-
if not self.one_type:
112-
self.max_nbor_size *= self.ntypes
113-
114-
def feed():
115-
for ii in range(len(data.system_dirs)):
116-
for jj in data.data_systems[ii].dirs:
117-
data_set = data.data_systems[ii]._load_set(jj)
118-
for kk in range(np.array(data_set["type"]).shape[0]):
119-
yield {
120-
"coord": np.array(data_set["coord"])[kk].reshape(
121-
[-1, data.natoms[ii] * 3]
122-
),
123-
"type": np.array(data_set["type"])[kk].reshape(
124-
[-1, data.natoms[ii]]
125-
),
126-
"natoms_vec": np.array(data.natoms_vec[ii]),
127-
"box": np.array(data_set["box"])[kk].reshape([-1, 9]),
128-
"default_mesh": np.array(data.default_mesh[ii]),
129-
"dir": str(jj),
130-
}
131-
132-
for mn, dt, jj in self.p.generate(self.sub_sess, feed()):
205+
min_nbor_dist = 100.0
206+
max_nbor_size = np.zeros(1 if self.mixed_type else self.ntypes, dtype=int)
207+
208+
for mn, dt, jj in self.iterator(data):
133209
if np.isinf(dt):
134210
log.warning(
135211
"Atoms with no neighbors found in %s. Please make sure it's what you expected."
136212
% jj
137213
)
138-
if dt < self.min_nbor_dist:
214+
if dt < min_nbor_dist:
139215
if math.isclose(dt, 0.0, rel_tol=1e-6):
140216
# it's unexpected that the distance between two atoms is zero
141217
# zero distance will cause nan (#874)
142218
raise RuntimeError(
143219
"Some atoms are overlapping in %s. Please check your"
144220
" training data to remove duplicated atoms." % jj
145221
)
146-
self.min_nbor_dist = dt
147-
self.max_nbor_size = np.maximum(mn, self.max_nbor_size)
222+
min_nbor_dist = dt
223+
max_nbor_size = np.maximum(mn, max_nbor_size)
148224

149225
# do sqrt in the final
150-
self.min_nbor_dist = math.sqrt(self.min_nbor_dist)
151-
log.info("training data with min nbor dist: " + str(self.min_nbor_dist))
152-
log.info("training data with max nbor size: " + str(self.max_nbor_size))
153-
return self.min_nbor_dist, self.max_nbor_size
226+
min_nbor_dist = math.sqrt(min_nbor_dist)
227+
log.info("training data with min nbor dist: " + str(min_nbor_dist))
228+
log.info("training data with max nbor size: " + str(max_nbor_size))
229+
return min_nbor_dist, max_nbor_size
230+
231+
def build(self) -> Tuple[tf.Tensor, tf.Tensor]:
232+
"""Build the graph.
233+
234+
Returns
235+
-------
236+
tf.Tensor
237+
The minimal squared distance between two atoms, in the shape of (nframes,)
238+
tf.Tensor
239+
The maximal number of neighbors
240+
"""
241+
for ii in ["coord", "box"]:
242+
self.place_holders[ii] = tf.placeholder(
243+
GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii
244+
)
245+
self.place_holders["type"] = tf.placeholder(
246+
tf.int32, [None, None], name="t_type"
247+
)
248+
self.place_holders["pbc"] = tf.placeholder(tf.bool, [], name="t_pbc")
249+
ret = self.neighbor_stat.build(
250+
self.place_holders["coord"],
251+
self.place_holders["type"],
252+
self.place_holders["box"],
253+
self.place_holders["pbc"],
254+
)
255+
return ret
256+
257+
def iterator(
258+
self, data: DeepmdDataSystem
259+
) -> Iterator[Tuple[np.ndarray, float, str]]:
260+
"""Produce data.
261+
262+
Parameters
263+
----------
264+
data
265+
The data system
266+
267+
Yields
268+
------
269+
np.ndarray
270+
The maximal number of neighbors
271+
float
272+
The squared minimal distance between two atoms
273+
str
274+
The directory of the data system
275+
"""
276+
for ii in range(len(data.system_dirs)):
277+
for jj in data.data_systems[ii].dirs:
278+
data_set = data.data_systems[ii]
279+
data_set_data = data_set._load_set(jj)
280+
minrr2, max_nnei = self.auto_batch_size.execute_all(
281+
self._execute,
282+
data_set_data["coord"].shape[0],
283+
data_set.get_natoms(),
284+
data_set_data["coord"],
285+
data_set_data["type"],
286+
data_set_data["box"],
287+
data_set.pbc,
288+
)
289+
yield np.max(max_nnei, axis=0), np.min(minrr2), jj
290+
291+
def _execute(
292+
self,
293+
coord: np.ndarray,
294+
atype: np.ndarray,
295+
box: Optional[np.ndarray],
296+
pbc: bool,
297+
):
298+
"""Execute the operation.
299+
300+
Parameters
301+
----------
302+
coord
303+
The coordinates of atoms.
304+
atype
305+
The atom types.
306+
box
307+
The box.
308+
pbc
309+
Whether the box is periodic.
310+
"""
311+
feed_dict = {
312+
self.place_holders["coord"]: coord,
313+
self.place_holders["type"]: atype,
314+
self.place_holders["box"]: box,
315+
self.place_holders["pbc"]: pbc,
316+
}
317+
minrr2, max_nnei = run_sess(self.sub_sess, self.op, feed_dict=feed_dict)
318+
return minrr2, max_nnei

0 commit comments

Comments
 (0)