diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index a19a2aa034..bbc332588c 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dpa1 import ( + DescrptDPA1, +) from .hybrid import ( DescrptHybrid, ) @@ -15,6 +18,7 @@ __all__ = [ "DescrptSeA", "DescrptSeR", + "DescrptDPA1", "DescrptHybrid", "make_base_descriptor", ] diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py new file mode 100644 index 0000000000..a551a57628 --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -0,0 +1,947 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + Any, + List, + Optional, + Tuple, + Union, +) + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.utils import ( + EmbeddingNet, + EnvMat, + NetworkCollection, + PairExcludeMask, +) + +from .base_descriptor import ( + BaseDescriptor, +) + + +def np_softmax(x, axis=-1): + e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + +def np_normalize(x, axis=-1): + return x / np.linalg.norm(x, axis=axis, keepdims=True) + + +@BaseDescriptor.register("se_atten") +@BaseDescriptor.register("dpa1") +class DescrptDPA1(NativeOP, BaseDescriptor): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + (Only support False to keep consistent with other backend references.) + (Not used in this version. True option is not implemented.) + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embedding: bool + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + spin + (Only support None to keep consistent with other backend references.) + (Not used in this version. Not-none option is not implemented.) + The old implementation of deepspin. + + Limitations + ----------- + The currently implementation does not support the following features + 1. tebd_input_mode != 'concat' + + The currently implementation will not support the following deprecated features + 1. spin is not None + 2. attn_mask == True + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = False, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + smooth_type_embedding: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + ) -> None: + ## seed, uniform_seed, multi_task, not included. + if spin is not None: + raise NotImplementedError("old implementation of spin is not supported.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) + # TODO + if tebd_input_mode != "concat": + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + + self.rcut = rcut + self.rcut_smth = rcut_smth + if isinstance(sel, int): + sel = [sel] + self.sel = sel + self.nnei = sum(sel) + self.ntypes = ntypes + self.neuron = neuron + self.filter_neuron = self.neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.resnet_dt = resnet_dt + self.trainable = trainable + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.type_one_side = type_one_side + self.attn = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.exclude_types = exclude_types + self.env_protection = env_protection + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.smooth = smooth_type_embedding + self.concat_output_tebd = concat_output_tebd + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + + self.type_embedding = TypeEmbedNet( + ntypes=self.ntypes, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + in_dim = 1 + self.tebd_dim * 2 + else: + in_dim = 1 + self.tebd_dim + else: + in_dim = 1 + self.embeddings = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings[0] = EmbeddingNet( + in_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn, + dotr=self.attn_dotr, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + smooth=self.smooth, + precision=self.precision, + ) + + wanted_shape = (self.ntypes, self.nnei, 4) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) + self.davg = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.dstd = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.orig_sel = self.sel + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.get_dim_out() + + def get_dim_out(self): + """Returns the output dimension of this descriptor.""" + return ( + self.neuron[-1] * self.axis_neuron + self.tebd_dim + if self.concat_output_tebd + else self.neuron[-1] * self.axis_neuron + ) + + def get_dim_emb(self): + """Returns the embedding (g2) dimension of this descriptor.""" + return self.neuron[-1] + + def get_rcut(self): + """Returns cutoff radius.""" + return self.rcut + + def get_sel(self): + """Returns cutoff radius.""" + return self.sel + + def mixed_types(self): + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def cal_g( + self, + ss, + embedding_idx, + ): + nfnl, nnei = ss.shape[0:2] + ss = ss.reshape(nfnl, nnei, -1) + # nfnl x nnei x ng + gg = self.embeddings[embedding_idx].call(ss) + return gg + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + del mapping + # nf x nloc x nnei x 4 + dmatrix, sw = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd + ) + nf, nloc, nnei, _ = dmatrix.shape + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + # nfnl x nnei + nlist = nlist.reshape(nf * nloc, nnei) + # nfnl x nnei x 4 + dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + # nfnl x nnei x 1 + sw = sw.reshape(nf * nloc, nnei, 1) + + # add type embedding into input + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call()[atype_ext] + # nfnl x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1) + # nfnl x nnei x tebd_dim + atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) + # nfnl x nnei + nlist_mask = nlist != -1 + # nfnl x nnei x 1 + sw = np.where(nlist_mask[:, :, None], sw, 0.0) + nlist_masked = np.where(nlist_mask, nlist, 0) + index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + # nfnl x nnei x tebd_dim + atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( + nf * nloc, nnei, self.tebd_dim + ) + + ng = self.neuron[-1] + # nfnl x nnei + exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + # nfnl x nnei x 4 + rr = dmatrix.reshape(nf * nloc, nnei, 4) + rr = rr * exclude_mask[:, :, None] + # nfnl x nnei x 1 + ss = rr[..., 0:1] + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + # nfnl x nnei x (1 + 2 * tebd_dim) + ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + else: + raise NotImplementedError + + # calculate gg + gg = self.cal_g(ss, 0) + input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / ( + np.linalg.norm( + dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True + ) + + 1e-12 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x ng x 4 + gr = np.einsum("lni,lnj->lij", gg, rr) + gr /= self.nnei + gr1 = gr[:, : self.axis_neuron, :] + # nfnl x ng x ng1 + grrg = np.einsum("lid,ljd->lij", gr, gr1) + # nf x nloc x (ng x ng1) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + # nf x nloc x (ng x ng1 + tebd_dim) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) + return grrg, gr[..., 1:], None, None, sw + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + return { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "neuron": self.neuron, + "axis_neuron": self.axis_neuron, + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "set_davg_zero": self.set_davg_zero, + "attn": self.attn, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": False, + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "smooth_type_embedding": self.smooth, + "type_one_side": self.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "embeddings": self.embeddings.serialize(), + "attention_layers": self.dpa1_attention.serialize(), + "env_mat": self.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "@variables": { + "davg": self.davg, + "dstd": self.dstd, + }, + ## to be updated when the options are supported. + "trainable": True, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.embeddings = NetworkCollection.deserialize(embeddings) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + +class NeighborGatedAttention(NativeOP): + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.smooth = smooth + self.precision = precision + self.network_type = NeighborGatedAttentionLayer + + self.attention_layers = [ + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth=smooth, + precision=precision, + ) + for _ in range(layer_num) + ] + + def call( + self, + input_G, + nei_mask, + input_r: Optional[np.ndarray] = None, + sw: Optional[np.ndarray] = None, + ): + out = input_G + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def __getitem__(self, key): + if isinstance(key, int): + return self.attention_layers[key] + else: + raise TypeError(key) + + def __setitem__(self, key, value): + if not isinstance(key, int): + raise TypeError(key) + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[key] = value + + def serialize(self): + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "NeighborGatedAttention", + "@version": 1, + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + attention_layers = data.pop("attention_layers") + obj = cls(**data) + obj.attention_layers = [ + NeighborGatedAttentionLayer.deserialize(layer) for layer in attention_layers + ] + return obj + + +class NeighborGatedAttentionLayer(NativeOP): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + ) + self.attn_layer_norm = LayerNorm( + self.embed_dim, eps=ln_eps, trainable=self.trainable_ln, precision=precision + ) + + def call( + self, + x, + nei_mask, + input_r: Optional[np.ndarray] = None, + sw: Optional[np.ndarray] = None, + ): + residual = x + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize(), + } + + @classmethod + def deserialize(cls, data) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(NativeOP): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = NativeLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + precision=precision, + ) + self.out_proj = NativeLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + precision=precision, + ) + + def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): + # Linear projection + q, k, v = np.split(self.in_proj(query), 3, axis=-1) + # Reshape and normalize + q = q.reshape(-1, self.nnei, self.hidden_dim) + k = k.reshape(-1, self.nnei, self.hidden_dim) + v = v.reshape(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = np_normalize(q, axis=-1) + k = np_normalize(k, axis=-1) + v = np_normalize(v, axis=-1) + q = q * self.scaling + # Attention weights + attn_weights = q @ k.transpose(0, 2, 1) + nei_mask = nei_mask.reshape(-1, self.nnei) + if self.smooth: + sw = sw.reshape(-1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, None, :] * sw[ + :, :, None + ] - attnw_shift + else: + attn_weights = np.where(nei_mask[:, None, :], attn_weights, -np.inf) + attn_weights = np_softmax(attn_weights, axis=-1) + attn_weights = np.where(nei_mask[:, :, None], attn_weights, 0.0) + if self.smooth: + attn_weights = attn_weights * sw[:, None, :] * sw[:, :, None] + if self.dotr: + angular_weight = input_r @ input_r.transpose(0, 2, 1) + attn_weights = attn_weights * angular_weight + # Output projection + o = attn_weights @ v + output = self.out_proj(o) + return output + + def serialize(self): + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize(), + } + + @classmethod + def deserialize(cls, data): + data = data.copy() + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = NativeLayer.deserialize(in_proj) + obj.out_proj = NativeLayer.deserialize(out_proj) + return obj diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 1cc8fda347..319f8a0dbd 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -274,6 +274,161 @@ def fn(x): raise NotImplementedError(activation_function) +class LayerNorm(NativeLayer): + """Implementation of Layer Normalization layer. + + Parameters + ---------- + num_in : int + The input dimension of the layer. + eps : float, optional + A small value added to prevent division by zero in calculations. + uni_init : bool, optional + If initialize the weights to be zeros and ones. + """ + + def __init__( + self, + num_in: int, + eps: float = 1e-5, + uni_init: bool = True, + trainable: bool = True, + precision: str = DEFAULT_PRECISION, + ) -> None: + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) + self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] + if self.uni_init: + self.w = np.ones_like(self.w) + self.b = np.zeros_like(self.b) + # only to keep consistent with other backends + self.trainable = trainable + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + data = { + "w": self.w, + "b": self.b, + } + return { + "@class": "LayerNorm", + "@version": 1, + "eps": self.eps, + "trainable": self.trainable, + "precision": self.precision, + "@variables": data, + } + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + variables = data.pop("@variables") + if variables["w"] is not None: + assert len(variables["w"].shape) == 1 + if variables["b"] is not None: + assert len(variables["b"].shape) == 1 + (num_in,) = variables["w"].shape + obj = cls( + num_in, + **data, + ) + (obj.w,) = (variables["w"],) + (obj.b,) = (variables["b"],) + obj._check_shape_consistency() + return obj + + def _check_shape_consistency(self): + if self.b is not None and self.w.shape[0] != self.b.shape[0]: + raise ValueError( + f"dim 1 of w {self.w.shape[0]} is not equal to shape " + f"of b {self.b.shape[0]}", + ) + + def __setitem__(self, key, value): + if key in ("w", "matrix"): + self.w = value + elif key in ("b", "bias"): + self.b = value + elif key == "trainable": + self.trainable = value + elif key == "precision": + self.precision = value + elif key == "eps": + self.eps = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("w", "matrix"): + return self.w + elif key in ("b", "bias"): + return self.b + elif key == "trainable": + return self.trainable + elif key == "precision": + return self.precision + elif key == "eps": + return self.eps + else: + raise KeyError(key) + + def dim_out(self) -> int: + return self.w.shape[0] + + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + y = self.layer_norm_numpy(x, (self.num_in,), self.w, self.b, self.eps) + return y + + @staticmethod + def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5): + # mean and variance + mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + # normalize + x_normalized = (x - mean) / np.sqrt(var + eps) + # shift and scale + if weight is not None and bias is not None: + x_normalized = x_normalized * weight + bias + return x_normalized + + def make_multilayer_network(T_NetworkLayer, ModuleBase): class NN(ModuleBase): """Native representation of a neural network. diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index d5e78296a5..852e08403c 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -9,8 +9,19 @@ import torch +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( + NetworkCollection, +) from deepmd.pt.model.network.network import ( TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.update_sel import ( UpdateSel, @@ -18,23 +29,167 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .base_descriptor import ( BaseDescriptor, ) from .se_atten import ( DescrptBlockSeAtten, + NeighborGatedAttention, ) @BaseDescriptor.register("dpa1") @BaseDescriptor.register("se_atten") class DescrptDPA1(BaseDescriptor, torch.nn.Module): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + (Only support False to keep consistent with other backend references.) + (Not used in this version. True option is not implemented.) + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embedding: bool + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + spin + (Only support None to keep consistent with other backend references.) + (Not used in this version. Not-none option is not implemented.) + The old implementation of deepspin. + + Limitations + ----------- + The currently implementation does not support the following features + 1. tebd_input_mode != 'concat' + + The currently implementation will not support the following deprecated features + 1. spin is not None + 2. attn_mask == True + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + def __init__( self, - rcut, - rcut_smth, - sel, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, @@ -46,39 +201,44 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation_function="tanh", - scaling_factor=1.0, - head_num=1, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + scaling_factor: int = 1.0, normalize=True, temperature=None, - return_rot=False, concat_output_tebd: bool = True, - env_protection: float = 0.0, - type: Optional[str] = None, - # not implemented - resnet_dt: bool = False, - type_one_side: bool = True, - precision: str = "default", trainable: bool = True, - exclude_types: List[Tuple[int, int]] = [], + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + smooth_type_embedding: bool = True, + type_one_side: bool = False, + # not implemented stripped_type_embedding: bool = False, - smooth_type_embedding: bool = False, + spin=None, + type: Optional[str] = None, + seed: Optional[int] = None, + old_impl: bool = False, ): super().__init__() - if resnet_dt: - raise NotImplementedError("resnet_dt is not supported.") - if not type_one_side: - raise NotImplementedError("type_one_side is not supported.") - if precision != "default" and precision != "float64": - raise NotImplementedError("precison is not supported.") if stripped_type_embedding: raise NotImplementedError("stripped_type_embedding is not supported.") - if smooth_type_embedding: - raise NotImplementedError("smooth_type_embedding is not supported.") - del type + if spin is not None: + raise NotImplementedError("old implementation of spin is not supported.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) + # TODO + if tebd_input_mode != "concat": + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + + del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( rcut, rcut_smth, @@ -92,20 +252,22 @@ def __init__( attn=attn, attn_layer=attn_layer, attn_dotr=attn_dotr, - attn_mask=attn_mask, - post_ln=post_ln, - ffn=ffn, - ffn_embed_dim=ffn_embed_dim, + attn_mask=False, activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, scaling_factor=scaling_factor, - head_num=head_num, normalize=normalize, temperature=temperature, - return_rot=return_rot, + smooth=smooth_type_embedding, + type_one_side=type_one_side, exclude_types=exclude_types, env_protection=env_protection, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + old_impl=old_impl, ) - self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd # set trainable @@ -204,14 +366,86 @@ def compute_input_stats( """ return self.se_atten.compute_input_stats(merged, path) + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.se_atten.mean = mean + self.se_atten.stddev = stddev + def serialize(self) -> dict: - """Serialize the obj to dict.""" - raise NotImplementedError + obj = self.se_atten + return { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn_dim, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth_type_embedding": obj.smooth, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "spin": None, + } @classmethod - def deserialize(cls) -> "DescrptDPA1": - """Deserialize from a dict.""" - raise NotImplementedError + def deserialize(cls, data: dict) -> "DescrptDPA1": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_atten["davg"] = t_cvt(variables["davg"]) + obj.se_atten["dstd"] = t_cvt(variables["dstd"]) + obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj def forward( self, @@ -224,9 +458,9 @@ def forward( Parameters ---------- - coord_ext + extended_coord The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext + extended_atype The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index e17b7c5d54..8b83f0d27b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -624,34 +624,3 @@ def forward( None, sw, ) - - -def analyze_descrpt(matrix, ndescrpt, natoms): - """Collect avg, square avg and count of descriptors in a batch.""" - ntypes = natoms.shape[1] - 2 - start_index = 0 - sysr = [] - sysa = [] - sysn = [] - sysr2 = [] - sysa2 = [] - for type_i in range(ntypes): - end_index = start_index + natoms[0, 2 + type_i] - dd = matrix[:, start_index:end_index] # all descriptors for this element - start_index = end_index - dd = np.reshape( - dd, [-1, 4] - ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr.append(sumr) - sysa.append(suma) - sysn.append(sumn) - sysr2.append(sumr2) - sysa2.append(suma2) - return sysr, sysr2, sysa, sysa2, sysn diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 051c66385c..66da86ce29 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -8,8 +8,9 @@ Union, ) -import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as torch_func from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, @@ -17,6 +18,14 @@ from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat, ) +from deepmd.pt.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + MLPLayer, + NetworkCollection, +) from deepmd.pt.model.network.network import ( NeighborWiseAttention, TypeFilter, @@ -24,6 +33,10 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -36,53 +49,111 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @DescriptorBlock.register("se_atten") class DescrptBlockSeAtten(DescriptorBlock): def __init__( self, - rcut, - rcut_smth, - sel, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", - # set_davg_zero: bool = False, - set_davg_zero: bool = True, # TODO + set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, + smooth: bool = True, + type_one_side: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, type: Optional[str] = None, + old_impl: bool = False, ): - """Construct an embedding net of type `se_atten`. - - Args: - - rcut: Cut-off radius. - - rcut_smth: Smooth hyper-parameter for pair force & energy. - - sel: For each element type, how many atoms is selected as neighbors. - - filter_neuron: Number of neurons in each hidden layers of the embedding net. - - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. + r"""Construct an embedding net of type `se_atten`. + + Parameters + ---------- + rcut : float + The cut-off radius :math:`r_c` + rcut_smth : float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron : int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim : int + Dimension of the type embedding + tebd_input_mode : str + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) + resnet_dt : bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable_ln : bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, Optional + The epsilon value for layer normalization. + type_one_side : bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn : int + Hidden dimension of the attention vectors + attn_layer : int + Number of attention layers + attn_dotr : bool + If dot the angular gate to the attention weights + attn_mask : bool + (Only support False to keep consistent with other backend references.) + (Not used in this version.) + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero : bool + Set the shift of embedding net input to zero. + activation_function : str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision : str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor : float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize : bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature : float + If not None, the scaling of attention weights is `temperature` itself. """ super().__init__() del type self.rcut = rcut self.rcut_smth = rcut_smth - self.filter_neuron = neuron + self.neuron = neuron + self.filter_neuron = self.neuron self.axis_neuron = axis_neuron self.tebd_dim = tebd_dim self.tebd_input_mode = tebd_input_mode @@ -91,18 +162,22 @@ def __init__( self.attn_layer = attn_layer self.attn_dotr = attn_dotr self.attn_mask = attn_mask - self.post_ln = post_ln - self.ffn = ffn - self.ffn_embed_dim = ffn_embed_dim - self.activation = activation_function - # TODO: To be fixed: precision should be given from inputs - self.prec = torch.float64 + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt self.scaling_factor = scaling_factor - self.head_num = head_num self.normalize = normalize self.temperature = temperature - self.return_rot = return_rot + self.smooth = smooth + self.type_one_side = type_one_side self.env_protection = env_protection + self.trainable_ln = trainable_ln + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + self.ln_eps = ln_eps + self.old_impl = old_impl if isinstance(sel, int): sel = [sel] @@ -115,22 +190,36 @@ def __init__( self.ndescrpt = self.nnei * 4 # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) - self.dpa1_attention = NeighborWiseAttention( - self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - post_ln=self.post_ln, - ffn=self.ffn, - ffn_embed_dim=self.ffn_embed_dim, - activation=self.activation, - scaling_factor=self.scaling_factor, - head_num=self.head_num, - normalize=self.normalize, - temperature=self.temperature, - ) + if self.old_impl: + self.dpa1_attention = NeighborWiseAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + activation=self.activation_function, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + smooth=self.smooth, + ) + else: + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + smooth=self.smooth, + precision=self.precision, + ) wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros( @@ -141,19 +230,41 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - - filter_layers = [] - one = TypeFilter( - 0, - self.nnei, - self.filter_neuron, - return_G=True, - tebd_dim=self.tebd_dim, - use_tebd=True, - tebd_mode=self.tebd_input_mode, - ) - filter_layers.append(one) - self.filter_layers = torch.nn.ModuleList(filter_layers) + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + self.embd_input_dim = 1 + self.tebd_dim * 2 + else: + self.embd_input_dim = 1 + self.tebd_dim + else: + self.embd_input_dim = 1 + + self.filter_layers_old = None + self.filter_layers = None + if self.old_impl: + filter_layers = [] + one = TypeFilter( + 0, + self.nnei, + self.filter_neuron, + return_G=True, + tebd_dim=self.tebd_dim, + use_tebd=True, + tebd_mode=self.tebd_input_mode, + ) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + else: + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers self.stats = None def get_rcut(self) -> float: @@ -184,6 +295,22 @@ def get_dim_emb(self) -> int: """Returns the output dimension of embedding.""" return self.filter_neuron[-1] + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; @@ -272,19 +399,38 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - """Calculate decoded embedding for each atom. + ): + """Compute the descriptor. - Args: - - coord: Tell atom coordinates with shape [nframes, natoms[1]*3]. - - atype: Tell atom types with shape [nframes, natoms[1]]. - - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. - - box: Tell simulation box with shape [nframes, 9]. + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall x nt + extended_atype_embd + The extended type embedding of atoms. shape: nf x nall + mapping + The index mapping, not required by this descriptor. Returns ------- - - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. - - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] + result + The descriptor. shape: nf x nloc x (ng x axis_neuron) + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + """ del mapping assert extended_atype_embd is not None @@ -302,8 +448,6 @@ def forward( self.rcut_smth, protection=self.env_protection, ) - # [nfxnlocxnnei, self.ndescrpt] - dmatrix = dmatrix.view(-1, self.ndescrpt) nlist_mask = nlist != -1 nlist[nlist == -1] = 0 sw = torch.squeeze(sw, -1) @@ -321,23 +465,60 @@ def forward( atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) - ret = self.filter_layers[0]( - dmatrix, - atype_tebd=atype_tebd_nnei, - nlist_tebd=atype_tebd_nlist, - ) # shape is [nframes*nall, self.neei, out_size] - input_r = torch.nn.functional.normalize( - dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - ret = self.dpa1_attention( - ret, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( - 0, 2, 1 - ) # shape is [nframes*natoms[0], 4, self.neei] - xyz_scatter = torch.matmul( - inputs_reshape, ret - ) # shape is [nframes*natoms[0], 4, out_size] + # (nb x nloc) x nnei + exclude_mask = self.emask(nlist, extended_atype).view(nb * nloc, nnei) + if self.old_impl: + assert self.filter_layers_old is not None + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] + gg = self.filter_layers_old[0]( + dmatrix, + atype_tebd=atype_tebd_nnei, + nlist_tebd=atype_tebd_nlist, + ) # shape is [nframes*nall, self.neei, out_size] + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( + 0, 2, 1 + ) # shape is [nframes*natoms[0], 4, self.neei] + xyz_scatter = torch.matmul( + inputs_reshape, gg + ) # shape is [nframes*natoms[0], 4, out_size] + else: + assert self.filter_layers is not None + # nfnl x nnei x 4 + dmatrix = dmatrix.view(-1, self.nnei, 4) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None] + ss = rr[:, :, :1] + if self.tebd_input_mode in ["concat"]: + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) + if not self.type_one_side: + # nfnl x nnei x (1 + tebd_dim * 2) + ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = torch.concat([ss, nlist_tebd], dim=2) + else: + raise NotImplementedError + # nfnl x nnei x ng + gg = self.filter_layers._networks[0](ss) + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] @@ -347,66 +528,406 @@ def forward( ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] return ( result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), + gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]), dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], rot_mat.view(-1, nloc, self.filter_neuron[-1], 3), sw, ) -def analyze_descrpt(matrix, ndescrpt, natoms, mixed_types=False, real_atype=None): - """Collect avg, square avg and count of descriptors in a batch.""" - ntypes = natoms.shape[1] - 2 - if not mixed_types: - sysr = [] - sysa = [] - sysn = [] - sysr2 = [] - sysa2 = [] - start_index = 0 - for type_i in range(ntypes): - end_index = start_index + natoms[0, 2 + type_i] - dd = matrix[:, start_index:end_index] - start_index = end_index - dd = np.reshape( - dd, [-1, 4] - ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr.append(sumr) - sysa.append(suma) - sysn.append(sumn) - sysr2.append(sumr2) - sysa2.append(suma2) - else: - sysr = [0.0 for i in range(ntypes)] - sysa = [0.0 for i in range(ntypes)] - sysn = [0 for i in range(ntypes)] - sysr2 = [0.0 for i in range(ntypes)] - sysa2 = [0.0 for i in range(ntypes)] - for frame_item in range(matrix.shape[0]): - dd_ff = matrix[frame_item] - atype_frame = real_atype[frame_item] - for type_i in range(ntypes): - type_idx = atype_frame == type_i - dd = dd_ff[type_idx] - dd = np.reshape(dd, [-1, 4]) # typen_atoms * nnei, 4 - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr[type_i] += sumr - sysa[type_i] += suma - sysn[type_i] += sumn - sysr2[type_i] += sumr2 - sysa2[type_i] += suma2 - - return sysr, sysr2, sysa, sysa2, sysn +class NeighborGatedAttention(nn.Module): + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.smooth = smooth + self.precision = precision + self.network_type = NeighborGatedAttentionLayer + attention_layers = [] + for i in range(self.layer_num): + attention_layers.append( + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth=smooth, + precision=precision, + ) + ) + self.attention_layers = nn.ModuleList(attention_layers) + + def forward( + self, + input_G, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + """Compute the multi-layer gated self-attention. + + Parameters + ---------- + input_G + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: nf x nloc x nnei + """ + out = input_G + # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def __getitem__(self, key): + if isinstance(key, int): + return self.attention_layers[key] + else: + raise TypeError(key) + + def __setitem__(self, key, value): + if not isinstance(key, int): + raise TypeError(key) + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[key] = value + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "NeighborGatedAttention", + "@version": 1, + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + attention_layers = data.pop("attention_layers") + obj = cls(**data) + for ii, network in enumerate(attention_layers): + obj[ii] = network + return obj + + +class NeighborGatedAttentionLayer(nn.Module): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + ) + self.attn_layer_norm = LayerNorm( + self.embed_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + ) + + def forward( + self, + x, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + residual = x + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(nn.Module): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = MLPLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) + self.out_proj = MLPLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) + + def forward( + self, + query, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + attnw_shift: float = 20.0, + ): + """Compute the gated self-attention. + + Parameters + ---------- + query + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: (nf x nloc) x nnei + attnw_shift : float + The attention weight shift to preserve smoothness when doing padding before softmax. + """ + q, k, v = self.in_proj(query).chunk(3, dim=-1) + # [nframes * nloc, nnei, hidden_dim] + q = q.view(-1, self.nnei, self.hidden_dim) + k = k.view(-1, self.nnei, self.hidden_dim) + v = v.view(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = torch_func.normalize(q, dim=-1) + k = torch_func.normalize(k, dim=-1) + v = torch_func.normalize(v, dim=-1) + q = q * self.scaling + k = k.transpose(1, 2) + # [nframes * nloc, nnei, nnei] + attn_weights = torch.bmm(q, k) + # [nframes * nloc, nnei] + nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: + # [nframes * nloc, nnei] + assert sw is not None + sw = sw.view([-1, self.nnei]) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ + :, None, : + ] - attnw_shift + else: + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1), float("-inf") + ) + attn_weights = torch_func.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) + if self.smooth: + assert sw is not None + attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + if self.dotr: + assert input_r is not None, "input_r must be provided when dotr is True!" + angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + attn_weights = attn_weights * angular_weight + o = torch.bmm(attn_weights, v) + output = self.out_proj(o) + return output + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} + # network_type_name = network_type_map_inv[self.network_type] + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "GatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = MLPLayer.deserialize(in_proj) + obj.out_proj = MLPLayer.deserialize(out_proj) + return obj diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py new file mode 100644 index 0000000000..27b9808010 --- /dev/null +++ b/deepmd/pt/model/network/layernorm.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.network import LayerNorm as DPLayerNorm +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .mlp import ( + MLPLayer, +) + +device = env.DEVICE + + +class LayerNorm(MLPLayer): + def __init__( + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + bavg: float = 0.0, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + ): + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + bavg=bavg, + stddev=stddev, + precision=precision, + ) + self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) + if self.uni_init: + nn.init.ones_(self.matrix.data) + nn.init.zeros_(self.bias.data) + self.trainable = trainable + if not self.trainable: + self.matrix.requires_grad = False + self.bias.requires_grad = False + + def dim_out(self) -> int: + return self.matrix.shape[0] + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """One Layer Norm used by DP model. + + Parameters + ---------- + xx : torch.Tensor + The input of index. + + Returns + ------- + yy: torch.Tensor + The output. + """ + mean = xx.mean(dim=-1, keepdim=True) + variance = xx.var(dim=-1, unbiased=False, keepdim=True) + yy = (xx - mean) / torch.sqrt(variance + self.eps) + if self.matrix is not None and self.bias is not None: + yy = yy * self.matrix + self.bias + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = DPLayerNorm( + self.matrix.shape[0], + eps=self.eps, + trainable=self.trainable, + precision=self.precision, + ) + nl.w = to_numpy_array(self.matrix) + nl.b = to_numpy_array(self.bias) + data = nl.serialize() + return data + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + nl = DPLayerNorm.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + eps=nl["eps"], + trainable=nl["trainable"], + precision=nl["precision"], + ) + prec = PRECISION_DICT[obj.precision] + + def check_load_param(ss): + return ( + nn.Parameter(data=to_torch_tensor(nl[ss])) + if nl[ss] is not None + else None + ) + + obj.matrix = check_load_param("matrix") + obj.bias = check_load_param("bias") + return obj diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index c895f642e1..09d9945b3b 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -556,15 +556,19 @@ def forward(self, inputs): class TypeEmbedNet(nn.Module): - def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0): + def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0, precision="default"): """Construct a type embedding net.""" super().__init__() + self.type_nums = type_nums + self.embed_dim = embed_dim + self.bavg = bavg + self.stddev = stddev self.embedding = TypeEmbedNetConsistent( - ntypes=type_nums, - neuron=[embed_dim], + ntypes=self.type_nums, + neuron=[self.embed_dim], padding=True, activation_function="Linear", - precision="default", + precision=precision, ) # nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) @@ -847,6 +851,7 @@ def __init__( head_num=1, normalize=True, temperature=None, + smooth=True, ): """Construct a neighbor-wise attention net.""" super().__init__() @@ -868,6 +873,7 @@ def __init__( head_num=head_num, normalize=normalize, temperature=temperature, + smooth=smooth, ) ) self.attention_layers = nn.ModuleList(attention_layers) @@ -915,6 +921,7 @@ def __init__( head_num=1, normalize=True, temperature=None, + smooth=True, ): """Construct a neighbor-wise attention layer.""" super().__init__() @@ -925,6 +932,7 @@ def __init__( self.do_mask = do_mask self.post_ln = post_ln self.ffn = ffn + self.smooth = smooth self.attention_layer = GatedSelfAttetion( nnei, embed_dim, @@ -935,6 +943,7 @@ def __init__( head_num=head_num, normalize=normalize, temperature=temperature, + smooth=smooth, ) self.attn_layer_norm = nn.LayerNorm( self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE diff --git a/deepmd/tf/descriptor/se.py b/deepmd/tf/descriptor/se.py index 4232503464..5a3c5dd447 100644 --- a/deepmd/tf/descriptor/se.py +++ b/deepmd/tf/descriptor/se.py @@ -296,7 +296,10 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict: net_idx.append(rest_ii % embeddings.ntypes) rest_ii //= embeddings.ntypes net_idx = tuple(net_idx) - if embeddings.ndim in (0, 1): + if embeddings.ndim == 0: + key0 = "all" + key1 = "" + elif embeddings.ndim == 1: key0 = "all" key1 = f"_{ii}" elif embeddings.ndim == 2: diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 0cd5a96632..0ba426ee4b 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import re import warnings from typing import ( + Any, List, Optional, Set, Tuple, + Union, ) import numpy as np @@ -13,11 +16,19 @@ Version, ) +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) from deepmd.tf.common import ( cast_precision, get_np_precision, ) from deepmd.tf.env import ( + ATTENTION_LAYER_PATTERN, GLOBAL_NP_FLOAT_PRECISION, GLOBAL_TF_FLOAT_PRECISION, TF_VERSION, @@ -50,6 +61,7 @@ ) from deepmd.tf.utils.network import ( embedding_net, + layernorm, one_layer, ) from deepmd.tf.utils.sess import ( @@ -58,9 +70,15 @@ from deepmd.tf.utils.tabulate import ( DPTabulate, ) +from deepmd.tf.utils.type_embed import ( + TypeEmbedNet, +) from deepmd.tf.utils.update_sel import ( UpdateSel, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .descriptor import ( Descriptor, @@ -79,51 +97,55 @@ class DescrptSeAtten(DescrptSeA): Parameters ---------- - rcut + rcut: float The cut-off radius :math:`r_c` - rcut_smth + rcut_smth: float From where the environment matrix should be smoothed :math:`r_s` - sel : int - sel[i] specifies the maxmum number of type i atoms in the cut-off radius - neuron : list[int] + sel: list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + neuron: list[int] Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` - axis_neuron + axis_neuron: int Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) - resnet_dt + resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) - trainable + trainable: bool If the weights of embedding net are trainable. - seed + seed: int, Optional Random seed for initializing the network parameters. - type_one_side + type_one_side: bool Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. - set_davg_zero + set_davg_zero: bool Set the shift of embedding net input to zero. - activation_function + activation_function: str The activation function in the embedding net. Supported options are |ACTIVATION_FN| - precision + precision: str The precision of the embedding net parameters. Supported options are |PRECISION| - uniform_seed + uniform_seed: bool Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed - attn + attn: int The length of hidden vector during scale-dot attention computation. - attn_layer + attn_layer: int The number of layers in attention mechanism. - attn_dotr + attn_dotr: bool Whether to dot the relative coordinates on the attention weights as a gated scheme. - attn_mask + attn_mask: bool Whether to mask the diagonal in the attention weights. - multi_task + ln_eps: float, Optional + The epsilon value for layer normalization. + multi_task: bool If the model has multi fitting nets to train. - stripped_type_embedding + stripped_type_embedding: bool Whether to strip the type embedding into a separated embedding network. Default value will be True in `se_atten_v2` descriptor. - smooth_type_embedding - When using stripped type embedding, whether to dot smooth factor on the network output of type embedding + smooth_type_embedding: bool + Whether to use smooth process in attention weights calculation. + And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. Default value will be True in `se_atten_v2` descriptor. @@ -137,9 +159,9 @@ def __init__( self, rcut: float, rcut_smth: float, - sel: int, + sel: Union[List[int], int], ntypes: int, - neuron: List[int] = [24, 48, 96], + neuron: List[int] = [25, 50, 100], axis_neuron: int = 8, resnet_dt: bool = False, trainable: bool = True, @@ -158,15 +180,13 @@ def __init__( stripped_type_embedding: bool = False, smooth_type_embedding: bool = False, # not implemented - post_ln=True, - ffn=False, - ffn_embed_dim=1024, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-3, concat_output_tebd: bool = True, + env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: if not set_davg_zero and not ( @@ -176,24 +196,21 @@ def __init__( "Set 'set_davg_zero' False in descriptor 'se_atten' " "may cause unexpected incontinuity during model inference!" ) - if not post_ln: - raise NotImplementedError("post_ln is not supported.") - if ffn: - raise NotImplementedError("ffn is not supported.") - if ffn_embed_dim != 1024: - raise NotImplementedError("ffn_embed_dim is not supported.") if scaling_factor != 1.0: raise NotImplementedError("scaling_factor is not supported.") - if head_num != 1: - raise NotImplementedError("head_num is not supported.") if not normalize: raise NotImplementedError("normalize is not supported.") if temperature is not None: raise NotImplementedError("temperature is not supported.") - if return_rot: - raise NotImplementedError("return_rot is not supported.") if not concat_output_tebd: raise NotImplementedError("concat_output_tebd is not supported.") + if env_protection != 0.0: + raise NotImplementedError("env_protection != 0.0 is not supported.") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-3 + if isinstance(sel, list): + sel = sum(sel) DescrptSeA.__init__( self, rcut, @@ -223,6 +240,8 @@ def __init__( raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding self.smooth = smooth_type_embedding + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.ntypes = ntypes self.att_n = attn self.attn_layer = attn_layer @@ -241,12 +260,6 @@ def __init__( std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( GLOBAL_NP_FLOAT_PRECISION ) - self.beta = np.zeros([self.attn_layer, self.filter_neuron[-1]]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) - self.gamma = np.ones([self.attn_layer, self.filter_neuron[-1]]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) self.attention_layer_variables = None sub_graph = tf.Graph() with sub_graph.as_default(): @@ -885,38 +898,6 @@ def _lookup_type_embedding( return self.embedding_input_2 return self.embedding_input - def _feedforward(self, input_xyz, d_in, d_mid): - residual = input_xyz - input_xyz = tf.nn.relu( - one_layer( - input_xyz, - d_mid, - name="c_ffn1", - reuse=tf.AUTO_REUSE, - seed=self.seed, - activation_fn=None, - precision=self.filter_precision, - trainable=True, - uniform_seed=self.uniform_seed, - initial_variables=self.attention_layer_variables, - ) - ) - input_xyz = one_layer( - input_xyz, - d_in, - name="c_ffn2", - reuse=tf.AUTO_REUSE, - seed=self.seed, - activation_fn=None, - precision=self.filter_precision, - trainable=True, - uniform_seed=self.uniform_seed, - initial_variables=self.attention_layer_variables, - ) - input_xyz += residual - input_xyz = tf.keras.layers.LayerNormalization()(input_xyz) - return input_xyz - def _scaled_dot_attn( self, Q, @@ -1053,12 +1034,19 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) - input_xyz = tf.keras.layers.LayerNormalization( - beta_initializer=tf.constant_initializer(self.beta[i]), - gamma_initializer=tf.constant_initializer(self.gamma[i]), - dtype=self.filter_precision, - )(input_xyz) - # input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n) + input_xyz = layernorm( + input_xyz, + outputs_size[-1], + precision=self.filter_precision, + name="layer_normalization", + scope=name + "/", + reuse=tf.AUTO_REUSE, + seed=self.seed, + uniform_seed=self.uniform_seed, + trainable=self.trainable_ln, + eps=self.ln_eps, + initial_variables=self.attention_layer_variables, + ) return input_xyz def _filter_lower( @@ -1366,20 +1354,17 @@ def init_variables( self.attention_layer_variables = get_attention_layer_variables_from_graph_def( graph_def, suffix=suffix ) - if self.attn_layer > 0: - self.beta[0] = self.attention_layer_variables[ - f"attention_layer_0{suffix}/layer_normalization/beta" - ] - self.gamma[0] = self.attention_layer_variables[ - f"attention_layer_0{suffix}/layer_normalization/gamma" - ] - for i in range(1, self.attn_layer): - self.beta[i] = self.attention_layer_variables[ - f"attention_layer_{i}{suffix}/layer_normalization_{i}/beta" - ] - self.gamma[i] = self.attention_layer_variables[ - f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma" - ] + + def compat_ln_pattern(old_key): + pattern = r"attention_layer_(\d+)/(layer_normalization)_\d+" + replacement = r"attention_layer_\1/\2" + if bool(re.search(pattern, old_key)): + new_key = re.sub(pattern, replacement, old_key) + v = self.attention_layer_variables.pop(old_key) + self.attention_layer_variables[new_key] = v + + for item_key in list(self.attention_layer_variables.keys()): + compat_ln_pattern(item_key) if self.stripped_type_embedding: self.two_side_embeeding_net_variables = ( @@ -1493,3 +1478,658 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): """ local_jdata_cpy = local_jdata.copy() return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + def serialize_attention_layers( + self, + nlayer: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool, + do_mask: bool, + trainable_ln: bool, + ln_eps: float, + variables: dict, + bias: bool = True, + suffix: str = "", + ) -> dict: + data = { + "layer_num": nlayer, + "nnei": nnei, + "embed_dim": embed_dim, + "hidden_dim": hidden_dim, + "dotr": dotr, + "do_mask": do_mask, + "trainable_ln": trainable_ln, + "ln_eps": ln_eps, + "precision": self.precision.name, + "attention_layers": [], + } + if suffix != "": + attention_layer_pattern = ( + ATTENTION_LAYER_PATTERN.replace("/(c_query)", suffix + "/(c_query)") + .replace("/(c_key)", suffix + "/(c_key)") + .replace("/(c_value)", suffix + "/(c_value)") + .replace("/(c_out)", suffix + "/(c_out)") + .replace("/(layer_normalization)", suffix + "/(layer_normalization)") + ) + else: + attention_layer_pattern = ATTENTION_LAYER_PATTERN + attention_layer_params = [{} for _ in range(nlayer)] + for key, value in variables.items(): + m = re.search(attention_layer_pattern, key) + m = [mm for mm in m.groups() if mm is not None] + assert len(m) == 3 + if m[1] not in attention_layer_params[int(m[0])]: + attention_layer_params[int(m[0])][m[1]] = {} + attention_layer_params[int(m[0])][m[1]][m[2]] = value + + for layer_idx in range(nlayer): + in_proj = NativeLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + precision=self.precision.name, + ) + matrix_list = [ + attention_layer_params[layer_idx][key]["matrix"] + for key in ["c_query", "c_key", "c_value"] + ] + in_proj["matrix"] = np.concatenate(matrix_list, axis=-1) + if bias: + bias_list = [ + attention_layer_params[layer_idx][key]["bias"] + for key in ["c_query", "c_key", "c_value"] + ] + in_proj["bias"] = np.concatenate(bias_list, axis=-1) + out_proj = NativeLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + precision=self.precision.name, + ) + out_proj["matrix"] = attention_layer_params[layer_idx]["c_out"]["matrix"] + if bias: + out_proj["bias"] = attention_layer_params[layer_idx]["c_out"]["bias"] + + layer_norm = LayerNorm( + embed_dim, + trainable=self.trainable_ln, + eps=self.ln_eps, + precision=self.precision.name, + ) + layer_norm["matrix"] = attention_layer_params[layer_idx][ + "layer_normalization" + ]["gamma"] + layer_norm["bias"] = attention_layer_params[layer_idx][ + "layer_normalization" + ]["beta"] + data["attention_layers"].append( + { + "attention_layer": { + "in_proj": in_proj.serialize(), + "out_proj": out_proj.serialize(), + "bias": bias, + "smooth": self.smooth, + }, + "attn_layer_norm": layer_norm.serialize(), + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + } + ) + return data + + @classmethod + def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: + """Deserialize attention layers. + + Parameters + ---------- + data : dict + The input attention layer data + suffix : str, optional + The suffix of the scope + + Returns + ------- + variables : dict + The input variables + """ + attention_layer_variables = {} + nlayer = data["layer_num"] + hidden_dim = data["hidden_dim"] + + for layer_idx in range(nlayer): + in_proj = NativeLayer.deserialize( + data["attention_layers"][layer_idx]["attention_layer"]["in_proj"] + ) + out_proj = NativeLayer.deserialize( + data["attention_layers"][layer_idx]["attention_layer"]["out_proj"] + ) + layer_norm = LayerNorm.deserialize( + data["attention_layers"][layer_idx]["attn_layer_norm"] + ) + + # Deserialize in_proj + c_query_matrix = in_proj["matrix"][:, :hidden_dim] + c_key_matrix = in_proj["matrix"][:, hidden_dim : 2 * hidden_dim] + c_value_matrix = in_proj["matrix"][:, 2 * hidden_dim :] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_query/matrix" + ] = c_query_matrix + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_key/matrix" + ] = c_key_matrix + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_value/matrix" + ] = c_value_matrix + if data["attention_layers"][layer_idx]["attention_layer"]["bias"]: + c_query_bias = in_proj["bias"][:hidden_dim] + c_key_bias = in_proj["bias"][hidden_dim : 2 * hidden_dim] + c_value_bias = in_proj["bias"][2 * hidden_dim :] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_query/bias" + ] = c_query_bias + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_key/bias" + ] = c_key_bias + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_value/bias" + ] = c_value_bias + + # Deserialize out_proj + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_out/matrix" + ] = out_proj["matrix"] + if data["attention_layers"][layer_idx]["attention_layer"]["bias"]: + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_out/bias" + ] = out_proj["bias"] + + # Deserialize layer_norm + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/layer_normalization/beta" + ] = layer_norm["bias"] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/layer_normalization/gamma" + ] = layer_norm["matrix"] + return attention_layer_variables + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeAtten: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + if type(self) not in [DescrptSeAtten, DescrptDPA1Compat]: + raise NotImplementedError( + "Not implemented in class %s" % self.__class__.__name__ + ) + if self.stripped_type_embedding: + raise NotImplementedError( + "stripped_type_embedding is unsupported by the native model" + ) + if (self.original_sel != self.sel_a).any(): + raise NotImplementedError( + "Adjusting sel is unsupported by the native model" + ) + if self.embedding_net_variables is None: + raise RuntimeError("init_variables must be called before serialize") + if self.spin is not None: + raise NotImplementedError("spin is unsupported") + assert self.davg is not None + assert self.dstd is not None + + return { + "@class": "Descriptor", + "type": "se_atten", + "@version": 1, + "rcut": self.rcut_r, + "rcut_smth": self.rcut_r_smth, + "sel": self.sel_a, + "ntypes": self.ntypes, + "neuron": self.filter_neuron, + "axis_neuron": self.n_axis_neuron, + "set_davg_zero": self.set_davg_zero, + "attn": self.att_n, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": self.attn_mask, + "activation_function": self.activation_function_name, + "resnet_dt": self.filter_resnet_dt, + "smooth_type_embedding": self.smooth, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.filter_precision.name, + "embeddings": self.serialize_network( + ntypes=self.ntypes, + ndim=0, + in_dim=1 + if not hasattr(self, "embd_input_dim") + else self.embd_input_dim, + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.embedding_net_variables, + excluded_types=self.exclude_types, + suffix=suffix, + ), + "attention_layers": self.serialize_attention_layers( + nlayer=self.attn_layer, + nnei=self.nnei_a, + embed_dim=self.filter_neuron[-1], + hidden_dim=self.att_n, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + variables=self.attention_layer_variables, + suffix=suffix, + ), + "env_mat": EnvMat(self.rcut_r, self.rcut_r_smth).serialize(), + "exclude_types": list(self.orig_exclude_types), + "env_protection": self.env_protection, + "@variables": { + "davg": self.davg.reshape(self.ntypes, self.nnei_a, 4), + "dstd": self.dstd.reshape(self.ntypes, self.nnei_a, 4), + }, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "spin": self.spin, + } + + +class DescrptDPA1Compat(DescrptSeAtten): + r"""Consistent version of the model for testing with other backend references. + + This model includes the type_embedding as attributes and other additional parameters. + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel: list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes: int + Number of element types + neuron: list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + (Only support `concat` to keep consistent with other backend references.) + The way to mix the type embeddings. + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + (Only support False to keep consistent with other backend references.) + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + (Only to keep consistent with other backend references.) + (Not used in this version.) + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + (Only support True to keep consistent with other backend references.) + (Not used in this version.) + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + (Only support 1.0 to keep consistent with other backend references.) + (Not used in this version.) + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embedding: bool + (Only support False to keep consistent with other backend references.) + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + spin + (Only support None to keep consistent with old implementation.) + The old implementation of deepspin. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "default", + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-3, + smooth_type_embedding: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + uniform_seed: bool = False, + ) -> None: + if tebd_input_mode != "concat": + raise NotImplementedError( + "Only support tebd_input_mode == `concat` in this version." + ) + if not normalize: + raise NotImplementedError("Only support normalize == True in this version.") + if temperature != 1.0: + raise NotImplementedError( + "Only support temperature == 1.0 in this version." + ) + if spin is not None: + raise NotImplementedError("Only support spin is None in this version.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-3 + + super().__init__( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + resnet_dt=resnet_dt, + trainable=trainable, + seed=seed, + type_one_side=type_one_side, + set_davg_zero=set_davg_zero, + exclude_types=exclude_types, + activation_function=activation_function, + precision=precision, + uniform_seed=uniform_seed, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=attn_mask, + multi_task=True, + stripped_type_embedding=False, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth_type_embedding=smooth_type_embedding, + env_protection=env_protection, + ) + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.type_embedding = TypeEmbedNet( + ntypes=self.ntypes, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + # precision=precision, + ) + self.concat_output_tebd = concat_output_tebd + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + self.embd_input_dim = 1 + self.tebd_dim * 2 + else: + self.embd_input_dim = 1 + self.tebd_dim + else: + self.embd_input_dim = 1 + + def build( + self, + coord_: tf.Tensor, + atype_: tf.Tensor, + natoms: tf.Tensor, + box_: tf.Tensor, + mesh: tf.Tensor, + input_dict: dict, + reuse: Optional[bool] = None, + suffix: str = "", + ) -> tf.Tensor: + type_embedding = self.type_embedding.build(self.ntypes, suffix=suffix) + input_dict["type_embedding"] = type_embedding + + # nf x nloc x out_dim + self.dout = super().build( + coord_, + atype_, + natoms, + box_, + mesh, + input_dict, + reuse=reuse, + suffix=suffix, + ) + # self.dout = tf.cast(self.dout, self.filter_precision) + if self.concat_output_tebd: + atype = tf.reshape(atype_, [-1, natoms[1]]) + atype_nloc = tf.reshape( + tf.slice(atype, [0, 0], [-1, natoms[0]]), [-1] + ) ## lammps will have error without this + atom_embed = tf.reshape( + tf.nn.embedding_lookup(type_embedding, atype_nloc), + [-1, natoms[0], self.tebd_dim], + ) + atom_embed = tf.cast(atom_embed, GLOBAL_TF_FLOAT_PRECISION) + # nf x nloc x (out_dim + tebd_dim) + self.dout = tf.concat([self.dout, atom_embed], axis=-1) + return self.dout + + def init_variables( + self, + graph: tf.Graph, + graph_def: tf.GraphDef, + suffix: str = "", + ) -> None: + """Init the embedding net variables with the given dict. + + Parameters + ---------- + graph : tf.Graph + The input frozen model graph + graph_def : tf.GraphDef + The input frozen model graph_def + suffix : str, optional + The suffix of the scope + """ + super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix) + self.type_embedding.init_variables( + graph=graph, graph_def=graph_def, suffix=suffix + ) + + def update_attention_layers_serialize(self, data: dict): + """Update the serialized data to be consistent with other backend references.""" + new_dict = { + "@class": "NeighborGatedAttention", + "@version": 1, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + } + new_dict.update(data) + update_info = { + "nnei": self.nnei_a, + "embed_dim": self.filter_neuron[-1], + "hidden_dim": self.att_n, + "dotr": self.attn_dotr, + "do_mask": self.attn_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.filter_precision.name, + } + for layer_idx in range(self.attn_layer): + new_dict["attention_layers"][layer_idx].update(update_info) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + update_info + ) + return new_dict + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptDPA1Compat: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + type_embedding = data.pop("type_embedding") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.type_embedding = TypeEmbedNet.deserialize( + type_embedding, suffix=suffix + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + data = super().serialize(suffix) + data.update( + { + "type": "dpa1", + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "concat_output_tebd": self.concat_output_tebd, + "type_embedding": self.type_embedding.serialize(suffix), + } + ) + data["attention_layers"] = self.update_attention_layers_serialize( + data["attention_layers"] + ) + return data diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index c7873b951c..3d4edadd8a 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -178,19 +178,19 @@ def dlopen_library(module: str, filename: str): )[:-1] ATTENTION_LAYER_PATTERN = str( - r"attention_layer_\d+/c_query/matrix|" - r"attention_layer_\d+/c_query/bias|" - r"attention_layer_\d+/c_key/matrix|" - r"attention_layer_\d+/c_key/bias|" - r"attention_layer_\d+/c_value/matrix|" - r"attention_layer_\d+/c_value/bias|" - r"attention_layer_\d+/c_out/matrix|" - r"attention_layer_\d+/c_out/bias|" - r"attention_layer_\d+/layer_normalization/beta|" - r"attention_layer_\d+/layer_normalization/gamma|" - r"attention_layer_\d+/layer_normalization_\d+/beta|" - r"attention_layer_\d+/layer_normalization_\d+/gamma|" -) + r"attention_layer_(\d+)/(c_query)/(matrix)|" + r"attention_layer_(\d+)/(c_query)/(bias)|" + r"attention_layer_(\d+)/(c_key)/(matrix)|" + r"attention_layer_(\d+)/(c_key)/(bias)|" + r"attention_layer_(\d+)/(c_value)/(matrix)|" + r"attention_layer_(\d+)/(c_value)/(bias)|" + r"attention_layer_(\d+)/(c_out)/(matrix)|" + r"attention_layer_(\d+)/(c_out)/(bias)|" + r"attention_layer_(\d+)/(layer_normalization)/(beta)|" + r"attention_layer_(\d+)/(layer_normalization)/(gamma)|" + r"attention_layer_(\d+)/(layer_normalization)_\d+/(beta)|" + r"attention_layer_(\d+)/(layer_normalization)_\d+/(gamma)|" +)[:-1] TRANSFER_PATTERN = ( EMBEDDING_NET_PATTERN diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index a6e2ab7422..53de9c9ce2 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -455,11 +455,11 @@ def get_attention_layer_nodes_from_graph_def( """ if suffix != "": attention_layer_pattern = ( - ATTENTION_LAYER_PATTERN.replace("/c_query", suffix + "/c_query") - .replace("/c_key", suffix + "/c_key") - .replace("/c_value", suffix + "/c_value") - .replace("/c_out", suffix + "/c_out") - .replace("/layer_normalization", suffix + "/layer_normalization") + ATTENTION_LAYER_PATTERN.replace("/(c_query)", suffix + "/(c_query)") + .replace("/(c_key)", suffix + "/(c_key)") + .replace("/(c_value)", suffix + "/(c_value)") + .replace("/(c_out)", suffix + "/(c_out)") + .replace("/(layer_normalization)", suffix + "/(layer_normalization)") ) else: attention_layer_pattern = ATTENTION_LAYER_PATTERN diff --git a/deepmd/tf/utils/network.py b/deepmd/tf/utils/network.py index fb8e89c737..7918b58d0c 100644 --- a/deepmd/tf/utils/network.py +++ b/deepmd/tf/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + import numpy as np from deepmd.tf.common import ( @@ -105,6 +106,100 @@ def one_layer( return hidden +def layer_norm_tf(x, shape, weight=None, bias=None, eps=1e-5): + """ + Layer normalization implementation in TensorFlow. + + Parameters + ---------- + x : tf.Tensor + The input tensor. + shape : tuple + The shape of the weight and bias tensors. + weight : tf.Tensor + The weight tensor. + bias : tf.Tensor + The bias tensor. + eps : float + A small value added to prevent division by zero. + + Returns + ------- + tf.Tensor + The normalized output tensor. + """ + # Calculate the mean and variance + mean = tf.reduce_mean(x, axis=list(range(-len(shape), 0)), keepdims=True) + variance = tf.reduce_mean( + tf.square(x - mean), axis=list(range(-len(shape), 0)), keepdims=True + ) + + # Normalize the input + x_ln = (x - mean) / tf.sqrt(variance + eps) + + # Scale and shift the normalized input + if weight is not None and bias is not None: + x_ln = x_ln * weight + bias + + return x_ln + + +def layernorm( + inputs, + outputs_size, + precision=GLOBAL_TF_FLOAT_PRECISION, + name="linear", + scope="", + reuse=None, + seed=None, + uniform_seed=False, + uni_init=True, + eps=1e-5, + trainable=True, + initial_variables=None, +): + with tf.variable_scope(name, reuse=reuse): + shape = inputs.get_shape().as_list() + if uni_init: + gamma_initializer = tf.ones_initializer() + beta_initializer = tf.zeros_initializer() + else: + gamma_initializer = tf.random_normal_initializer( + seed=seed if (seed is None or uniform_seed) else seed + 0 + ) + beta_initializer = tf.random_normal_initializer( + seed=seed if (seed is None or uniform_seed) else seed + 1 + ) + if initial_variables is not None: + gamma_initializer = tf.constant_initializer( + initial_variables[scope + name + "/gamma"] + ) + beta_initializer = tf.constant_initializer( + initial_variables[scope + name + "/beta"] + ) + gamma = tf.get_variable( + "gamma", + [outputs_size], + precision, + gamma_initializer, + trainable=trainable, + ) + variable_summaries(gamma, "gamma") + beta = tf.get_variable( + "beta", [outputs_size], precision, beta_initializer, trainable=trainable + ) + variable_summaries(beta, "beta") + + output = layer_norm_tf( + inputs, + (outputs_size,), + weight=gamma, + bias=beta, + eps=eps, + ) + return output + + def embedding_net_rand_seed_shift(network_size): shift = 3 * (len(network_size) + 1) return shift diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 323ce44dfe..94c225010a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -409,7 +409,7 @@ def descrpt_se_atten_common_args(): ) doc_type_one_side = ( doc_only_tf_supported - + r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." + + r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." ) doc_precision = ( doc_only_tf_supported @@ -476,8 +476,12 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." - doc_smooth_type_embedding = "When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." + doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" + doc_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." + ) + doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." doc_tebd_dim = "The dimension of atom type embedding." doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." doc_scaling_factor = ( @@ -508,11 +512,15 @@ def descrpt_se_atten_args(): optional=True, default=False, alias=["smooth_type_embdding"], - doc=doc_only_tf_supported + doc_smooth_type_embedding, + doc=doc_smooth_type_embedding, ), Argument( "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero ), + Argument( + "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln + ), + Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps), # pt only Argument( "tebd_dim", @@ -528,27 +536,6 @@ def descrpt_se_atten_args(): default="concat", doc=doc_only_pt_supported + doc_deprecated, ), - Argument( - "post_ln", - bool, - optional=True, - default=True, - doc=doc_only_pt_supported + doc_deprecated, - ), - Argument( - "ffn", - bool, - optional=True, - default=False, - doc=doc_only_pt_supported + doc_deprecated, - ), - Argument( - "ffn_embed_dim", - int, - optional=True, - default=1024, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "scaling_factor", float, @@ -556,13 +543,6 @@ def descrpt_se_atten_args(): default=1.0, doc=doc_only_pt_supported + doc_scaling_factor, ), - Argument( - "head_num", - int, - optional=True, - default=1, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "normalize", bool, @@ -576,13 +556,6 @@ def descrpt_se_atten_args(): optional=True, doc=doc_only_pt_supported + doc_temperature, ), - Argument( - "return_rot", - bool, - optional=True, - default=False, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "concat_output_tebd", bool, diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 1511ac7fac..59333eb0da 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -133,7 +133,6 @@ An example of the DPA-1 descriptor is provided as follows "attn_layer": 2, "attn_mask": false, "attn_dotr": true, - "post_ln": true } ``` @@ -147,7 +146,6 @@ An example of the DPA-1 descriptor is provided as follows - {ref}`attn_layer ` sets the number of layers in attention mechanism. - {ref}`attn_mask ` determines whether to mask the diagonal in the attention weights and False is recommended. - {ref}`attn_dotr ` determines whether to dot the relative coordinates on the attention weights as a gated scheme, True is recommended. -- {ref}`post_ln ` determines whether to perform post layer norm. ::: diff --git a/examples/water/se_atten/input_torch.json b/examples/water/se_atten/input_torch.json index 4160feda17..cdb4b0db49 100644 --- a/examples/water/se_atten/input_torch.json +++ b/examples/water/se_atten/input_torch.json @@ -22,12 +22,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 1024, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, diff --git a/source/tests/common/dpmodel/test_descriptor_dpa1.py b/source/tests/common/dpmodel/test_descriptor_dpa1.py new file mode 100644 index 0000000000..d50e0cba86 --- /dev/null +++ b/source/tests/common/dpmodel/test_descriptor_dpa1.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptDPA1, +) + +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestDescrptDPA1(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + em0 = DescrptDPA1(self.rcut, self.rcut_smth, self.sel, ntypes=2) + em0.davg = davg + em0.dstd = dstd + em1 = DescrptDPA1.deserialize(em0.serialize()) + mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist) + mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii], mm1[ii]) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index ef7b39b52e..9c8c1cea7f 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -57,7 +57,9 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): t_mesh: make_default_mesh(True, False), } - def eval_dp_descriptor(self, dp_obj: Any, natoms, coords, atype, box) -> Any: + def eval_dp_descriptor( + self, dp_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( coords.reshape(1, -1, 3), atype.reshape(1, -1), @@ -70,11 +72,13 @@ def eval_dp_descriptor(self, dp_obj: Any, natoms, coords, atype, box) -> Any: natoms[0], dp_obj.get_rcut(), dp_obj.get_sel(), - distinguish_types=True, + distinguish_types=(not mixed_types), ) return dp_obj(ext_coords, ext_atype, nlist=nlist) - def eval_pt_descriptor(self, pt_obj: Any, natoms, coords, atype, box) -> Any: + def eval_pt_descriptor( + self, pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), @@ -87,7 +91,7 @@ def eval_pt_descriptor(self, pt_obj: Any, natoms, coords, atype, box) -> Any: natoms[0], pt_obj.get_rcut(), pt_obj.get_sel(), - distinguish_types=True, + distinguish_types=(not mixed_types), ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py new file mode 100644 index 0000000000..c0ca46c91e --- /dev/null +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PT +else: + DescrptDPA1PT = None +if INSTALLED_TF: + from deepmd.tf.descriptor.se_atten import DescrptDPA1Compat as DescrptDPA1TF +else: + DescrptDPA1TF = None +from deepmd.utils.argcheck import ( + descrpt_se_atten_args, +) + + +@parameterized( + (4,), # tebd_dim + ("concat",), # tebd_input_mode + (True,), # resnet_dt + (True, False), # type_one_side + (20,), # attn + (0, 2), # attn_layer + (True, False), # attn_dotr + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (1.0,), # scaling_factor + (True, False), # normalize + (None, 1.0), # temperature + (1e-5,), # ln_eps + (True, False), # smooth_type_embedding + (True, False), # concat_output_tebd + ("float64",), # precision +) +class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + return { + "sel": [10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "ntypes": self.ntypes, + "axis_neuron": 3, + "tebd_dim": tebd_dim, + "tebd_input_mode": tebd_input_mode, + "attn": attn, + "attn_layer": attn_layer, + "attn_dotr": attn_dotr, + "attn_mask": False, + "scaling_factor": scaling_factor, + "normalize": normalize, + "temperature": temperature, + "ln_eps": ln_eps, + "concat_output_tebd": concat_output_tebd, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "smooth_type_embedding": smooth_type_embedding, + "seed": 1145141919810, + } + + @property + def skip_pt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_tf(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + # TODO (excluded_types != [] and attn_layer > 0) need fix + return ( + env_protection != 0.0 + or smooth_type_embedding + or not normalize + or temperature != 1.0 + or (excluded_types != [] and attn_layer > 0) + ) + + tf_class = DescrptDPA1TF + dp_class = DescrptDPA1DP + pt_class = DescrptDPA1PT + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/models/dpa1.json b/source/tests/pt/model/models/dpa1.json index 5d2c65c214..1321acbd53 100644 --- a/source/tests/pt/model/models/dpa1.json +++ b/source/tests/pt/model/models/dpa1.json @@ -18,12 +18,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 10, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, diff --git a/source/tests/pt/model/models/dpa1.pth b/source/tests/pt/model/models/dpa1.pth index 75acf2fa15..47766b10c2 100644 Binary files a/source/tests/pt/model/models/dpa1.pth and b/source/tests/pt/model/models/dpa1.pth differ diff --git a/source/tests/pt/model/models/dpa2.pth b/source/tests/pt/model/models/dpa2.pth index 0559d30c48..26d6155272 100644 Binary files a/source/tests/pt/model/models/dpa2.pth and b/source/tests/pt/model/models/dpa2.pth differ diff --git a/source/tests/pt/model/models/dpa2_hyb.json b/source/tests/pt/model/models/dpa2_hyb.json index ee69ed4d69..f7d2234fae 100644 --- a/source/tests/pt/model/models/dpa2_hyb.json +++ b/source/tests/pt/model/models/dpa2_hyb.json @@ -22,12 +22,8 @@ "attn_layer": 0, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 10, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py new file mode 100644 index 0000000000..7567f18593 --- /dev/null +++ b/source/tests/pt/model/test_dpa1.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, sm, to in itertools.product( + [False, True], # resnet_dt + ["float64", "float32"], # precision + [False, True], # smooth_type_embedding + [False, True], # type_one_side + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + + # dpa1 new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + type_one_side=to, + old_impl=False, + ).to(env.DEVICE) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA1.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # old impl + if idt is False and prec == "float64" and to is False: + dd3 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.se_atten.state_dict() + dd3_state_dict = dd3.se_atten.state_dict() + + dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() + dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() + for i in dd3_state_dict: + dd3_state_dict[i] = ( + dd0_state_dict[ + i.replace(".deep_layers.", ".layers.") + .replace("filter_layers_old.", "filter_layers._networks.") + .replace( + ".attn_layer_norm.weight", ".attn_layer_norm.matrix" + ) + ] + .detach() + .clone() + ) + if ".bias" in i and "attn_layer_norm" not in i: + dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) + dd3.se_atten.load_state_dict(dd3_state_dict) + + dd0_state_dict_tebd = dd0.type_embedding.state_dict() + dd3_state_dict_tebd = dd3.type_embedding.state_dict() + for i in dd3_state_dict_tebd: + dd3_state_dict_tebd[i] = ( + dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] + .detach() + .clone() + ) + dd3.type_embedding.load_state_dict(dd3_state_dict_tebd) + + rd3, _, _, _, _ = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, sm, to in itertools.product( + [False, True], + ["float64", "float32"], + [False, True], + [False, True], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + type_one_side=to, + old_impl=False, + ) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + # dd1 = DescrptDPA1.deserialize(dd0.serialize()) + model = torch.jit.script(dd0) + # model = torch.jit.script(dd1) diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index e18093b2f1..24ed886b86 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -35,6 +35,8 @@ def setUp(self): self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] + self.natoms = [3, 3, 2, 1] self.nlist = np.array( [ [1, 3, -1, -1, -1, 2, -1], @@ -83,6 +85,8 @@ def setUp(self): self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] + self.natoms = [3, 3, 2, 1] self.nlist = np.array( [ [2, 4, -1, -1, -1, 3, -1], @@ -131,6 +135,8 @@ def setUp(self): self.cell = 2.0 * np.eye(3).reshape([1, 9]) # sel = [5, 2] self.sel = [16, 8] + self.sel_mix = [24] + self.natoms = [3, 3, 2, 1] self.rcut = 2.2 self.rcut_smth = 0.4 self.atol = 1e-12 diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 5e395eb8c0..110abd0d23 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -158,12 +158,8 @@ "attn_layer": 2, "attn_dotr": True, "attn_mask": False, - "post_ln": True, - "ffn": False, - "ffn_embed_dim": 512, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": False, "temperature": 1.0, "set_davg_zero": True, @@ -193,12 +189,8 @@ "attn_layer": 0, "attn_dotr": True, "attn_mask": False, - "post_ln": True, - "ffn": False, - "ffn_embed_dim": 1024, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": True, "temperature": 1.0, }, diff --git a/source/tests/pt/model/water/se_atten.json b/source/tests/pt/model/water/se_atten.json index 6b6fca50d3..71cee94d8b 100644 --- a/source/tests/pt/model/water/se_atten.json +++ b/source/tests/pt/model/water/se_atten.json @@ -21,12 +21,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 512, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": false, "temperature": 1.0 },