forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmake_base_descriptor.py
106 lines (86 loc) · 2.71 KB
/
make_base_descriptor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
List,
Optional,
)
def make_base_descriptor(
t_tensor,
fwd_method_name: str = "forward",
):
"""Make the base class for the descriptor.
Parameters
----------
t_tensor
The type of the tensor. used in the type hint.
fwd_method_name
Name of the forward method. For dpmodels, it should be "call".
For torch models, it should be "forward".
"""
class BD(ABC):
"""Base descriptor provides the interfaces of descriptor."""
@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
pass
@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected neighboring atoms for each type."""
pass
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.get_nsel()
@abstractmethod
def get_ntypes(self) -> int:
"""Returns the number of element types."""
pass
@abstractmethod
def get_dim_out(self) -> int:
"""Returns the output descriptor dimension."""
pass
@abstractmethod
def get_dim_emb(self) -> int:
"""Returns the embedding dimension of g2."""
pass
@abstractmethod
def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
pass
@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
pass
@abstractmethod
def init_desc_stat(self, stat_dict):
"""Initialize the model bias by the statistics."""
pass
@abstractmethod
def fwd(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[t_tensor] = None,
):
"""Calculate descriptor."""
pass
@abstractmethod
def serialize(self) -> dict:
"""Serialize the obj to dict."""
pass
@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
return BD