forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexclude_mask.py
140 lines (121 loc) · 4.11 KB
/
exclude_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Tuple,
)
import array_api_compat
import numpy as np
class AtomExcludeMask:
"""Computes the type exclusion mask for atoms."""
def __init__(
self,
ntypes: int,
exclude_types: List[int] = [],
):
self.ntypes = ntypes
self.exclude_types = exclude_types
self.type_mask = np.array(
[1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)],
dtype=np.int32,
)
# (ntypes)
self.type_mask = self.type_mask.reshape([-1])
def get_exclude_types(self):
return self.exclude_types
def get_type_mask(self):
return self.type_mask
def build_type_exclude_mask(
self,
atype: np.ndarray,
):
"""Compute type exclusion mask for atoms.
Parameters
----------
atype
The extended aotm types. shape: nf x natom
Returns
-------
mask
The type exclusion mask for atoms. shape: nf x natom
Element [ff,ii] being 0 if type(ii) is excluded,
otherwise being 1.
"""
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return xp.reshape(self.type_mask[atype], (nf, natom))
class PairExcludeMask:
"""Computes the type exclusion mask for atom pairs."""
def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
self.ntypes = ntypes
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in self.exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = type_mask.reshape([-1])
def get_exclude_types(self):
return self.exclude_types
def build_type_exclude_mask(
self,
nlist: np.ndarray,
atype_ext: np.ndarray,
):
"""Compute type exclusion mask for atom pairs.
Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall
Returns
-------
mask
The type exclusion mask for pair atoms of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.
"""
xp = array_api_compat.array_namespace(nlist, atype_ext)
if len(self.exclude_types) == 0:
# safely return 1 if nothing is excluded.
return xp.ones_like(nlist, dtype=xp.int32)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = xp.concat(
[atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
)
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
# nf x nloc x nnei
index = xp.reshape(
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
)
# type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
index = xp.reshape(index, [-1])
index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei)
type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0)
type_j = xp.reshape(type_j, (nf, nloc, nnei))
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
mask = xp.reshape(
xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei)
)
return mask
def __contains__(self, item):
return item in self.exclude_types