forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpreprocess.py
303 lines (265 loc) · 10.1 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Union,
)
import torch
from deepmd.pt.utils import (
env,
)
log = logging.getLogger(__name__)
class Region3D:
def __init__(self, boxt):
"""Construct a simulation box."""
boxt = boxt.reshape([3, 3])
self.boxt = boxt # convert physical coordinates to internal ones
self.rec_boxt = torch.linalg.inv(
self.boxt
) # convert internal coordinates to physical ones
self.volume = torch.linalg.det(self.boxt) # compute the volume
# boxt = boxt.permute(1, 0)
c_yz = torch.cross(boxt[1], boxt[2])
self._h2yz = self.volume / torch.linalg.norm(c_yz)
c_zx = torch.cross(boxt[2], boxt[0])
self._h2zx = self.volume / torch.linalg.norm(c_zx)
c_xy = torch.cross(boxt[0], boxt[1])
self._h2xy = self.volume / torch.linalg.norm(c_xy)
def phys2inter(self, coord):
"""Convert physical coordinates to internal ones."""
return coord @ self.rec_boxt
def inter2phys(self, coord):
"""Convert internal coordinates to physical ones."""
return coord @ self.boxt
def get_face_distance(self):
"""Return face distinces to each surface of YZ, ZX, XY."""
return torch.stack([self._h2yz, self._h2zx, self._h2xy])
def normalize_coord(coord, region: Region3D, nloc: int):
"""Move outer atoms into region by mirror.
Args:
- coord: shape is [nloc*3]
"""
tmp_coord = coord.clone()
inter_cood = torch.remainder(region.phys2inter(tmp_coord), 1.0)
tmp_coord = region.inter2phys(inter_cood)
return tmp_coord
def compute_serial_cid(cell_offset, ncell):
"""Tell the sequential cell ID in its 3D space.
Args:
- cell_offset: shape is [3]
- ncell: shape is [3]
"""
cell_offset[:, 0] *= ncell[1] * ncell[2]
cell_offset[:, 1] *= ncell[2]
return cell_offset.sum(-1)
def compute_pbc_shift(cell_offset, ncell):
"""Tell shift count to move the atom into region."""
shift = torch.zeros_like(cell_offset)
shift = shift + (cell_offset < 0) * -(
torch.div(cell_offset, ncell, rounding_mode="floor")
)
shift = shift + (cell_offset >= ncell) * -(
torch.div((cell_offset - ncell), ncell, rounding_mode="floor") + 1
)
assert torch.all(cell_offset + shift * ncell >= 0)
assert torch.all(cell_offset + shift * ncell < ncell)
return shift
def build_inside_clist(coord, region: Region3D, ncell):
"""Build cell list on atoms inside region.
Args:
- coord: shape is [nloc*3]
- ncell: shape is [3]
"""
loc_ncell = int(torch.prod(ncell)) # num of local cells
nloc = coord.numel() // 3 # num of local atoms
inter_cell_size = 1.0 / ncell
inter_cood = region.phys2inter(coord.view(-1, 3))
cell_offset = torch.floor(inter_cood / inter_cell_size).to(torch.long)
# numerical error brought by conversion from phys to inter back and force
# may lead to negative value
cell_offset[cell_offset < 0] = 0
delta = cell_offset - ncell
a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms
arange = torch.arange(0, loc_ncell, 1)
cellid = a2c == arange.unsqueeze(-1) # one hot cellid
c2a = cellid.nonzero()
lst = []
cnt = 0
bincount = torch.bincount(a2c, minlength=loc_ncell)
for i in range(loc_ncell):
n = bincount[i]
lst.append(c2a[cnt : cnt + n, 1])
cnt += n
return a2c, lst
def append_neighbors(coord, region: Region3D, atype, rcut: float):
"""Make ghost atoms who are valid neighbors.
Args:
- coord: shape is [nloc*3]
- atype: shape is [nloc]
"""
to_face = region.get_face_distance()
# compute num and size of local cells
ncell = torch.floor(to_face / rcut).to(torch.long)
ncell[ncell == 0] = 1
cell_size = to_face / ncell
ngcell = (
torch.floor(rcut / cell_size).to(torch.long) + 1
) # num of cells out of local, which contain ghost atoms
# add ghost atoms
a2c, c2a = build_inside_clist(coord, region, ncell)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor([1, 0, 0], dtype=torch.long)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor([0, 1, 0], dtype=torch.long)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor([0, 0, 1], dtype=torch.long)
xyz = xyz.view(-1, 3)
mask_a = (xyz >= 0).all(dim=-1)
mask_b = (xyz < ncell).all(dim=-1)
mask = ~torch.logical_and(mask_a, mask_b)
xyz = xyz[mask] # cell coord
shift = compute_pbc_shift(xyz, ncell)
coord_shift = region.inter2phys(shift.to(env.GLOBAL_PT_FLOAT_PRECISION))
mirrored = shift * ncell + xyz
cid = compute_serial_cid(mirrored, ncell)
n_atoms = coord.shape[0]
aid = [c2a[ci] + i * n_atoms for i, ci in enumerate(cid)]
aid = torch.cat(aid)
tmp = torch.div(aid, n_atoms, rounding_mode="trunc")
aid = aid % n_atoms
tmp_coord = coord[aid] - coord_shift[tmp]
tmp_atype = atype[aid]
# merge local and ghost atoms
merged_coord = torch.cat([coord, tmp_coord])
merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]])
merged_atype = torch.cat([atype, tmp_atype])
merged_mapping = torch.cat([torch.arange(atype.numel()), aid])
return merged_coord_shift, merged_atype, merged_mapping
def build_neighbor_list(
nloc: int, coord, atype, rcut: float, sec, mapping, type_split=True, min_check=False
):
"""For each atom inside region, build its neighbor list.
Args:
- coord: shape is [nall*3]
- atype: shape is [nall]
"""
nall = coord.numel() // 3
coord = coord.float()
nlist = [[] for _ in range(nloc)]
coord_l = coord.view(-1, 1, 3)[:nloc]
coord_r = coord.view(1, -1, 3)
distance = coord_l - coord_r
distance = torch.linalg.norm(distance, dim=-1)
DISTANCE_INF = distance.max().detach() + rcut
distance[:nloc, :nloc] += torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF
if min_check:
if distance.min().abs() < 1e-6:
RuntimeError("Atom dist too close!")
if not type_split:
sec = sec[-1:]
lst = []
nlist = torch.zeros((nloc, sec[-1].item())).long() - 1
nlist_loc = torch.zeros((nloc, sec[-1].item())).long() - 1
nlist_type = torch.zeros((nloc, sec[-1].item())).long() - 1
for i, nnei in enumerate(sec):
if i > 0:
nnei = nnei - sec[i - 1]
if not type_split:
tmp = distance
else:
mask = atype.unsqueeze(0) == i
tmp = distance + (~mask) * DISTANCE_INF
if tmp.shape[1] >= nnei:
_sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False)
else:
# when nnei > nall
indices = torch.zeros((nloc, nnei)).long() - 1
_sorted = torch.ones((nloc, nnei)).long() * DISTANCE_INF
_sorted_nnei, indices_nnei = torch.topk(
tmp, tmp.shape[1], dim=1, largest=False
)
_sorted[:, : tmp.shape[1]] = _sorted_nnei
indices[:, : tmp.shape[1]] = indices_nnei
mask = (_sorted < rcut).to(torch.long)
indices_loc = mapping[indices]
indices = indices * mask + -1 * (1 - mask) # -1 for padding
indices_loc = indices_loc * mask + -1 * (1 - mask) # -1 for padding
if i == 0:
start = 0
else:
start = sec[i - 1]
end = min(sec[i], start + indices.shape[1])
nlist[:, start:end] = indices[:, :nnei]
nlist_loc[:, start:end] = indices_loc[:, :nnei]
nlist_type[:, start:end] = atype[indices[:, :nnei]] * mask + -1 * (1 - mask)
return nlist, nlist_loc, nlist_type
def compute_smooth_weight(distance, rmin: float, rmax: float):
"""Compute smooth weight for descriptor elements."""
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask))
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
return torch.where(mid_mask, vv, min_mask.to(dtype=distance.dtype))
def make_env_mat(
coord,
atype,
region,
rcut: Union[float, list],
sec,
pbc=True,
type_split=True,
min_check=False,
):
"""Based on atom coordinates, return environment matrix.
Returns
-------
nlist: nlist, [nloc, nnei]
merged_coord_shift: shift on nall atoms, [nall, 3]
merged_mapping: mapping from nall index to nloc index, [nall]
"""
# move outer atoms into cell
hybrid = isinstance(rcut, list)
_rcut = rcut
if hybrid:
_rcut = max(rcut)
if pbc:
merged_coord_shift, merged_atype, merged_mapping = append_neighbors(
coord, region, atype, _rcut
)
merged_coord = coord[merged_mapping] - merged_coord_shift
if merged_coord.shape[0] <= coord.shape[0]:
log.warning("No ghost atom is added for system ")
else:
merged_coord_shift = torch.zeros_like(coord)
merged_atype = atype.clone()
merged_mapping = torch.arange(atype.numel())
merged_coord = coord.clone()
# build nlist
if not hybrid:
nlist, nlist_loc, nlist_type = build_neighbor_list(
coord.shape[0],
merged_coord,
merged_atype,
rcut,
sec,
merged_mapping,
type_split=type_split,
min_check=min_check,
)
else:
nlist, nlist_loc, nlist_type = [], [], []
for ii, single_rcut in enumerate(rcut):
nlist_tmp, nlist_loc_tmp, nlist_type_tmp = build_neighbor_list(
coord.shape[0],
merged_coord,
merged_atype,
single_rcut,
sec[ii],
merged_mapping,
type_split=type_split,
min_check=min_check,
)
nlist.append(nlist_tmp)
nlist_loc.append(nlist_loc_tmp)
nlist_type.append(nlist_type_tmp)
return nlist, nlist_loc, nlist_type, merged_coord_shift, merged_mapping