forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathse_t.py
32 lines (27 loc) · 889 Bytes
/
se_t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)
class DescrptSeT(DescrptSeTDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)
return super().__setattr__(name, value)