forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdpa1.py
588 lines (532 loc) · 22.7 KB
/
dpa1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
import torch
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pt.model.network.mlp import (
NetworkCollection,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
TypeEmbedNetConsistent,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
from .base_descriptor import (
BaseDescriptor,
)
from .se_atten import (
DescrptBlockSeAtten,
NeighborGatedAttention,
)
@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
class DescrptDPA1(BaseDescriptor, torch.nn.Module):
r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model.
This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by
.. math::
\mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<,
where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i`
after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor.
Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor.
To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`,
keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`,
and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained:
.. math::
\left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
.. math::
\left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
.. math::
\left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations
that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l`
is the index of the attention layer.
The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`,
is chosen as the two-body embedding matrix.
Then the scaled dot-product attention method is adopted:
.. math::
A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l},
where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights.
In the original attention method,
one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`,
with :math:`\sqrt{d_{k}}` being the normalization temperature.
This is slightly modified to incorporate the angular information:
.. math::
\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T},
where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates,
:math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}`
and :math:`\odot` means element-wise multiplication.
Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix
:math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1]
.. math::
\mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})).
Parameters
----------
rcut: float
The cut-off radius :math:`r_c`
rcut_smth: float
From where the environment matrix should be smoothed :math:`r_s`
sel : list[int], int
list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
int: the total maxmum number of atoms in the cut-off radius
ntypes : int
Number of element types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
axis_neuron: int
Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
trainable: bool
If the weights of this descriptors are trainable.
trainable_ln: bool
Whether to use trainable shift and scale weights in layer normalization.
ln_eps: float, Optional
The epsilon value for layer normalization.
type_one_side: bool
If 'False', type embeddings of both neighbor and central atoms are considered.
If 'True', only type embeddings of neighbor atoms are considered.
Default is 'False'.
attn: int
Hidden dimension of the attention vectors
attn_layer: int
Number of attention layers
attn_dotr: bool
If dot the angular gate to the attention weights
attn_mask: bool
(Only support False to keep consistent with other backend references.)
(Not used in this version. True option is not implemented.)
If mask the diagonal of attention weights
exclude_types : List[List[int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
set_davg_zero: bool
Set the shift of embedding net input to zero.
activation_function: str
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision: str
The precision of the embedding net parameters. Supported options are |PRECISION|
scaling_factor: float
The scaling factor of normalization in calculations of attention weights.
If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5
normalize: bool
Whether to normalize the hidden vectors in attention weights calculation.
temperature: float
If not None, the scaling of attention weights is `temperature` itself.
smooth_type_embedding: bool
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
seed: int, Optional
Random seed for parameter initialization.
use_econf_tebd: bool, Optional
Whether to use electronic configuration type embedding.
type_map: List[str], Optional
A list of strings. Give the name to each type of atoms.
Only used if `use_econf_tebd` is `True` in type embedding net.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.
Limitations
-----------
The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
References
----------
.. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022.
DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation.
arXiv preprint arXiv:2208.08236.
"""
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: list = [25, 50, 100],
axis_neuron: int = 16,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
set_davg_zero: bool = True,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
scaling_factor: int = 1.0,
normalize=True,
temperature=None,
concat_output_tebd: bool = True,
trainable: bool = True,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
seed: Optional[int] = None,
use_econf_tebd: bool = False,
type_map: Optional[List[str]] = None,
# not implemented
spin=None,
type: Optional[str] = None,
old_impl: bool = False,
):
super().__init__()
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
del type, spin, attn_mask
self.se_atten = DescrptBlockSeAtten(
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode=tebd_input_mode,
set_davg_zero=set_davg_zero,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=False,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
smooth=smooth_type_embedding,
type_one_side=type_one_side,
exclude_types=exclude_types,
env_protection=env_protection,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
seed=seed,
old_impl=old_impl,
)
self.use_econf_tebd = use_econf_tebd
self.type_map = type_map
self.type_embedding = TypeEmbedNet(
ntypes,
tebd_dim,
precision=precision,
seed=seed,
use_econf_tebd=use_econf_tebd,
type_map=type_map,
)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
# set trainable
for param in self.parameters():
param.requires_grad = trainable
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.se_atten.get_rcut()
def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.se_atten.get_rcut_smth()
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return self.se_atten.get_nsel()
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.se_atten.get_sel()
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.se_atten.get_ntypes()
def get_dim_out(self) -> int:
"""Returns the output dimension."""
ret = self.se_atten.get_dim_out()
if self.concat_output_tebd:
ret += self.tebd_dim
return ret
def get_dim_emb(self) -> int:
return self.se_atten.dim_emb
def mixed_types(self) -> bool:
"""If true, the discriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return self.se_atten.mixed_types()
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
assert (
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For DPA1 descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in both type_embedding and se_atten
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.se_atten.share_params(base_class.se_atten, 0, resume=resume)
# shared_level: 1
# share all parameters in type_embedding
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
# Other shared levels
else:
raise NotImplementedError
@property
def dim_out(self):
return self.get_dim_out()
@property
def dim_emb(self):
return self.get_dim_emb()
def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.
"""
return self.se_atten.compute_input_stats(merged, path)
def set_stat_mean_and_stddev(
self,
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
self.se_atten.mean = mean
self.se_atten.stddev = stddev
def serialize(self) -> dict:
obj = self.se_atten
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn_dim,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"smooth_type_embedding": obj.smooth,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"type_map": self.type_map,
# make deterministic
"precision": RESERVED_PRECISON_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"type_embedding": self.type_embedding.embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
"trainable": self.trainable,
"spin": None,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data
@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)
def t_cvt(xx):
return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE)
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Parameters
----------
extended_coord
The extended coordinates of atoms. shape: nf x (nallx3)
extended_atype
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
g1_ext,
mapping=None,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, sel = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
return local_jdata_cpy, min_nbor_dist