From 1f485b084c1a9e79ca2a2f864281b763e66689c6 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:07:27 +0800 Subject: [PATCH 01/18] Reformat DPA1 --- deepmd/dpmodel/descriptor/dpa1.py | 911 ++++++++++++++++++ deepmd/dpmodel/descriptor/dpa1_bk.py | 402 ++++++++ deepmd/dpmodel/utils/network.py | 148 +++ deepmd/pt/model/descriptor/dpa1.py | 283 +++++- deepmd/pt/model/descriptor/se_atten.py | 610 ++++++++++-- deepmd/pt/model/network/layernorm.py | 126 +++ deepmd/pt/model/network/network.py | 17 +- deepmd/tf/descriptor/se.py | 5 +- deepmd/tf/descriptor/se_atten.py | 738 +++++++++++++- deepmd/tf/env.py | 24 +- deepmd/tf/utils/graph.py | 10 +- deepmd/tf/utils/network.py | 266 +++++ deepmd/utils/argcheck.py | 41 +- doc/model/train-se-atten.md | 2 - examples/water/se_atten/input_torch.json | 4 - ...9545d53bb64e65febe2ff48926b4145285f3a.json | 11 + source/tests/consistent/descriptor/common.py | 12 +- .../tests/consistent/descriptor/test_dpa1.py | 258 +++++ source/tests/pt/model/models/dpa1.json | 4 - source/tests/pt/model/test_dpa1.py | 214 ++++ source/tests/pt/model/test_env_mat.py | 6 + source/tests/pt/model/test_permutation.py | 8 - source/tests/pt/model/water/se_atten.json | 4 - 23 files changed, 3868 insertions(+), 236 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/dpa1.py create mode 100644 deepmd/dpmodel/descriptor/dpa1_bk.py create mode 100644 deepmd/pt/model/network/layernorm.py create mode 100644 node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json create mode 100644 source/tests/consistent/descriptor/test_dpa1.py create mode 100644 source/tests/pt/model/test_dpa1.py diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py new file mode 100644 index 0000000000..d38879b62b --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -0,0 +1,911 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +from typing import ( + Any, + List, + Optional, + Tuple, + Union, +) + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.utils import ( + EmbeddingNet, + EnvMat, + NetworkCollection, + PairExcludeMask, +) + +from .base_descriptor import ( + BaseDescriptor, +) + + +def np_softmax(x, axis=-1): + e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + +def np_normalize(x, axis=-1): + return x / np.linalg.norm(x, axis=axis, keepdims=True) + + +@BaseDescriptor.register("se_atten") +@BaseDescriptor.register("dpa1") +class DescrptDPA1(NativeOP, BaseDescriptor): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of embedding net are trainable. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embdding: bool + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + spin + The old implementation of deepspin (deprecated in the descriptor). + + Limitations + ----------- + The currently implementation does not support the following features + + 1. type_one_side == True + 2. exclude_types != [] + 3. spin is not None + 4. tebd_input_mode != 'concat' + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = False, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth_type_embdding: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + ) -> None: + ## seed, uniform_seed, multi_task, not included. + if spin is not None: + raise NotImplementedError("old implementation of spin is not supported.") + # TODO + if tebd_input_mode != "concat": + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + + self.rcut = rcut + self.rcut_smth = rcut_smth + if isinstance(sel, int): + sel = [sel] + self.sel = sel + self.nnei = sum(sel) + self.ntypes = ntypes + self.neuron = neuron + self.filter_neuron = self.neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.resnet_dt = resnet_dt + self.trainable = trainable + self.type_one_side = type_one_side + self.attn = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.attn_mask = attn_mask + self.exclude_types = exclude_types + self.env_protection = env_protection + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.smooth = smooth_type_embdding + self.concat_output_tebd = concat_output_tebd + self.spin = spin + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + + self.type_embedding = TypeEmbedNet( + ntypes=self.ntypes, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + ) + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + in_dim = 1 + self.tebd_dim * 2 + else: + in_dim = 1 + self.tebd_dim + else: + in_dim = 1 + self.embeddings = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings[0] = EmbeddingNet( + in_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + smooth=self.smooth, + precision=self.precision, + ) + + wanted_shape = (self.ntypes, self.nnei, 4) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) + self.davg = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.dstd = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.orig_sel = self.sel + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.get_dim_out() + + def get_dim_out(self): + """Returns the output dimension of this descriptor.""" + return ( + self.neuron[-1] * self.axis_neuron + self.tebd_dim + if self.concat_output_tebd + else self.neuron[-1] * self.axis_neuron + ) + + def get_dim_emb(self): + """Returns the embedding (g2) dimension of this descriptor.""" + return self.neuron[-1] + + def get_rcut(self): + """Returns cutoff radius.""" + return self.rcut + + def get_sel(self): + """Returns cutoff radius.""" + return self.sel + + def mixed_types(self): + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def cal_g( + self, + ss, + embedding_idx, + ): + nfnl, nnei = ss.shape[0:2] + ss = ss.reshape(nfnl, nnei, -1) + # nfnl x nnei x ng + gg = self.embeddings[embedding_idx].call(ss) + return gg + + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + del mapping + # nf x nloc x nnei x 4 + dmatrix, sw = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd + ) + nf, nloc, nnei, _ = dmatrix.shape + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + # nfnl x nnei + nlist = nlist.reshape(nf * nloc, nnei) + # nfnl x nnei x 4 + dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + # nfnl x nnei x 1 + sw = sw.reshape(nf * nloc, nnei, 1) + + # add type embedding into input + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call()[atype_ext] + # nfnl x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1) + # nfnl x nnei x tebd_dim + atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) + # nfnl x nnei + nlist_mask = nlist != -1 + nlist_masked = np.copy(nlist) + # nfnl x nnei x 1 + sw = np.where(nlist_mask[:, :, None], sw, 0.0) + nlist_masked[nlist_masked == -1] = 0 + index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + # nfnl x nnei x tebd_dim + atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( + nf * nloc, nnei, self.tebd_dim + ) + + ng = self.neuron[-1] + # nfnl x nnei + exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + # nfnl x nnei x 4 + rr = dmatrix.reshape(nf * nloc, nnei, 4) + rr = rr * exclude_mask[:, :, None] + # nfnl x nnei x 1 + ss = rr[..., 0:1] + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + # nfnl x nnei x (1 + 2 * tebd_dim) + ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + else: + raise NotImplementedError + + # calculate gg + gg = self.cal_g(ss, 0) + input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / ( + np.linalg.norm( + dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True + ) + + 1e-12 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x ng x 4 + gr = np.einsum("lni,lnj->lij", gg, rr) + gr /= self.nnei + gr1 = gr[:, : self.axis_neuron, :] + # nfnl x ng x ng1 + grrg = np.einsum("lid,ljd->lij", gr, gr1) + # nf x nloc x (ng x ng1) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + # nf x nloc x (ng x ng1 + tebd_dim) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) + return grrg, gr[..., 1:], None, None, sw + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + return { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "neuron": self.neuron, + "axis_neuron": self.axis_neuron, + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "set_davg_zero": self.set_davg_zero, + "attn": self.attn, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": self.attn_mask, + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "smooth_type_embdding": self.smooth, + "type_one_side": self.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "embeddings": self.embeddings.serialize(), + "attention_layers": self.dpa1_attention.serialize(), + "env_mat": self.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "@variables": { + "davg": self.davg, + "dstd": self.dstd, + }, + ## to be updated when the options are supported. + "trainable": True, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.embeddings = NetworkCollection.deserialize(embeddings) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + +class NeighborGatedAttention(NativeOP): + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.smooth = smooth + self.precision = precision + self.network_type = NeighborGatedAttentionLayer + + self.attention_layers = [ + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + ) + for _ in range(layer_num) + ] + + def call( + self, + input_G, + nei_mask, + input_r: Optional[np.ndarray] = None, + sw: Optional[np.ndarray] = None, + ): + out = input_G + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def __getitem__(self, key): + if isinstance(key, int): + return self.attention_layers[key] + else: + raise TypeError(key) + + def __setitem__(self, key, value): + if not isinstance(key, int): + raise TypeError(key) + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[key] = value + + def serialize(self): + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "NeighborGatedAttention", + "@version": 1, + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + attention_layers = data.pop("attention_layers") + obj = cls(**data) + obj.attention_layers = [ + NeighborGatedAttentionLayer.deserialize(layer) for layer in attention_layers + ] + return obj + + +class NeighborGatedAttentionLayer(NativeOP): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + ) + self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) + + def call( + self, + x, + nei_mask, + input_r: Optional[np.ndarray] = None, + sw: Optional[np.ndarray] = None, + ): + residual = x + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize(), + } + + @classmethod + def deserialize(cls, data) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(NativeOP): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = NativeLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + precision=precision, + ) + self.out_proj = NativeLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + precision=precision, + ) + + def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): + # Linear projection + q, k, v = np.split(self.in_proj(query), 3, axis=-1) + # Reshape and normalize + q = q.reshape(-1, self.nnei, self.hidden_dim) + k = k.reshape(-1, self.nnei, self.hidden_dim) + v = v.reshape(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = np_normalize(q, axis=-1) + k = np_normalize(k, axis=-1) + v = np_normalize(v, axis=-1) + q = q * self.scaling + # Attention weights + attn_weights = q @ k.transpose(0, 2, 1) + nei_mask = nei_mask.reshape(-1, self.nnei) + if self.smooth: + sw = sw.reshape(-1, self.nnei) + attn_weights = (attn_weights + attnw_shift) * sw[:, None, :] * sw[ + :, :, None + ] - attnw_shift + else: + attn_weights = np.where(nei_mask[:, None, :], attn_weights, -np.inf) + attn_weights = np_softmax(attn_weights, axis=-1) + attn_weights = np.where(nei_mask[:, :, None], attn_weights, 0.0) + if self.smooth: + attn_weights = attn_weights * sw[:, None, :] * sw[:, :, None] + if self.dotr: + angular_weight = input_r @ input_r.transpose(0, 2, 1) + attn_weights = attn_weights * angular_weight + # Output projection + o = attn_weights @ v + output = self.out_proj(o) + return output + + def serialize(self): + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize(), + } + + @classmethod + def deserialize(cls, data): + data = data.copy() + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = NativeLayer.deserialize(in_proj) + obj.out_proj = NativeLayer.deserialize(out_proj) + return obj diff --git a/deepmd/dpmodel/descriptor/dpa1_bk.py b/deepmd/dpmodel/descriptor/dpa1_bk.py new file mode 100644 index 0000000000..3ca28a4fae --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa1_bk.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +import copy +from typing import ( + Any, + List, + Optional, +) + +from .common import ( + DEFAULT_PRECISION, + NativeOP, +) +from .env_mat import ( + EnvMat, +) +from .network import ( + EmbdLayer, + EmbeddingNet, + NetworkCollection, +) + + +class DescrptDPA1(NativeOP): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut + The cut-off radius :math:`r_c` + rcut_smth + From where the environment matrix should be smoothed :math:`r_s` + sel : list[str] + sel[i] specifies the maxmum number of type i atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + resnet_dt + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable + If the weights of embedding net are trainable. + type_one_side + Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + set_davg_zero + Set the shift of embedding net input to zero. + activation_function + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + temperature: Optional[float] + If not None, the scaling of attention weights is `temperature` itself. + spin + The deepspin object. + + Limitations + ----------- + The currently implementation does not support the following features + + 1. type_one_side == False + 2. exclude_types != [] + 3. spin is not None + 4. tebd_input_mode != 'concat' + 5. smooth == True + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: List[str], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize=True, + temperature=None, + smooth: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + ) -> None: + ## seed, uniform_seed, multi_task, not included. + if not type_one_side: + raise NotImplementedError("type_one_side == False not implemented") + if exclude_types != []: + raise NotImplementedError("exclude_types is not implemented") + if spin is not None: + raise NotImplementedError("spin is not implemented") + # TODO + if tebd_input_mode != "concat": + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + if not smooth: + raise NotImplementedError("smooth == False not implemented") + + self.rcut = rcut + self.rcut_smth = rcut_smth + if isinstance(sel, int): + sel = [sel] + self.sel = sel + self.ntypes = ntypes + self.neuron = neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.resnet_dt = resnet_dt + self.trainable = trainable + self.type_one_side = type_one_side + self.attn = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.attn_mask = attn_mask + self.exclude_types = exclude_types + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.concat_output_tebd = concat_output_tebd + self.spin = spin + + self.type_embedding = EmbdLayer( + ntypes, tebd_dim, padding=True, precision=precision + ) + in_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ["concat"] else 1 + self.embeddings = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings[0] = EmbeddingNet( + in_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + # self.dpa1_attention = NeighborGatedAttention + self.env_mat = EnvMat(self.rcut, self.rcut_smth) + self.nnei = np.sum(self.sel) + self.davg = np.zeros([self.ntypes, self.nnei, 4]) + self.dstd = np.ones([self.ntypes, self.nnei, 4]) + self.orig_sel = self.sel + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return ( + self.neuron[-1] * self.axis_neuron + self.tebd_dim * 2 + if self.concat_output_tebd + else self.neuron[-1] * self.axis_neuron + ) + + def cal_g( + self, + ss, + ll, + ): + nf, nloc, nnei = ss.shape[0:3] + ss = ss.reshape(nf, nloc, nnei, -1) + # nf x nloc x nnei x ng + gg = self.embeddings[ll].call(ss) + return gg + + def call( + self, + coord_ext, + atype_ext, + nlist, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + # nf x nloc x nnei x 4 + rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) + nf, nloc, nnei, _ = rr.shape + + # add type embedding into input + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call(atype_ext) + atype_embd = atype_embd_ext[:, :nloc, :] + # nf x nloc x nnei x tebd_dim + atype_embd_nnei = np.tile(atype_embd[:, :, np.newaxis, :], (1, 1, nnei, 1)) + nlist_mask = nlist != -1 + nlist_masked = np.copy(nlist) + nlist_masked[nlist_masked == -1] = 0 + index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + # nf x nloc x nnei x tebd_dim + atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( + nf, nloc, nnei, self.tebd_dim + ) + ng = self.neuron[-1] + ss = rr[..., 0:1] + ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + + # calculate gg + gg = self.cal_g(ss, 0) + # nf x nloc x ng x 4 + gr = np.einsum("flni,flnj->flij", gg, rr) + # nf x nloc x ng x 4 + gr /= self.nnei + gr1 = gr[:, :, : self.axis_neuron, :] + # nf x nloc x ng x ng1 + grrg = np.einsum("flid,fljd->flij", gr, gr1) + # nf x nloc x (ng x ng1) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd], axis=-1) + return grrg, gr[..., 1:], None, None, ww + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + return { + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "neuron": self.neuron, + "axis_neuron": self.axis_neuron, + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "resnet_dt": self.resnet_dt, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "exclude_types": self.exclude_types, + "set_davg_zero": self.set_davg_zero, + "attn": self.attn, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": self.attn_mask, + "activation_function": self.activation_function, + "precision": self.precision, + "spin": self.spin, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "concat_output_tebd": self.concat_output_tebd, + "embeddings": self.embeddings.serialize(), + # "attention_layers": self.dpa1_attention.serialize(), + "env_mat": self.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "@variables": { + "davg": self.davg, + "dstd": self.dstd, + }, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = copy.deepcopy(data) + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers", None) + env_mat = data.pop("env_mat") + obj = cls(**data) + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.type_embedding = EmbdLayer.deserialize(type_embedding) + obj.embeddings = NetworkCollection.deserialize(embeddings) + obj.env_mat = EnvMat.deserialize(env_mat) + # obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 661358ed70..88e97ee3c4 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -381,6 +381,154 @@ def fn(x): raise NotImplementedError(activation_function) +class LayerNorm(NativeLayer): + """Implementation of Layer Normalization layer. + + Parameters + ---------- + num_in : int + The input dimension of the layer. + eps : float, optional + A small value added to prevent division by zero in calculations. + uni_init : bool, optional + If initialize the weights to be zeros and ones. + """ + + def __init__( + self, + num_in: int, + eps: float = 1e-5, + uni_init: bool = True, + precision: str = DEFAULT_PRECISION, + ) -> None: + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) + self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] + if self.uni_init: + self.w = np.ones_like(self.w) + self.b = np.zeros_like(self.b) + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + data = { + "w": self.w, + "b": self.b, + } + return { + "@class": "LayerNorm", + "@version": 1, + "eps": self.eps, + "precision": self.precision, + "@variables": data, + } + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + variables = data.pop("@variables") + if variables["w"] is not None: + assert len(variables["w"].shape) == 1 + if variables["b"] is not None: + assert len(variables["b"].shape) == 1 + (num_in,) = variables["w"].shape + obj = cls( + num_in, + **data, + ) + (obj.w,) = (variables["w"],) + (obj.b,) = (variables["b"],) + obj._check_shape_consistency() + return obj + + def _check_shape_consistency(self): + if self.b is not None and self.w.shape[0] != self.b.shape[0]: + raise ValueError( + f"dim 1 of w {self.w.shape[0]} is not equal to shape " + f"of b {self.b.shape[0]}", + ) + + def __setitem__(self, key, value): + if key in ("w", "matrix"): + self.w = value + elif key in ("b", "bias"): + self.b = value + elif key == "precision": + self.precision = value + elif key == "eps": + self.eps = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("w", "matrix"): + return self.w + elif key in ("b", "bias"): + return self.b + elif key == "precision": + return self.precision + elif key == "eps": + return self.eps + else: + raise KeyError(key) + + def dim_out(self) -> int: + return self.w.shape[0] + + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + if self.w is None or self.b is None: + raise ValueError("w/b must be set") + y = self.layer_norm_numpy(x, (self.num_in,), self.w, self.b, self.eps) + return y + + @staticmethod + def layer_norm_numpy(x, shape, weight, bias, eps): + # mean and variance + mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + # normalize + x_normalized = (x - mean) / np.sqrt(var + eps) + # shift and scale + x_ln = x_normalized * weight + bias + return x_ln + + def make_multilayer_network(T_NetworkLayer, ModuleBase): class NN(ModuleBase): """Native representation of a neural network. diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 21275317dc..9f03edce12 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -9,8 +9,19 @@ import torch +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( + NetworkCollection, +) from deepmd.pt.model.network.network import ( TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.update_sel import ( UpdateSel, @@ -18,23 +29,157 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .base_descriptor import ( BaseDescriptor, ) from .se_atten import ( DescrptBlockSeAtten, + NeighborGatedAttention, ) @BaseDescriptor.register("dpa1") @BaseDescriptor.register("se_atten") class DescrptDPA1(BaseDescriptor, torch.nn.Module): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of embedding net are trainable. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embdding: bool + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + spin + The old implementation of deepspin (deprecated in the descriptor). + + Limitations + ----------- + The currently implementation does not support the following features + + 1. exclude_types != [] + 2. spin is not None + 3. tebd_input_mode != 'concat' + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + def __init__( self, - rcut, - rcut_smth, - sel, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, @@ -46,39 +191,31 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation_function="tanh", - scaling_factor=1.0, - head_num=1, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + scaling_factor: int = 1.0, normalize=True, temperature=None, - return_rot=False, concat_output_tebd: bool = True, - env_protection: float = 0.0, - type: Optional[str] = None, - # not implemented - resnet_dt: bool = False, - type_one_side: bool = True, - precision: str = "default", trainable: bool = True, - exclude_types: List[Tuple[int, int]] = [], + smooth_type_embdding: bool = True, + type_one_side: bool = False, + # not implemented stripped_type_embedding: bool = False, - smooth_type_embdding: bool = False, + spin=None, + type: Optional[str] = None, + seed: Optional[int] = None, + old_impl: bool = False, ): super().__init__() - if resnet_dt: - raise NotImplementedError("resnet_dt is not supported.") - if not type_one_side: - raise NotImplementedError("type_one_side is not supported.") - if precision != "default" and precision != "float64": - raise NotImplementedError("precison is not supported.") if stripped_type_embedding: raise NotImplementedError("stripped_type_embedding is not supported.") - if smooth_type_embdding: - raise NotImplementedError("smooth_type_embdding is not supported.") - del type + if spin is not None: + raise NotImplementedError("old implementation of spin is not supported.") + del type, spin self.se_atten = DescrptBlockSeAtten( rcut, rcut_smth, @@ -93,19 +230,19 @@ def __init__( attn_layer=attn_layer, attn_dotr=attn_dotr, attn_mask=attn_mask, - post_ln=post_ln, - ffn=ffn, - ffn_embed_dim=ffn_embed_dim, activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, scaling_factor=scaling_factor, - head_num=head_num, normalize=normalize, temperature=temperature, - return_rot=return_rot, + smooth=smooth_type_embdding, + type_one_side=type_one_side, exclude_types=exclude_types, env_protection=env_protection, + old_impl=old_impl, ) - self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd # set trainable @@ -204,14 +341,84 @@ def compute_input_stats( """ return self.se_atten.compute_input_stats(merged, path) + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.se_atten.mean = mean + self.se_atten.stddev = stddev + def serialize(self) -> dict: - """Serialize the obj to dict.""" - raise NotImplementedError + obj = self.se_atten + return { + "@class": "Descriptor", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn_dim, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": obj.attn_mask, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "smooth_type_embdding": obj.smooth, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "spin": None, + } @classmethod - def deserialize(cls) -> "DescrptDPA1": - """Deserialize from a dict.""" - raise NotImplementedError + def deserialize(cls, data: dict) -> "DescrptDPA1": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_atten["davg"] = t_cvt(variables["davg"]) + obj.se_atten["dstd"] = t_cvt(variables["dstd"]) + obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj def forward( self, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 051c66385c..cfd0a7f95d 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -10,6 +10,8 @@ import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as torch_func from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, @@ -17,6 +19,14 @@ from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat, ) +from deepmd.pt.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + MLPLayer, + NetworkCollection, +) from deepmd.pt.model.network.network import ( NeighborWiseAttention, TypeFilter, @@ -24,6 +34,10 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -36,15 +50,18 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) @DescriptorBlock.register("se_atten") class DescrptBlockSeAtten(DescriptorBlock): def __init__( self, - rcut, - rcut_smth, - sel, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, @@ -56,18 +73,18 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, + smooth: bool = True, + type_one_side: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, type: Optional[str] = None, + old_impl: bool = False, ): """Construct an embedding net of type `se_atten`. @@ -82,7 +99,8 @@ def __init__( del type self.rcut = rcut self.rcut_smth = rcut_smth - self.filter_neuron = neuron + self.neuron = neuron + self.filter_neuron = self.neuron self.axis_neuron = axis_neuron self.tebd_dim = tebd_dim self.tebd_input_mode = tebd_input_mode @@ -91,18 +109,17 @@ def __init__( self.attn_layer = attn_layer self.attn_dotr = attn_dotr self.attn_mask = attn_mask - self.post_ln = post_ln - self.ffn = ffn - self.ffn_embed_dim = ffn_embed_dim - self.activation = activation_function - # TODO: To be fixed: precision should be given from inputs - self.prec = torch.float64 + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt self.scaling_factor = scaling_factor - self.head_num = head_num self.normalize = normalize self.temperature = temperature - self.return_rot = return_rot + self.smooth = smooth + self.type_one_side = type_one_side self.env_protection = env_protection + self.old_impl = old_impl if isinstance(sel, int): sel = [sel] @@ -115,22 +132,34 @@ def __init__( self.ndescrpt = self.nnei * 4 # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) - self.dpa1_attention = NeighborWiseAttention( - self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - post_ln=self.post_ln, - ffn=self.ffn, - ffn_embed_dim=self.ffn_embed_dim, - activation=self.activation, - scaling_factor=self.scaling_factor, - head_num=self.head_num, - normalize=self.normalize, - temperature=self.temperature, - ) + if self.old_impl: + self.dpa1_attention = NeighborWiseAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + activation=self.activation_function, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + smooth=self.smooth, + ) + else: + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + smooth=self.smooth, + precision=self.precision, + ) wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros( @@ -141,19 +170,41 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - - filter_layers = [] - one = TypeFilter( - 0, - self.nnei, - self.filter_neuron, - return_G=True, - tebd_dim=self.tebd_dim, - use_tebd=True, - tebd_mode=self.tebd_input_mode, - ) - filter_layers.append(one) - self.filter_layers = torch.nn.ModuleList(filter_layers) + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + self.embd_input_dim = 1 + self.tebd_dim * 2 + else: + self.embd_input_dim = 1 + self.tebd_dim + else: + self.embd_input_dim = 1 + + self.filter_layers_old = None + self.filter_layers = None + if self.old_impl: + filter_layers = [] + one = TypeFilter( + 0, + self.nnei, + self.filter_neuron, + return_G=True, + tebd_dim=self.tebd_dim, + use_tebd=True, + tebd_mode=self.tebd_input_mode, + ) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + else: + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers self.stats = None def get_rcut(self) -> float: @@ -184,6 +235,22 @@ def get_dim_emb(self) -> int: """Returns the output dimension of embedding.""" return self.filter_neuron[-1] + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; @@ -272,7 +339,7 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: + ): """Calculate decoded embedding for each atom. Args: @@ -302,8 +369,6 @@ def forward( self.rcut_smth, protection=self.env_protection, ) - # [nfxnlocxnnei, self.ndescrpt] - dmatrix = dmatrix.view(-1, self.ndescrpt) nlist_mask = nlist != -1 nlist[nlist == -1] = 0 sw = torch.squeeze(sw, -1) @@ -321,23 +386,60 @@ def forward( atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) - ret = self.filter_layers[0]( - dmatrix, - atype_tebd=atype_tebd_nnei, - nlist_tebd=atype_tebd_nlist, - ) # shape is [nframes*nall, self.neei, out_size] - input_r = torch.nn.functional.normalize( - dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - ret = self.dpa1_attention( - ret, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( - 0, 2, 1 - ) # shape is [nframes*natoms[0], 4, self.neei] - xyz_scatter = torch.matmul( - inputs_reshape, ret - ) # shape is [nframes*natoms[0], 4, out_size] + # (nb x nloc) x nnei + exclude_mask = self.emask(nlist, extended_atype).view(nb * nloc, nnei) + if self.old_impl: + assert self.filter_layers_old is not None + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] + gg = self.filter_layers_old[0]( + dmatrix, + atype_tebd=atype_tebd_nnei, + nlist_tebd=atype_tebd_nlist, + ) # shape is [nframes*nall, self.neei, out_size] + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( + 0, 2, 1 + ) # shape is [nframes*natoms[0], 4, self.neei] + xyz_scatter = torch.matmul( + inputs_reshape, gg + ) # shape is [nframes*natoms[0], 4, out_size] + else: + assert self.filter_layers is not None + # nfnl x nnei x 4 + dmatrix = dmatrix.view(-1, self.nnei, 4) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None] + ss = rr[:, :, :1] + if self.tebd_input_mode in ["concat"]: + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) + if not self.type_one_side: + # nfnl x nnei x (1 + tebd_dim * 2) + ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = torch.concat([ss, nlist_tebd], dim=2) + else: + raise NotImplementedError + # nfnl x nnei x ng + gg = self.filter_layers._networks[0](ss) + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] @@ -347,13 +449,387 @@ def forward( ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] return ( result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), + gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]), dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], rot_mat.view(-1, nloc, self.filter_neuron[-1], 3), sw, ) +class NeighborGatedAttention(nn.Module): + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.smooth = smooth + self.precision = precision + self.network_type = NeighborGatedAttentionLayer + attention_layers = [] + for i in range(self.layer_num): + attention_layers.append( + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=self.smooth, + precision=precision, + ) + ) + self.attention_layers = nn.ModuleList(attention_layers) + + def forward( + self, + input_G, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + """ + Args: + input_G: Input G, [nframes * nloc, nnei, embed_dim]. + nei_mask: neighbor mask, [nframes * nloc, nnei]. + input_r: normalized radial, [nframes, nloc, nei, 3]. + + Returns + ------- + out: Output G, [nframes * nloc, nnei, embed_dim] + """ + out = input_G + # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def __getitem__(self, key): + if isinstance(key, int): + return self.attention_layers[key] + else: + raise TypeError(key) + + def __setitem__(self, key, value): + if not isinstance(key, int): + raise TypeError(key) + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[key] = value + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "NeighborGatedAttention", + "@version": 1, + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + attention_layers = data.pop("attention_layers") + obj = cls(**data) + for ii, network in enumerate(attention_layers): + obj[ii] = network + return obj + + +class NeighborGatedAttentionLayer(nn.Module): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + ) + self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) + + def forward( + self, + x, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + residual = x + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(nn.Module): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = MLPLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) + self.out_proj = MLPLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + ) + + def forward( + self, + query, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + attnw_shift: float = 20.0, + ): + """ + Args: + query: input G, [nframes * nloc, nnei, embed_dim]. + nei_mask: neighbor mask, [nframes * nloc, nnei]. + input_r: normalized radial, [nframes, nloc, nei, 3]. + + Returns + ------- + type_embedding: + """ + q, k, v = self.in_proj(query).chunk(3, dim=-1) + # [nframes * nloc, nnei, hidden_dim] + q = q.view(-1, self.nnei, self.hidden_dim) + k = k.view(-1, self.nnei, self.hidden_dim) + v = v.view(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = torch_func.normalize(q, dim=-1) + k = torch_func.normalize(k, dim=-1) + v = torch_func.normalize(v, dim=-1) + q = q * self.scaling + k = k.transpose(1, 2) + # [nframes * nloc, nnei, nnei] + attn_weights = torch.bmm(q, k) + # [nframes * nloc, nnei] + nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: + # [nframes * nloc, nnei] + assert sw is not None + sw = sw.view([-1, self.nnei]) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ + :, None, : + ] - attnw_shift + else: + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1), float("-inf") + ) + attn_weights = torch_func.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) + if self.smooth: + assert sw is not None + attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + if self.dotr: + assert input_r is not None, "input_r must be provided when dotr is True!" + angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + attn_weights = attn_weights * angular_weight + o = torch.bmm(attn_weights, v) + output = self.out_proj(o) + return output + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} + # network_type_name = network_type_map_inv[self.network_type] + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "GatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = MLPLayer.deserialize(in_proj) + obj.out_proj = MLPLayer.deserialize(out_proj) + return obj + + def analyze_descrpt(matrix, ndescrpt, natoms, mixed_types=False, real_atype=None): """Collect avg, square avg and count of descriptors in a batch.""" ntypes = natoms.shape[1] - 2 diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py new file mode 100644 index 0000000000..efb4836db7 --- /dev/null +++ b/deepmd/pt/model/network/layernorm.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.network import LayerNorm as DPLayerNorm +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .mlp import ( + MLPLayer, +) + +device = env.DEVICE + + +class LayerNorm(MLPLayer): + def __init__( + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + bavg: float = 0.0, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, + ): + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__( + num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + bavg=bavg, + stddev=stddev, + precision=precision, + ) + self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) + if self.uni_init: + nn.init.ones_(self.matrix.data) + nn.init.zeros_(self.bias.data) + + def dim_out(self) -> int: + return self.matrix.shape[0] + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """One Layer Norm used by DP model. + + Parameters + ---------- + xx : torch.Tensor + The input of index. + + Returns + ------- + yy: torch.Tensor + The output. + """ + mean = xx.mean(dim=-1, keepdim=True) + variance = xx.var(dim=-1, unbiased=False, keepdim=True) + yy = (xx - mean) / torch.sqrt(variance + self.eps) + if self.matrix is not None: + yy = yy * self.matrix + if self.bias is not None: + yy = yy + self.bias + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = DPLayerNorm( + self.matrix.shape[0], + eps=self.eps, + precision=self.precision, + ) + nl.w = to_numpy_array(self.matrix) + nl.b = to_numpy_array(self.bias) + data = nl.serialize() + return data + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + nl = DPLayerNorm.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + eps=nl["eps"], + precision=nl["precision"], + ) + prec = PRECISION_DICT[obj.precision] + + def check_load_param(ss): + return ( + nn.Parameter(data=to_torch_tensor(nl[ss])) + if nl[ss] is not None + else None + ) + + obj.matrix = check_load_param("matrix") + obj.bias = check_load_param("bias") + return obj diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index c895f642e1..09d9945b3b 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -556,15 +556,19 @@ def forward(self, inputs): class TypeEmbedNet(nn.Module): - def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0): + def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0, precision="default"): """Construct a type embedding net.""" super().__init__() + self.type_nums = type_nums + self.embed_dim = embed_dim + self.bavg = bavg + self.stddev = stddev self.embedding = TypeEmbedNetConsistent( - ntypes=type_nums, - neuron=[embed_dim], + ntypes=self.type_nums, + neuron=[self.embed_dim], padding=True, activation_function="Linear", - precision="default", + precision=precision, ) # nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) @@ -847,6 +851,7 @@ def __init__( head_num=1, normalize=True, temperature=None, + smooth=True, ): """Construct a neighbor-wise attention net.""" super().__init__() @@ -868,6 +873,7 @@ def __init__( head_num=head_num, normalize=normalize, temperature=temperature, + smooth=smooth, ) ) self.attention_layers = nn.ModuleList(attention_layers) @@ -915,6 +921,7 @@ def __init__( head_num=1, normalize=True, temperature=None, + smooth=True, ): """Construct a neighbor-wise attention layer.""" super().__init__() @@ -925,6 +932,7 @@ def __init__( self.do_mask = do_mask self.post_ln = post_ln self.ffn = ffn + self.smooth = smooth self.attention_layer = GatedSelfAttetion( nnei, embed_dim, @@ -935,6 +943,7 @@ def __init__( head_num=head_num, normalize=normalize, temperature=temperature, + smooth=smooth, ) self.attn_layer_norm = nn.LayerNorm( self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE diff --git a/deepmd/tf/descriptor/se.py b/deepmd/tf/descriptor/se.py index 4232503464..5a3c5dd447 100644 --- a/deepmd/tf/descriptor/se.py +++ b/deepmd/tf/descriptor/se.py @@ -296,7 +296,10 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict: net_idx.append(rest_ii % embeddings.ntypes) rest_ii //= embeddings.ntypes net_idx = tuple(net_idx) - if embeddings.ndim in (0, 1): + if embeddings.ndim == 0: + key0 = "all" + key1 = "" + elif embeddings.ndim == 1: key0 = "all" key1 = f"_{ii}" elif embeddings.ndim == 2: diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3ca763870b..249337e162 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import re import warnings from typing import ( + Any, List, Optional, Set, Tuple, + Union, ) import numpy as np @@ -13,11 +16,19 @@ Version, ) +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) +from deepmd.dpmodel.utils.network import ( + LayerNorm, + NativeLayer, +) from deepmd.tf.common import ( cast_precision, get_np_precision, ) from deepmd.tf.env import ( + ATTENTION_LAYER_PATTERN, GLOBAL_NP_FLOAT_PRECISION, GLOBAL_TF_FLOAT_PRECISION, TF_VERSION, @@ -50,6 +61,7 @@ ) from deepmd.tf.utils.network import ( embedding_net, + layernorm, one_layer, ) from deepmd.tf.utils.sess import ( @@ -58,9 +70,15 @@ from deepmd.tf.utils.tabulate import ( DPTabulate, ) +from deepmd.tf.utils.type_embed import ( + TypeEmbedNet, +) from deepmd.tf.utils.update_sel import ( UpdateSel, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .descriptor import ( Descriptor, @@ -83,8 +101,9 @@ class DescrptSeAtten(DescrptSeA): The cut-off radius :math:`r_c` rcut_smth From where the environment matrix should be smoothed :math:`r_s` - sel : int - sel[i] specifies the maxmum number of type i atoms in the cut-off radius + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius neuron : list[int] Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` axis_neuron @@ -123,7 +142,8 @@ class DescrptSeAtten(DescrptSeA): Whether to strip the type embedding into a separated embedding network. Default value will be True in `se_atten_v2` descriptor. smooth_type_embdding - When using stripped type embedding, whether to dot smooth factor on the network output of type embedding + Whether to use smooth process in attention weights calculation. + And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. Default value will be True in `se_atten_v2` descriptor. @@ -137,9 +157,9 @@ def __init__( self, rcut: float, rcut_smth: float, - sel: int, + sel: Union[List[int], int], ntypes: int, - neuron: List[int] = [24, 48, 96], + neuron: List[int] = [25, 50, 100], axis_neuron: int = 8, resnet_dt: bool = False, trainable: bool = True, @@ -158,15 +178,11 @@ def __init__( stripped_type_embedding: bool = False, smooth_type_embdding: bool = False, # not implemented - post_ln=True, - ffn=False, - ffn_embed_dim=1024, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, concat_output_tebd: bool = True, + env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: if not set_davg_zero and not (stripped_type_embedding and smooth_type_embdding): @@ -174,24 +190,18 @@ def __init__( "Set 'set_davg_zero' False in descriptor 'se_atten' " "may cause unexpected incontinuity during model inference!" ) - if not post_ln: - raise NotImplementedError("post_ln is not supported.") - if ffn: - raise NotImplementedError("ffn is not supported.") - if ffn_embed_dim != 1024: - raise NotImplementedError("ffn_embed_dim is not supported.") if scaling_factor != 1.0: raise NotImplementedError("scaling_factor is not supported.") - if head_num != 1: - raise NotImplementedError("head_num is not supported.") if not normalize: raise NotImplementedError("normalize is not supported.") if temperature is not None: raise NotImplementedError("temperature is not supported.") - if return_rot: - raise NotImplementedError("return_rot is not supported.") if not concat_output_tebd: raise NotImplementedError("concat_output_tebd is not supported.") + if env_protection != 0.0: + raise NotImplementedError("env_protection != 0.0 is not supported.") + if isinstance(sel, list): + sel = sum(sel) DescrptSeA.__init__( self, rcut, @@ -239,12 +249,12 @@ def __init__( std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( GLOBAL_NP_FLOAT_PRECISION ) - self.beta = np.zeros([self.attn_layer, self.filter_neuron[-1]]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) - self.gamma = np.ones([self.attn_layer, self.filter_neuron[-1]]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + # self.beta = np.zeros([self.attn_layer, self.filter_neuron[-1]]).astype( + # GLOBAL_NP_FLOAT_PRECISION + # ) + # self.gamma = np.ones([self.attn_layer, self.filter_neuron[-1]]).astype( + # GLOBAL_NP_FLOAT_PRECISION + # ) self.attention_layer_variables = None sub_graph = tf.Graph() with sub_graph.as_default(): @@ -1047,11 +1057,23 @@ def _attention_layers( uniform_seed=self.uniform_seed, initial_variables=self.attention_layer_variables, ) - input_xyz = tf.keras.layers.LayerNormalization( - beta_initializer=tf.constant_initializer(self.beta[i]), - gamma_initializer=tf.constant_initializer(self.gamma[i]), - dtype=self.filter_precision, - )(input_xyz) + input_xyz = layernorm( + input_xyz, + outputs_size[-1], + precision=self.filter_precision, + name="layer_normalization", + scope=name + "/", + reuse=tf.AUTO_REUSE, + seed=self.seed, + uniform_seed=self.uniform_seed, + trainable=trainable, + initial_variables=self.attention_layer_variables, + ) + # input_xyz = tf.keras.layers.LayerNormalization( + # beta_initializer=tf.constant_initializer(self.beta[i]), + # gamma_initializer=tf.constant_initializer(self.gamma[i]), + # dtype=self.filter_precision, + # )(input_xyz) # input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n) return input_xyz @@ -1360,20 +1382,27 @@ def init_variables( self.attention_layer_variables = get_attention_layer_variables_from_graph_def( graph_def, suffix=suffix ) - if self.attn_layer > 0: - self.beta[0] = self.attention_layer_variables[ - f"attention_layer_0{suffix}/layer_normalization/beta" - ] - self.gamma[0] = self.attention_layer_variables[ - f"attention_layer_0{suffix}/layer_normalization/gamma" - ] - for i in range(1, self.attn_layer): - self.beta[i] = self.attention_layer_variables[ - f"attention_layer_{i}{suffix}/layer_normalization_{i}/beta" - ] - self.gamma[i] = self.attention_layer_variables[ - f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma" - ] + # if self.attn_layer > 0: + # self.beta[0] = self.attention_layer_variables[ + # f"attention_layer_0{suffix}/layer_normalization/beta" + # ] + # self.gamma[0] = self.attention_layer_variables[ + # f"attention_layer_0{suffix}/layer_normalization/gamma" + # ] + # for i in range(1, self.attn_layer): + # self.beta[i] = self.attention_layer_variables[ + # f"attention_layer_{i}{suffix}/layer_normalization_{i}/beta" + # ] + # self.gamma[i] = self.attention_layer_variables[ + # f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma" + # ] + # for i in range(self.attn_layer): + # self.beta[i] = self.attention_layer_variables[ + # f"attention_layer_{i}{suffix}/layer_normalization/beta" + # ] + # self.gamma[i] = self.attention_layer_variables[ + # f"attention_layer_{i}{suffix}/layer_normalization/gamma" + # ] if self.stripped_type_embedding: self.two_side_embeeding_net_variables = ( @@ -1487,3 +1516,622 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): """ local_jdata_cpy = local_jdata.copy() return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True) + + def serialize_attention_layers( + self, + nlayer: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool, + do_mask: bool, + variables: dict, + bias: bool = True, + suffix: str = "", + ) -> dict: + data = { + "layer_num": nlayer, + "nnei": nnei, + "embed_dim": embed_dim, + "hidden_dim": hidden_dim, + "dotr": dotr, + "do_mask": do_mask, + "precision": self.precision.name, + "attention_layers": [], + } + if suffix != "": + attention_layer_pattern = ( + ATTENTION_LAYER_PATTERN.replace("/(c_query)", suffix + "/(c_query)") + .replace("/(c_key)", suffix + "/(c_key)") + .replace("/(c_value)", suffix + "/(c_value)") + .replace("/(c_out)", suffix + "/(c_out)") + .replace("/(layer_normalization)", suffix + "/(layer_normalization)") + ) + else: + attention_layer_pattern = ATTENTION_LAYER_PATTERN + attention_layer_params = [{} for _ in range(nlayer)] + for key, value in variables.items(): + m = re.search(attention_layer_pattern, key) + m = [mm for mm in m.groups() if mm is not None] + assert len(m) == 3 + if m[1] not in attention_layer_params[int(m[0])]: + attention_layer_params[int(m[0])][m[1]] = {} + attention_layer_params[int(m[0])][m[1]][m[2]] = value + + for layer_idx in range(nlayer): + in_proj = NativeLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + precision=self.precision.name, + ) + matrix_list = [ + attention_layer_params[layer_idx][key]["matrix"] + for key in ["c_query", "c_key", "c_value"] + ] + in_proj["matrix"] = np.concatenate(matrix_list, axis=-1) + if bias: + bias_list = [ + attention_layer_params[layer_idx][key]["bias"] + for key in ["c_query", "c_key", "c_value"] + ] + in_proj["bias"] = np.concatenate(bias_list, axis=-1) + out_proj = NativeLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + precision=self.precision.name, + ) + out_proj["matrix"] = attention_layer_params[layer_idx]["c_out"]["matrix"] + if bias: + out_proj["bias"] = attention_layer_params[layer_idx]["c_out"]["bias"] + + layer_norm = LayerNorm( + embed_dim, + precision=self.precision.name, + ) + layer_norm["matrix"] = attention_layer_params[layer_idx][ + "layer_normalization" + ]["gamma"] + layer_norm["bias"] = attention_layer_params[layer_idx][ + "layer_normalization" + ]["beta"] + data["attention_layers"].append( + { + "attention_layer": { + "in_proj": in_proj.serialize(), + "out_proj": out_proj.serialize(), + "bias": bias, + "smooth": self.smooth, + }, + "attn_layer_norm": layer_norm.serialize(), + } + ) + return data + + @classmethod + def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: + """Deserialize attention layers. + + Parameters + ---------- + data : dict + The input attention layer data + suffix : str, optional + The suffix of the scope + + Returns + ------- + variables : dict + The input variables + """ + attention_layer_variables = {} + nlayer = data["layer_num"] + hidden_dim = data["hidden_dim"] + + for layer_idx in range(nlayer): + in_proj = NativeLayer.deserialize( + data["attention_layers"][layer_idx]["attention_layer"]["in_proj"] + ) + out_proj = NativeLayer.deserialize( + data["attention_layers"][layer_idx]["attention_layer"]["out_proj"] + ) + layer_norm = LayerNorm.deserialize( + data["attention_layers"][layer_idx]["attn_layer_norm"] + ) + + # Deserialize in_proj + c_query_matrix = in_proj["matrix"][:, :hidden_dim] + c_key_matrix = in_proj["matrix"][:, hidden_dim : 2 * hidden_dim] + c_value_matrix = in_proj["matrix"][:, 2 * hidden_dim :] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_query/matrix" + ] = c_query_matrix + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_key/matrix" + ] = c_key_matrix + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_value/matrix" + ] = c_value_matrix + if data["attention_layers"][layer_idx]["attention_layer"]["bias"]: + c_query_bias = in_proj["bias"][:hidden_dim] + c_key_bias = in_proj["bias"][hidden_dim : 2 * hidden_dim] + c_value_bias = in_proj["bias"][2 * hidden_dim :] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_query/bias" + ] = c_query_bias + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_key/bias" + ] = c_key_bias + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_value/bias" + ] = c_value_bias + + # Deserialize out_proj + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_out/matrix" + ] = out_proj["matrix"] + if data["attention_layers"][layer_idx]["attention_layer"]["bias"]: + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/c_out/bias" + ] = out_proj["bias"] + + # Deserialize layer_norm + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/layer_normalization/beta" + ] = layer_norm["bias"] + attention_layer_variables[ + f"attention_layer_{layer_idx}{suffix}/layer_normalization/gamma" + ] = layer_norm["matrix"] + return attention_layer_variables + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeAtten: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + if type(self) not in [DescrptSeAtten, DescrptDPA1Compat]: + raise NotImplementedError( + "Not implemented in class %s" % self.__class__.__name__ + ) + if self.stripped_type_embedding: + raise NotImplementedError( + "stripped_type_embedding is unsupported by the native model" + ) + if (self.original_sel != self.sel_a).any(): + raise NotImplementedError( + "Adjusting sel is unsupported by the native model" + ) + if self.embedding_net_variables is None: + raise RuntimeError("init_variables must be called before serialize") + if self.spin is not None: + raise NotImplementedError("spin is unsupported") + assert self.davg is not None + assert self.dstd is not None + + return { + "@class": "Descriptor", + "type": "se_atten", + "@version": 1, + "rcut": self.rcut_r, + "rcut_smth": self.rcut_r_smth, + "sel": self.sel_a, + "ntypes": self.ntypes, + "neuron": self.filter_neuron, + "axis_neuron": self.n_axis_neuron, + "set_davg_zero": self.set_davg_zero, + "attn": self.att_n, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": self.attn_mask, + "activation_function": self.activation_function_name, + "resnet_dt": self.filter_resnet_dt, + "smooth_type_embdding": self.smooth, + "precision": self.filter_precision.name, + "embeddings": self.serialize_network( + ntypes=self.ntypes, + ndim=0, + in_dim=1 + if not hasattr(self, "embd_input_dim") + else self.embd_input_dim, + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.embedding_net_variables, + excluded_types=self.exclude_types, + suffix=suffix, + ), + "attention_layers": self.serialize_attention_layers( + nlayer=self.attn_layer, + nnei=self.nnei_a, + embed_dim=self.filter_neuron[-1], + hidden_dim=self.att_n, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + variables=self.attention_layer_variables, + suffix=suffix, + ), + "env_mat": EnvMat(self.rcut_r, self.rcut_r_smth).serialize(), + "exclude_types": list(self.orig_exclude_types), + "env_protection": self.env_protection, + "@variables": { + "davg": self.davg.reshape(self.ntypes, self.nnei_a, 4), + "dstd": self.dstd.reshape(self.ntypes, self.nnei_a, 4), + }, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "spin": self.spin, + } + + +class DescrptDPA1Compat(DescrptSeAtten): + r"""Consistent version of the model for testing with other backend references. + + This model includes the type_embedding as attributes and other additional parameters. + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Only support `concat` in this version. + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of embedding net are trainable. + type_one_side: bool + Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + Not supported in this version. + normalize: bool + Not supported in this version. + temperature: float + Not supported in this version. + smooth_type_embdding: bool + Whether to use smooth process in attention weights calculation. Only support True in this version. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. Only support True in this version. + spin + The old implementation of deepspin (deprecated in the descriptor). Not supported in this version. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "default", + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth_type_embdding: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + uniform_seed: bool = False, + ) -> None: + if tebd_input_mode != "concat": + raise NotImplementedError( + "Only support tebd_input_mode == `concat` in this version." + ) + if not normalize: + raise NotImplementedError("Only support normalize == True in this version.") + if temperature != 1.0: + raise NotImplementedError( + "Only support temperature == 1.0 in this version." + ) + if not concat_output_tebd: + raise NotImplementedError( + "Only support concat_output_tebd == True in this version." + ) + if spin is not None: + raise NotImplementedError("Only support spin is None in this version.") + + super().__init__( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + resnet_dt=resnet_dt, + trainable=trainable, + seed=seed, + type_one_side=type_one_side, + set_davg_zero=set_davg_zero, + exclude_types=exclude_types, + activation_function=activation_function, + precision=precision, + uniform_seed=uniform_seed, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=attn_mask, + multi_task=True, + stripped_type_embedding=False, + smooth_type_embdding=smooth_type_embdding, + env_protection=env_protection, + ) + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.type_embedding = TypeEmbedNet( + ntypes=self.ntypes, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + # precision=precision, + ) + self.concat_output_tebd = concat_output_tebd + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + self.embd_input_dim = 1 + self.tebd_dim * 2 + else: + self.embd_input_dim = 1 + self.tebd_dim + else: + self.embd_input_dim = 1 + + def build( + self, + coord_: tf.Tensor, + atype_: tf.Tensor, + natoms: tf.Tensor, + box_: tf.Tensor, + mesh: tf.Tensor, + input_dict: dict, + reuse: Optional[bool] = None, + suffix: str = "", + ) -> tf.Tensor: + type_embedding = self.type_embedding.build(self.ntypes, suffix=suffix) + input_dict["type_embedding"] = type_embedding + + # nf x nloc x out_dim + self.dout = super().build( + coord_, + atype_, + natoms, + box_, + mesh, + input_dict, + reuse=reuse, + suffix=suffix, + ) + # self.dout = tf.cast(self.dout, self.filter_precision) + if self.concat_output_tebd: + atype = tf.reshape(atype_, [-1, natoms[1]]) + atype_nloc = tf.reshape( + tf.slice(atype, [0, 0], [-1, natoms[0]]), [-1] + ) ## lammps will have error without this + atom_embed = tf.reshape( + tf.nn.embedding_lookup(type_embedding, atype_nloc), + [-1, natoms[0], self.tebd_dim], + ) + atom_embed = tf.cast(atom_embed, GLOBAL_TF_FLOAT_PRECISION) + # nf x nloc x (out_dim + tebd_dim) + self.dout = tf.concat([self.dout, atom_embed], axis=-1) + return self.dout + + def init_variables( + self, + graph: tf.Graph, + graph_def: tf.GraphDef, + suffix: str = "", + ) -> None: + """Init the embedding net variables with the given dict. + + Parameters + ---------- + graph : tf.Graph + The input frozen model graph + graph_def : tf.GraphDef + The input frozen model graph_def + suffix : str, optional + The suffix of the scope + """ + super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix) + self.type_embedding.init_variables( + graph=graph, graph_def=graph_def, suffix=suffix + ) + + def update_attention_layers_serialize(self, data: dict): + """Update the serialized data to be consistent with other backend references.""" + new_dict = { + "@class": "NeighborGatedAttention", + "@version": 1, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + } + new_dict.update(data) + update_info = { + "nnei": self.nnei_a, + "embed_dim": self.filter_neuron[-1], + "hidden_dim": self.att_n, + "dotr": self.attn_dotr, + "do_mask": self.attn_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.filter_precision.name, + } + for layer_idx in range(self.attn_layer): + new_dict["attention_layers"][layer_idx].update(update_info) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + update_info + ) + return new_dict + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptDPA1Compat: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + type_embedding = data.pop("type_embedding") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.type_embedding = TypeEmbedNet.deserialize( + type_embedding, suffix=suffix + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + data = super().serialize(suffix) + data.update( + { + "type": "dpa1", + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "concat_output_tebd": self.concat_output_tebd, + "type_embedding": self.type_embedding.serialize(suffix), + } + ) + data["attention_layers"] = self.update_attention_layers_serialize( + data["attention_layers"] + ) + return data diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index c7873b951c..129b10b6ff 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -178,18 +178,18 @@ def dlopen_library(module: str, filename: str): )[:-1] ATTENTION_LAYER_PATTERN = str( - r"attention_layer_\d+/c_query/matrix|" - r"attention_layer_\d+/c_query/bias|" - r"attention_layer_\d+/c_key/matrix|" - r"attention_layer_\d+/c_key/bias|" - r"attention_layer_\d+/c_value/matrix|" - r"attention_layer_\d+/c_value/bias|" - r"attention_layer_\d+/c_out/matrix|" - r"attention_layer_\d+/c_out/bias|" - r"attention_layer_\d+/layer_normalization/beta|" - r"attention_layer_\d+/layer_normalization/gamma|" - r"attention_layer_\d+/layer_normalization_\d+/beta|" - r"attention_layer_\d+/layer_normalization_\d+/gamma|" + r"attention_layer_(\d+)/(c_query)/(matrix)|" + r"attention_layer_(\d+)/(c_query)/(bias)|" + r"attention_layer_(\d+)/(c_key)/(matrix)|" + r"attention_layer_(\d+)/(c_key)/(bias)|" + r"attention_layer_(\d+)/(c_value)/(matrix)|" + r"attention_layer_(\d+)/(c_value)/(bias)|" + r"attention_layer_(\d+)/(c_out)/(matrix)|" + r"attention_layer_(\d+)/(c_out)/(bias)|" + r"attention_layer_(\d+)/(layer_normalization)/(beta)|" + r"attention_layer_(\d+)/(layer_normalization)/(gamma)|" + # r"attention_layer_(\d+)/(layer_normalization)_\d+/(beta)|" + # r"attention_layer_(\d+)/(layer_normalization)_\d+/(gamma)|" ) TRANSFER_PATTERN = ( diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index a6e2ab7422..53de9c9ce2 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -455,11 +455,11 @@ def get_attention_layer_nodes_from_graph_def( """ if suffix != "": attention_layer_pattern = ( - ATTENTION_LAYER_PATTERN.replace("/c_query", suffix + "/c_query") - .replace("/c_key", suffix + "/c_key") - .replace("/c_value", suffix + "/c_value") - .replace("/c_out", suffix + "/c_out") - .replace("/layer_normalization", suffix + "/layer_normalization") + ATTENTION_LAYER_PATTERN.replace("/(c_query)", suffix + "/(c_query)") + .replace("/(c_key)", suffix + "/(c_key)") + .replace("/(c_value)", suffix + "/(c_value)") + .replace("/(c_out)", suffix + "/(c_out)") + .replace("/(layer_normalization)", suffix + "/(layer_normalization)") ) else: attention_layer_pattern = ATTENTION_LAYER_PATTERN diff --git a/deepmd/tf/utils/network.py b/deepmd/tf/utils/network.py index fb8e89c737..916f783050 100644 --- a/deepmd/tf/utils/network.py +++ b/deepmd/tf/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + import numpy as np from deepmd.tf.common import ( @@ -105,6 +106,271 @@ def one_layer( return hidden +def layer_norm_tf(x, shape, weight=None, bias=None, eps=1e-5): + """ + Layer normalization implementation in TensorFlow. + + Parameters + ---------- + x : tf.Tensor + The input tensor. + shape : tuple + The shape of the weight and bias tensors. + weight : tf.Tensor + The weight tensor. + bias : tf.Tensor + The bias tensor. + eps : float + A small value added to prevent division by zero. + + Returns + ------- + tf.Tensor + The normalized output tensor. + """ + # Calculate the mean and variance + mean = tf.reduce_mean(x, axis=list(range(-len(shape), 0)), keepdims=True) + variance = tf.reduce_mean( + tf.square(x - mean), axis=list(range(-len(shape), 0)), keepdims=True + ) + + # Normalize the input + x_ln = (x - mean) / tf.sqrt(variance + eps) + + # Scale and shift the normalized input + if weight is not None and bias is not None: + x_ln = x_ln * weight + bias + + return x_ln + + +def layernorm( + inputs, + outputs_size, + precision=GLOBAL_TF_FLOAT_PRECISION, + name="linear", + scope="", + reuse=None, + seed=None, + uniform_seed=False, + uni_init=True, + eps=1e-5, + trainable=True, + initial_variables=None, +): + with tf.variable_scope(name, reuse=reuse): + shape = inputs.get_shape().as_list() + if uni_init: + gamma_initializer = tf.ones_initializer() + beta_initializer = tf.zeros_initializer() + else: + gamma_initializer = tf.random_normal_initializer( + seed=seed if (seed is None or uniform_seed) else seed + 0 + ) + beta_initializer = tf.random_normal_initializer( + seed=seed if (seed is None or uniform_seed) else seed + 1 + ) + if initial_variables is not None: + gamma_initializer = tf.constant_initializer( + initial_variables[scope + name + "/gamma"] + ) + beta_initializer = tf.constant_initializer( + initial_variables[scope + name + "/beta"] + ) + gamma = tf.get_variable( + "gamma", + [outputs_size], + precision, + gamma_initializer, + trainable=trainable, + ) + variable_summaries(gamma, "gamma") + beta = tf.get_variable( + "beta", [outputs_size], precision, beta_initializer, trainable=trainable + ) + variable_summaries(beta, "beta") + + output = layer_norm_tf( + inputs, + (outputs_size,), + weight=gamma, + bias=beta, + eps=eps, + ) + return output + + +# class LayerNormCompat: +# """Implementation of Layer Normalization layer for testing with other backend references. +# +# Parameters +# ---------- +# num_in : int +# The input dimension of the layer. +# eps : float, optional +# A small value added to prevent division by zero in calculations. +# uni_init : bool, optional +# If initialize the weights to be zeros and ones. +# precision : str, optional +# The precision of the layer parameters. Supported options are |PRECISION| +# """ +# +# def __init__( +# self, +# num_in: int, +# eps: float = 1e-5, +# uni_init: bool = True, +# precision: str = "default", +# ) -> None: +# self.eps = eps +# self.uni_init = uni_init +# self.num_in = num_in +# self.filter_precision = get_precision(precision) +# self.layer_norm_variables = None +# +# def build( +# self, +# inputs, +# input_shape: List[int], +# reuse=None, +# suffix="", +# ): +# """Build the computational graph for the layer normalization. +# +# Parameters +# ---------- +# input_shape +# The shape of the input tensor. +# reuse +# The weights in the networks should be reused when get the variable. +# suffix +# Name suffix to identify this layer +# +# Returns +# ------- +# normalized_output +# The computational graph for the normalized output +# """ +# assert input_shape[-1] == self.num_in +# name = "layer_norm" + suffix +# with tf.variable_scope(name, reuse=reuse): +# gamma = tf.get_variable( +# "gamma", +# shape=[self.num_in], +# initializer=tf.ones_initializer(), +# dtype=self.filter_precision, +# trainable=True, +# ) +# beta = tf.get_variable( +# "beta", +# shape=[self.num_in], +# initializer=tf.zeros_initializer(), +# dtype=self.filter_precision, +# trainable=True, +# ) +# normalized_output = tf.contrib.layers.layer_norm( +# inputs=input, +# begin_norm_axis=-1, +# begin_params_axis=-1, +# epsilon=self.eps, +# activation_fn=None, +# param_initializers={ +# "gamma": tf.ones_initializer(), +# "beta": tf.zeros_initializer(), +# }, +# trainable=True, +# reuse=reuse, +# variables_collections=None, +# outputs_collections=None, +# data_format="NHWC", +# name=name, +# ) +# return normalized_output +# +# def init_variables( +# self, +# graph: tf.Graph, +# graph_def: tf.GraphDef, +# suffix="", +# model_type="original_model", +# ) -> None: +# """Init the layer norm variables with the given dict. +# +# Parameters +# ---------- +# graph : tf.Graph +# The input frozen model graph +# graph_def : tf.GraphDef +# The input frozen model graph_def +# suffix +# Name suffix to identify this layer +# model_type +# Indicator of whether this model is a compressed model +# """ +# self.layer_norm_variables = get_layer_norm_variables_from_graph_def( +# graph_def, suffix=suffix +# ) +# +# @classmethod +# def deserialize(cls, data: dict, suffix: str = ""): +# """Deserialize the layer from a dict. +# +# Parameters +# ---------- +# data : dict +# The dict to deserialize from. +# suffix : str, optional +# The suffix of the scope +# +# Returns +# ------- +# LayerNorm +# The deserialized layer +# """ +# data = data.copy() +# check_version_compatibility(data.pop("@version", 1), 1, 1) +# data_cls = data.pop("@class") +# assert data_cls == "LayerNorm", f"Invalid class {data_cls}" +# variables = data.pop("@variables") +# obj = cls( +# num_in=variables["w"].shape[0], +# eps=data.pop("eps"), +# precision=data.pop("precision"), +# ) +# obj.layer_norm_variables = { +# f"layer_norm{suffix}/gamma": variables["w"], +# f"layer_norm{suffix}/beta": variables["b"], +# } +# return obj +# +# def serialize(self, suffix: str = "") -> dict: +# """Serialize the layer to a dict. +# +# Parameters +# ---------- +# suffix : str, optional +# The suffix of the scope +# +# Returns +# ------- +# dict +# The serialized layer. +# """ +# assert self.layer_norm_variables is not None +# gamma = self.layer_norm_variables[f"layer_norm{suffix}/gamma"] +# beta = self.layer_norm_variables[f"layer_norm{suffix}/beta"] +# return { +# "@class": "LayerNorm", +# "@version": 1, +# "eps": self.eps, +# "precision": self.filter_precision.name, +# "@variables": { +# "w": gamma, +# "b": beta, +# }, +# } + + def embedding_net_rand_seed_shift(network_size): shift = 3 * (len(network_size) + 1) return shift diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 07add486c1..276f4d63e9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -409,7 +409,7 @@ def descrpt_se_atten_common_args(): ) doc_type_one_side = ( doc_only_tf_supported - + r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." + + r"If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'." ) doc_precision = ( doc_only_tf_supported @@ -476,7 +476,7 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." - doc_smooth_type_embdding = "When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." + doc_smooth_type_embdding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_tebd_dim = "The dimension of atom type embedding." doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." @@ -507,7 +507,7 @@ def descrpt_se_atten_args(): bool, optional=True, default=False, - doc=doc_only_tf_supported + doc_smooth_type_embdding, + doc=doc_smooth_type_embdding, ), Argument( "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero @@ -527,27 +527,6 @@ def descrpt_se_atten_args(): default="concat", doc=doc_only_pt_supported + doc_deprecated, ), - Argument( - "post_ln", - bool, - optional=True, - default=True, - doc=doc_only_pt_supported + doc_deprecated, - ), - Argument( - "ffn", - bool, - optional=True, - default=False, - doc=doc_only_pt_supported + doc_deprecated, - ), - Argument( - "ffn_embed_dim", - int, - optional=True, - default=1024, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "scaling_factor", float, @@ -555,13 +534,6 @@ def descrpt_se_atten_args(): default=1.0, doc=doc_only_pt_supported + doc_scaling_factor, ), - Argument( - "head_num", - int, - optional=True, - default=1, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "normalize", bool, @@ -575,13 +547,6 @@ def descrpt_se_atten_args(): optional=True, doc=doc_only_pt_supported + doc_temperature, ), - Argument( - "return_rot", - bool, - optional=True, - default=False, - doc=doc_only_pt_supported + doc_deprecated, - ), Argument( "concat_output_tebd", bool, diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 364d35805b..ccc1e476e0 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -133,7 +133,6 @@ An example of the DPA-1 descriptor is provided as follows "attn_layer": 2, "attn_mask": false, "attn_dotr": true, - "post_ln": true } ``` @@ -147,7 +146,6 @@ An example of the DPA-1 descriptor is provided as follows - {ref}`attn_layer ` sets the number of layers in attention mechanism. - {ref}`attn_mask ` determines whether to mask the diagonal in the attention weights and False is recommended. - {ref}`attn_dotr ` determines whether to dot the relative coordinates on the attention weights as a gated scheme, True is recommended. -- {ref}`post_ln ` determines whether to perform post layer norm. ::: diff --git a/examples/water/se_atten/input_torch.json b/examples/water/se_atten/input_torch.json index 4160feda17..cdb4b0db49 100644 --- a/examples/water/se_atten/input_torch.json +++ b/examples/water/se_atten/input_torch.json @@ -22,12 +22,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 1024, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, diff --git a/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json b/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json new file mode 100644 index 0000000000..b926803eba --- /dev/null +++ b/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json @@ -0,0 +1,11 @@ +{ + "30fbbb65ecaaf3cbf6c6d2b94493105df86183b4": { + "files": { + "doc/model/train-se-atten.md": [ + "6hG+T7kQKxMcKc50vx7IIcn0lrk=", + true + ] + }, + "modified": 1713690325168 + } +} diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index ef7b39b52e..9c8c1cea7f 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -57,7 +57,9 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): t_mesh: make_default_mesh(True, False), } - def eval_dp_descriptor(self, dp_obj: Any, natoms, coords, atype, box) -> Any: + def eval_dp_descriptor( + self, dp_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( coords.reshape(1, -1, 3), atype.reshape(1, -1), @@ -70,11 +72,13 @@ def eval_dp_descriptor(self, dp_obj: Any, natoms, coords, atype, box) -> Any: natoms[0], dp_obj.get_rcut(), dp_obj.get_sel(), - distinguish_types=True, + distinguish_types=(not mixed_types), ) return dp_obj(ext_coords, ext_atype, nlist=nlist) - def eval_pt_descriptor(self, pt_obj: Any, natoms, coords, atype, box) -> Any: + def eval_pt_descriptor( + self, pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), @@ -87,7 +91,7 @@ def eval_pt_descriptor(self, pt_obj: Any, natoms, coords, atype, box) -> Any: natoms[0], pt_obj.get_rcut(), pt_obj.get_sel(), - distinguish_types=True, + distinguish_types=(not mixed_types), ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py new file mode 100644 index 0000000000..cc2bd57457 --- /dev/null +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PT +else: + DescrptDPA1PT = None +if INSTALLED_TF: + from deepmd.tf.descriptor.se_atten import DescrptDPA1Compat as DescrptDPA1TF +else: + DescrptDPA1TF = None +from deepmd.utils.argcheck import ( + descrpt_se_atten_args, +) + + +@parameterized( + (True, False), # resnet_dt + ([], [[0, 1]]), # excluded_types + ("float32", "float64"), # precision + (0.0, 1e-8, 1e-2), # env_protection + (True, False), # smooth_type_embdding + (True, False), # type_one_side + (True, False), # set_davg_zero + (0, 2), # attn_layer +) +class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + return { + "sel": [10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "ntypes": self.ntypes, + "axis_neuron": 3, + "tebd_dim": 4, + # "tebd_input_mode": tebd_input_mode, + "attn": 20, + "attn_layer": attn_layer, + "attn_dotr": True, + "attn_mask": False, + "scaling_factor": 1.0, + "normalize": True, + "temperature": 1.0, + "concat_output_tebd": True, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "smooth_type_embdding": smooth_type_embdding, + "seed": 1145141919810, + } + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + # TODO (excluded_types != [] and attn_layer > 0) need fix + return ( + env_protection != 0.0 + or smooth_type_embdding + or (excluded_types != [] and attn_layer > 0) + ) + + tf_class = DescrptDPA1TF + dp_class = DescrptDPA1DP + pt_class = DescrptDPA1PT + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + excluded_types, + precision, + env_protection, + smooth_type_embdding, + type_one_side, + set_davg_zero, + attn_layer, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/models/dpa1.json b/source/tests/pt/model/models/dpa1.json index 5d2c65c214..1321acbd53 100644 --- a/source/tests/pt/model/models/dpa1.json +++ b/source/tests/pt/model/models/dpa1.json @@ -18,12 +18,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 10, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py new file mode 100644 index 0000000000..57b55fcf16 --- /dev/null +++ b/source/tests/pt/model/test_dpa1.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +try: + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 + + support_se_atten = True +except ModuleNotFoundError: + support_se_atten = False +except ImportError: + support_se_atten = False + +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +@unittest.skipIf(not support_se_atten, "EnvMat not supported") +class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, sm, to in itertools.product( + [False, True], + ["float64", "float32"], + [False, True], + [False, True], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + + # dpa1 new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embdding=sm, + type_one_side=to, + old_impl=False, + ).to(env.DEVICE) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA1.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl serialization + dd3 = DPDescrptDPA1.deserialize(dd2.serialize()) + rd3, _, _, _, _ = dd3.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # old impl + if idt is False and prec == "float64" and to is False: + dd4 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embdding=sm, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.se_atten.state_dict() + dd4_state_dict = dd4.se_atten.state_dict() + + dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() + dd4_state_dict_attn = dd4.se_atten.dpa1_attention.state_dict() + for i in dd4_state_dict: + dd4_state_dict[i] = ( + dd0_state_dict[ + i.replace(".deep_layers.", ".layers.") + .replace("filter_layers_old.", "filter_layers._networks.") + .replace( + ".attn_layer_norm.weight", ".attn_layer_norm.matrix" + ) + ] + .detach() + .clone() + ) + if ".bias" in i and "attn_layer_norm" not in i: + dd4_state_dict[i] = dd4_state_dict[i].unsqueeze(0) + dd4.se_atten.load_state_dict(dd4_state_dict) + + dd0_state_dict_tebd = dd0.type_embedding.state_dict() + dd4_state_dict_tebd = dd4.type_embedding.state_dict() + for i in dd4_state_dict_tebd: + dd4_state_dict_tebd[i] = ( + dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] + .detach() + .clone() + ) + dd4.type_embedding.load_state_dict(dd4_state_dict_tebd) + + rd4, _, _, _, _ = dd4( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd4.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, sm, to in itertools.product( + [False, True], + ["float64", "float32"], + [False, True], + [False, True], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + smooth_type_embdding=sm, + type_one_side=to, + old_impl=False, + ) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + # dd1 = DescrptDPA1.deserialize(dd0.serialize()) + model = torch.jit.script(dd0) + # model = torch.jit.script(dd1) diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index e18093b2f1..24ed886b86 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -35,6 +35,8 @@ def setUp(self): self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] + self.natoms = [3, 3, 2, 1] self.nlist = np.array( [ [1, 3, -1, -1, -1, 2, -1], @@ -83,6 +85,8 @@ def setUp(self): self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall]) # sel = [5, 2] self.sel = [5, 2] + self.sel_mix = [7] + self.natoms = [3, 3, 2, 1] self.nlist = np.array( [ [2, 4, -1, -1, -1, 3, -1], @@ -131,6 +135,8 @@ def setUp(self): self.cell = 2.0 * np.eye(3).reshape([1, 9]) # sel = [5, 2] self.sel = [16, 8] + self.sel_mix = [24] + self.natoms = [3, 3, 2, 1] self.rcut = 2.2 self.rcut_smth = 0.4 self.atol = 1e-12 diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 5e395eb8c0..110abd0d23 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -158,12 +158,8 @@ "attn_layer": 2, "attn_dotr": True, "attn_mask": False, - "post_ln": True, - "ffn": False, - "ffn_embed_dim": 512, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": False, "temperature": 1.0, "set_davg_zero": True, @@ -193,12 +189,8 @@ "attn_layer": 0, "attn_dotr": True, "attn_mask": False, - "post_ln": True, - "ffn": False, - "ffn_embed_dim": 1024, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": True, "temperature": 1.0, }, diff --git a/source/tests/pt/model/water/se_atten.json b/source/tests/pt/model/water/se_atten.json index 6b6fca50d3..71cee94d8b 100644 --- a/source/tests/pt/model/water/se_atten.json +++ b/source/tests/pt/model/water/se_atten.json @@ -21,12 +21,8 @@ "attn_layer": 2, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 512, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": false, "temperature": 1.0 }, From ae8b0d1ba1c288d19bcbdf576f81bfb9961d84b4 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:27:23 +0800 Subject: [PATCH 02/18] remove bk files --- deepmd/dpmodel/descriptor/dpa1_bk.py | 402 ------------------ ...9545d53bb64e65febe2ff48926b4145285f3a.json | 11 - 2 files changed, 413 deletions(-) delete mode 100644 deepmd/dpmodel/descriptor/dpa1_bk.py delete mode 100644 node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json diff --git a/deepmd/dpmodel/descriptor/dpa1_bk.py b/deepmd/dpmodel/descriptor/dpa1_bk.py deleted file mode 100644 index 3ca28a4fae..0000000000 --- a/deepmd/dpmodel/descriptor/dpa1_bk.py +++ /dev/null @@ -1,402 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import numpy as np - -try: - from deepmd._version import version as __version__ -except ImportError: - __version__ = "unknown" - -import copy -from typing import ( - Any, - List, - Optional, -) - -from .common import ( - DEFAULT_PRECISION, - NativeOP, -) -from .env_mat import ( - EnvMat, -) -from .network import ( - EmbdLayer, - EmbeddingNet, - NetworkCollection, -) - - -class DescrptDPA1(NativeOP): - r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. - - This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by - - .. math:: - \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, - - where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` - after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. - Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. - - To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, - keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, - and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: - - .. math:: - \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), - - .. math:: - \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), - - .. math:: - \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), - - where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations - that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` - is the index of the attention layer. - The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, - is chosen as the two-body embedding matrix. - - Then the scaled dot-product attention method is adopted: - - .. math:: - A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, - - where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. - In the original attention method, - one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, - with :math:`\sqrt{d_{k}}` being the normalization temperature. - This is slightly modified to incorporate the angular information: - - .. math:: - \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, - - where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, - :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` - and :math:`\odot` means element-wise multiplication. - - Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix - :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] - - .. math:: - \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). - - Parameters - ---------- - rcut - The cut-off radius :math:`r_c` - rcut_smth - From where the environment matrix should be smoothed :math:`r_s` - sel : list[str] - sel[i] specifies the maxmum number of type i atoms in the cut-off radius - ntypes : int - Number of element types - neuron : list[int] - Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` - axis_neuron - Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) - tebd_dim: int - Dimension of the type embedding - tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. - resnet_dt - Time-step `dt` in the resnet construction: - y = x + dt * \phi (Wx + b) - trainable - If the weights of embedding net are trainable. - type_one_side - Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets - attn: int - Hidden dimension of the attention vectors - attn_layer: int - Number of attention layers - attn_dotr: bool - If dot the angular gate to the attention weights - attn_mask: bool - If mask the diagonal of attention weights - exclude_types : List[List[int]] - The excluded pairs of types which have no interaction with each other. - For example, `[[0, 1]]` means no interaction between type 0 and type 1. - set_davg_zero - Set the shift of embedding net input to zero. - activation_function - The activation function in the embedding net. Supported options are |ACTIVATION_FN| - precision - The precision of the embedding net parameters. Supported options are |PRECISION| - scaling_factor: float - The scaling factor of normalization in calculations of attention weights. - If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 - temperature: Optional[float] - If not None, the scaling of attention weights is `temperature` itself. - spin - The deepspin object. - - Limitations - ----------- - The currently implementation does not support the following features - - 1. type_one_side == False - 2. exclude_types != [] - 3. spin is not None - 4. tebd_input_mode != 'concat' - 5. smooth == True - - References - ---------- - .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. - DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. - arXiv preprint arXiv:2208.08236. - """ - - def __init__( - self, - rcut: float, - rcut_smth: float, - sel: List[str], - ntypes: int, - neuron: List[int] = [25, 50, 100], - axis_neuron: int = 8, - tebd_dim: int = 8, - tebd_input_mode: str = "concat", - resnet_dt: bool = False, - trainable: bool = True, - type_one_side: bool = True, - attn: int = 128, - attn_layer: int = 2, - attn_dotr: bool = True, - attn_mask: bool = False, - exclude_types: List[List[int]] = [], - set_davg_zero: bool = False, - activation_function: str = "tanh", - precision: str = DEFAULT_PRECISION, - scaling_factor=1.0, - normalize=True, - temperature=None, - smooth: bool = True, - concat_output_tebd: bool = True, - spin: Optional[Any] = None, - ) -> None: - ## seed, uniform_seed, multi_task, not included. - if not type_one_side: - raise NotImplementedError("type_one_side == False not implemented") - if exclude_types != []: - raise NotImplementedError("exclude_types is not implemented") - if spin is not None: - raise NotImplementedError("spin is not implemented") - # TODO - if tebd_input_mode != "concat": - raise NotImplementedError("tebd_input_mode != 'concat' not implemented") - if not smooth: - raise NotImplementedError("smooth == False not implemented") - - self.rcut = rcut - self.rcut_smth = rcut_smth - if isinstance(sel, int): - sel = [sel] - self.sel = sel - self.ntypes = ntypes - self.neuron = neuron - self.axis_neuron = axis_neuron - self.tebd_dim = tebd_dim - self.tebd_input_mode = tebd_input_mode - self.resnet_dt = resnet_dt - self.trainable = trainable - self.type_one_side = type_one_side - self.attn = attn - self.attn_layer = attn_layer - self.attn_dotr = attn_dotr - self.attn_mask = attn_mask - self.exclude_types = exclude_types - self.set_davg_zero = set_davg_zero - self.activation_function = activation_function - self.precision = precision - self.scaling_factor = scaling_factor - self.normalize = normalize - self.temperature = temperature - self.concat_output_tebd = concat_output_tebd - self.spin = spin - - self.type_embedding = EmbdLayer( - ntypes, tebd_dim, padding=True, precision=precision - ) - in_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ["concat"] else 1 - self.embeddings = NetworkCollection( - ndim=0, - ntypes=self.ntypes, - network_type="embedding_network", - ) - self.embeddings[0] = EmbeddingNet( - in_dim, - self.neuron, - self.activation_function, - self.resnet_dt, - self.precision, - ) - # self.dpa1_attention = NeighborGatedAttention - self.env_mat = EnvMat(self.rcut, self.rcut_smth) - self.nnei = np.sum(self.sel) - self.davg = np.zeros([self.ntypes, self.nnei, 4]) - self.dstd = np.ones([self.ntypes, self.nnei, 4]) - self.orig_sel = self.sel - - def __setitem__(self, key, value): - if key in ("avg", "data_avg", "davg"): - self.davg = value - elif key in ("std", "data_std", "dstd"): - self.dstd = value - else: - raise KeyError(key) - - def __getitem__(self, key): - if key in ("avg", "data_avg", "davg"): - return self.davg - elif key in ("std", "data_std", "dstd"): - return self.dstd - else: - raise KeyError(key) - - @property - def dim_out(self): - """Returns the output dimension of this descriptor.""" - return ( - self.neuron[-1] * self.axis_neuron + self.tebd_dim * 2 - if self.concat_output_tebd - else self.neuron[-1] * self.axis_neuron - ) - - def cal_g( - self, - ss, - ll, - ): - nf, nloc, nnei = ss.shape[0:3] - ss = ss.reshape(nf, nloc, nnei, -1) - # nf x nloc x nnei x ng - gg = self.embeddings[ll].call(ss) - return gg - - def call( - self, - coord_ext, - atype_ext, - nlist, - ): - """Compute the descriptor. - - Parameters - ---------- - coord_ext - The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext - The extended aotm types. shape: nf x nall - nlist - The neighbor list. shape: nf x nloc x nnei - - Returns - ------- - descriptor - The descriptor. shape: nf x nloc x (ng x axis_neuron) - gr - The rotationally equivariant and permutationally invariant single particle - representation. shape: nf x nloc x ng x 3 - g2 - The rotationally invariant pair-partical representation. - this descriptor returns None - h2 - The rotationally equivariant pair-partical representation. - this descriptor returns None - sw - The smooth switch function. - """ - # nf x nloc x nnei x 4 - rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) - nf, nloc, nnei, _ = rr.shape - - # add type embedding into input - # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call(atype_ext) - atype_embd = atype_embd_ext[:, :nloc, :] - # nf x nloc x nnei x tebd_dim - atype_embd_nnei = np.tile(atype_embd[:, :, np.newaxis, :], (1, 1, nnei, 1)) - nlist_mask = nlist != -1 - nlist_masked = np.copy(nlist) - nlist_masked[nlist_masked == -1] = 0 - index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) - # nf x nloc x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( - nf, nloc, nnei, self.tebd_dim - ) - ng = self.neuron[-1] - ss = rr[..., 0:1] - ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) - - # calculate gg - gg = self.cal_g(ss, 0) - # nf x nloc x ng x 4 - gr = np.einsum("flni,flnj->flij", gg, rr) - # nf x nloc x ng x 4 - gr /= self.nnei - gr1 = gr[:, :, : self.axis_neuron, :] - # nf x nloc x ng x ng1 - grrg = np.einsum("flid,fljd->flij", gr, gr1) - # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) - if self.concat_output_tebd: - grrg = np.concatenate([grrg, atype_embd], axis=-1) - return grrg, gr[..., 1:], None, None, ww - - def serialize(self) -> dict: - """Serialize the descriptor to dict.""" - return { - "rcut": self.rcut, - "rcut_smth": self.rcut_smth, - "sel": self.sel, - "ntypes": self.ntypes, - "neuron": self.neuron, - "axis_neuron": self.axis_neuron, - "tebd_dim": self.tebd_dim, - "tebd_input_mode": self.tebd_input_mode, - "resnet_dt": self.resnet_dt, - "trainable": self.trainable, - "type_one_side": self.type_one_side, - "exclude_types": self.exclude_types, - "set_davg_zero": self.set_davg_zero, - "attn": self.attn, - "attn_layer": self.attn_layer, - "attn_dotr": self.attn_dotr, - "attn_mask": self.attn_mask, - "activation_function": self.activation_function, - "precision": self.precision, - "spin": self.spin, - "scaling_factor": self.scaling_factor, - "normalize": self.normalize, - "temperature": self.temperature, - "concat_output_tebd": self.concat_output_tebd, - "embeddings": self.embeddings.serialize(), - # "attention_layers": self.dpa1_attention.serialize(), - "env_mat": self.env_mat.serialize(), - "type_embedding": self.type_embedding.serialize(), - "@variables": { - "davg": self.davg, - "dstd": self.dstd, - }, - } - - @classmethod - def deserialize(cls, data: dict) -> "DescrptDPA1": - """Deserialize from dict.""" - data = copy.deepcopy(data) - variables = data.pop("@variables") - embeddings = data.pop("embeddings") - type_embedding = data.pop("type_embedding") - attention_layers = data.pop("attention_layers", None) - env_mat = data.pop("env_mat") - obj = cls(**data) - obj["davg"] = variables["davg"] - obj["dstd"] = variables["dstd"] - obj.type_embedding = EmbdLayer.deserialize(type_embedding) - obj.embeddings = NetworkCollection.deserialize(embeddings) - obj.env_mat = EnvMat.deserialize(env_mat) - # obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) - return obj diff --git a/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json b/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json deleted file mode 100644 index b926803eba..0000000000 --- a/node_modules/.cache/prettier/.prettier-caches/9be9545d53bb64e65febe2ff48926b4145285f3a.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "30fbbb65ecaaf3cbf6c6d2b94493105df86183b4": { - "files": { - "doc/model/train-se-atten.md": [ - "6hG+T7kQKxMcKc50vx7IIcn0lrk=", - true - ] - }, - "modified": 1713690325168 - } -} From d16e5a2b151a15ce0c9c4ed5e73e13500f25827f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 00:13:20 +0800 Subject: [PATCH 03/18] fix typo --- deepmd/dpmodel/descriptor/dpa1.py | 8 ++++---- deepmd/pt/model/descriptor/dpa1.py | 8 ++++---- deepmd/tf/descriptor/se_atten.py | 18 +++++++++-------- deepmd/tf/descriptor/se_atten_v2.py | 2 +- deepmd/utils/argcheck.py | 6 +++--- doc/model/train-se-atten.md | 2 +- .../tests/consistent/descriptor/test_dpa1.py | 20 +++++++++---------- source/tests/pt/model/test_dpa1.py | 6 +++--- .../tf/test_model_compression_se_atten.py | 14 ++++++------- source/tests/tf/test_model_se_atten.py | 4 ++-- 10 files changed, 45 insertions(+), 43 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index d38879b62b..437c37b47a 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -171,7 +171,7 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Whether to normalize the hidden vectors in attention weights calculation. temperature: float If not None, the scaling of attention weights is `temperature` itself. - smooth_type_embdding: bool + smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. @@ -219,7 +219,7 @@ def __init__( scaling_factor=1.0, normalize: bool = True, temperature: Optional[float] = None, - smooth_type_embdding: bool = True, + smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, # consistent with argcheck, not used though @@ -259,7 +259,7 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature - self.smooth = smooth_type_embdding + self.smooth = smooth_type_embedding self.concat_output_tebd = concat_output_tebd self.spin = spin # order matters, placed after the assignment of self.ntypes @@ -535,7 +535,7 @@ def serialize(self) -> dict: "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, - "smooth_type_embdding": self.smooth, + "smooth_type_embedding": self.smooth, "type_one_side": self.type_one_side, "concat_output_tebd": self.concat_output_tebd, # make deterministic diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 9f03edce12..24f1012cab 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -153,7 +153,7 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Whether to normalize the hidden vectors in attention weights calculation. temperature: float If not None, the scaling of attention weights is `temperature` itself. - smooth_type_embdding: bool + smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. @@ -201,7 +201,7 @@ def __init__( temperature=None, concat_output_tebd: bool = True, trainable: bool = True, - smooth_type_embdding: bool = True, + smooth_type_embedding: bool = True, type_one_side: bool = False, # not implemented stripped_type_embedding: bool = False, @@ -236,7 +236,7 @@ def __init__( scaling_factor=scaling_factor, normalize=normalize, temperature=temperature, - smooth=smooth_type_embdding, + smooth=smooth_type_embedding, type_one_side=type_one_side, exclude_types=exclude_types, env_protection=env_protection, @@ -373,7 +373,7 @@ def serialize(self) -> dict: "scaling_factor": obj.scaling_factor, "normalize": obj.normalize, "temperature": obj.temperature, - "smooth_type_embdding": obj.smooth, + "smooth_type_embedding": obj.smooth, "type_one_side": obj.type_one_side, "concat_output_tebd": self.concat_output_tebd, # make deterministic diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 249337e162..0ddb8b9da6 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -141,7 +141,7 @@ class DescrptSeAtten(DescrptSeA): stripped_type_embedding Whether to strip the type embedding into a separated embedding network. Default value will be True in `se_atten_v2` descriptor. - smooth_type_embdding + smooth_type_embedding Whether to use smooth process in attention weights calculation. And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. @@ -176,7 +176,7 @@ def __init__( attn_mask: bool = False, multi_task: bool = False, stripped_type_embedding: bool = False, - smooth_type_embdding: bool = False, + smooth_type_embedding: bool = False, # not implemented scaling_factor=1.0, normalize=True, @@ -185,7 +185,9 @@ def __init__( env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: - if not set_davg_zero and not (stripped_type_embedding and smooth_type_embdding): + if not set_davg_zero and not ( + stripped_type_embedding and smooth_type_embedding + ): warnings.warn( "Set 'set_davg_zero' False in descriptor 'se_atten' " "may cause unexpected incontinuity during model inference!" @@ -230,7 +232,7 @@ def __init__( if ntypes == 0: raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding - self.smooth = smooth_type_embdding + self.smooth = smooth_type_embedding self.ntypes = ntypes self.att_n = attn self.attn_layer = attn_layer @@ -1775,7 +1777,7 @@ def serialize(self, suffix: str = "") -> dict: "attn_mask": self.attn_mask, "activation_function": self.activation_function_name, "resnet_dt": self.filter_resnet_dt, - "smooth_type_embdding": self.smooth, + "smooth_type_embedding": self.smooth, "precision": self.filter_precision.name, "embeddings": self.serialize_network( ntypes=self.ntypes, @@ -1869,7 +1871,7 @@ class DescrptDPA1Compat(DescrptSeAtten): Not supported in this version. temperature: float Not supported in this version. - smooth_type_embdding: bool + smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. Only support True in this version. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. Only support True in this version. @@ -1902,7 +1904,7 @@ def __init__( scaling_factor=1.0, normalize: bool = True, temperature: Optional[float] = None, - smooth_type_embdding: bool = True, + smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, # consistent with argcheck, not used though @@ -1948,7 +1950,7 @@ def __init__( attn_mask=attn_mask, multi_task=True, stripped_type_embedding=False, - smooth_type_embdding=smooth_type_embdding, + smooth_type_embedding=smooth_type_embedding, env_protection=env_protection, ) self.tebd_dim = tebd_dim diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 784e02d84d..01c4d93ad8 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -110,6 +110,6 @@ def __init__( attn_mask=attn_mask, multi_task=multi_task, stripped_type_embedding=True, - smooth_type_embdding=True, + smooth_type_embedding=True, **kwargs, ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 276f4d63e9..c51900e9a9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -476,7 +476,7 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." - doc_smooth_type_embdding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." + doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_tebd_dim = "The dimension of atom type embedding." doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." @@ -503,11 +503,11 @@ def descrpt_se_atten_args(): doc=doc_only_tf_supported + doc_stripped_type_embedding, ), Argument( - "smooth_type_embdding", + "smooth_type_embedding", bool, optional=True, default=False, - doc=doc_smooth_type_embdding, + doc=doc_smooth_type_embedding, ), Argument( "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index ccc1e476e0..59333eb0da 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -157,7 +157,7 @@ We highly recommend using the version 2.0 of the attention-based descriptor `"se ```json "stripped_type_embedding": true, - "smooth_type_embdding": true, + "smooth_type_embedding": true, "set_davg_zero": false ``` diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index cc2bd57457..854e0fa43e 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -43,7 +43,7 @@ ([], [[0, 1]]), # excluded_types ("float32", "float64"), # precision (0.0, 1e-8, 1e-2), # env_protection - (True, False), # smooth_type_embdding + (True, False), # smooth_type_embedding (True, False), # type_one_side (True, False), # set_davg_zero (0, 2), # attn_layer @@ -56,7 +56,7 @@ def data(self) -> dict: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -84,7 +84,7 @@ def data(self) -> dict: "env_protection": env_protection, "precision": precision, "set_davg_zero": set_davg_zero, - "smooth_type_embdding": smooth_type_embdding, + "smooth_type_embedding": smooth_type_embedding, "seed": 1145141919810, } @@ -95,7 +95,7 @@ def skip_pt(self) -> bool: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -109,7 +109,7 @@ def skip_dp(self) -> bool: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -123,7 +123,7 @@ def skip_tf(self) -> bool: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -131,7 +131,7 @@ def skip_tf(self) -> bool: # TODO (excluded_types != [] and attn_layer > 0) need fix return ( env_protection != 0.0 - or smooth_type_embdding + or smooth_type_embedding or (excluded_types != [] and attn_layer > 0) ) @@ -178,7 +178,7 @@ def setUp(self): excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -225,7 +225,7 @@ def rtol(self) -> float: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, @@ -245,7 +245,7 @@ def atol(self) -> float: excluded_types, precision, env_protection, - smooth_type_embdding, + smooth_type_embedding, type_one_side, set_davg_zero, attn_layer, diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index 57b55fcf16..3cec5d1876 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -67,7 +67,7 @@ def test_consistency( attn_layer=2, precision=prec, resnet_dt=idt, - smooth_type_embdding=sm, + smooth_type_embedding=sm, type_one_side=to, old_impl=False, ).to(env.DEVICE) @@ -130,7 +130,7 @@ def test_consistency( attn_layer=2, precision=prec, resnet_dt=idt, - smooth_type_embdding=sm, + smooth_type_embedding=sm, old_impl=True, ).to(env.DEVICE) dd0_state_dict = dd0.se_atten.state_dict() @@ -203,7 +203,7 @@ def test_jit( self.nt, precision=prec, resnet_dt=idt, - smooth_type_embdding=sm, + smooth_type_embedding=sm, type_one_side=to, old_impl=False, ) diff --git a/source/tests/tf/test_model_compression_se_atten.py b/source/tests/tf/test_model_compression_se_atten.py index aa1f0afa38..03ddedad39 100644 --- a/source/tests/tf/test_model_compression_se_atten.py +++ b/source/tests/tf/test_model_compression_se_atten.py @@ -37,27 +37,27 @@ def _file_delete(file): { "se_atten precision": "float64", "type embedding precision": "float64", - "smooth_type_embdding": True, + "smooth_type_embedding": True, }, { "se_atten precision": "float64", "type embedding precision": "float64", - "smooth_type_embdding": False, + "smooth_type_embedding": False, }, { "se_atten precision": "float64", "type embedding precision": "float32", - "smooth_type_embdding": True, + "smooth_type_embedding": True, }, { "se_atten precision": "float32", "type embedding precision": "float64", - "smooth_type_embdding": True, + "smooth_type_embedding": True, }, { "se_atten precision": "float32", "type embedding precision": "float32", - "smooth_type_embdding": True, + "smooth_type_embedding": True, }, ] @@ -82,8 +82,8 @@ def _init_models(): jdata["model"]["descriptor"]["stripped_type_embedding"] = True jdata["model"]["descriptor"]["sel"] = 120 jdata["model"]["descriptor"]["attn_layer"] = 0 - jdata["model"]["descriptor"]["smooth_type_embdding"] = tests[i][ - "smooth_type_embdding" + jdata["model"]["descriptor"]["smooth_type_embedding"] = tests[i][ + "smooth_type_embedding" ] jdata["model"]["type_embedding"] = {} jdata["model"]["type_embedding"]["precision"] = tests[i][ diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index 36cf4887c0..d75dc0cfff 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -764,7 +764,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["smooth_type_embdding"] = True + jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0 jdata["model"]["descriptor"]["rcut_smth"] = 4.0 @@ -910,7 +910,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self) jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["smooth_type_embdding"] = True + jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0 jdata["model"]["descriptor"]["rcut_smth"] = 4.0 From 19a47520e5ab0dbfb819eb5118c90c6247489ab1 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 18:43:09 +0800 Subject: [PATCH 04/18] Cleanup params to be consistent --- deepmd/dpmodel/descriptor/dpa1.py | 25 +-- deepmd/pt/model/descriptor/dpa1.py | 28 +++- deepmd/tf/descriptor/se_atten.py | 37 +++-- deepmd/tf/env.py | 2 +- .../tests/consistent/descriptor/test_dpa1.py | 150 +++++++++++++----- 5 files changed, 169 insertions(+), 73 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 437c37b47a..03f1e36ca8 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -135,7 +135,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -152,6 +153,7 @@ class DescrptDPA1(NativeOP, BaseDescriptor): attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool + (Deprecated, only support False to keep consistent with old implementation.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -176,16 +178,17 @@ class DescrptDPA1(NativeOP, BaseDescriptor): concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. spin - The old implementation of deepspin (deprecated in the descriptor). + (Deprecated, only support None to keep consistent with old implementation.) + The old implementation of deepspin. Limitations ----------- The currently implementation does not support the following features + 1. tebd_input_mode != 'concat' - 1. type_one_side == True - 2. exclude_types != [] - 3. spin is not None - 4. tebd_input_mode != 'concat' + The currently implementation will not support the following deprecated features + 1. spin is not None + 2. attn_mask == True References ---------- @@ -228,10 +231,15 @@ def __init__( ## seed, uniform_seed, multi_task, not included. if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) # TODO if tebd_input_mode != "concat": raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + del attn_mask, spin self.rcut = rcut self.rcut_smth = rcut_smth if isinstance(sel, int): @@ -250,7 +258,6 @@ def __init__( self.attn = attn self.attn_layer = attn_layer self.attn_dotr = attn_dotr - self.attn_mask = attn_mask self.exclude_types = exclude_types self.env_protection = env_protection self.set_davg_zero = set_davg_zero @@ -261,7 +268,6 @@ def __init__( self.temperature = temperature self.smooth = smooth_type_embedding self.concat_output_tebd = concat_output_tebd - self.spin = spin # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) @@ -297,7 +303,6 @@ def __init__( self.filter_neuron[-1], self.attn, dotr=self.attn_dotr, - do_mask=self.attn_mask, scaling_factor=self.scaling_factor, normalize=self.normalize, temperature=self.temperature, @@ -529,7 +534,7 @@ def serialize(self) -> dict: "attn": self.attn, "attn_layer": self.attn_layer, "attn_dotr": self.attn_dotr, - "attn_mask": self.attn_mask, + "attn_mask": False, "activation_function": self.activation_function, "resnet_dt": self.resnet_dt, "scaling_factor": self.scaling_factor, diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 24f1012cab..aec2e75e6e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -117,7 +117,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -134,6 +135,7 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool + (Deprecated, only support False to keep consistent with old implementation.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -158,15 +160,17 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. spin - The old implementation of deepspin (deprecated in the descriptor). + (Deprecated, only support None to keep consistent with old implementation.) + The old implementation of deepspin. Limitations ----------- The currently implementation does not support the following features + 1. tebd_input_mode != 'concat' - 1. exclude_types != [] - 2. spin is not None - 3. tebd_input_mode != 'concat' + The currently implementation will not support the following deprecated features + 1. spin is not None + 2. attn_mask == True References ---------- @@ -215,7 +219,15 @@ def __init__( raise NotImplementedError("stripped_type_embedding is not supported.") if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") - del type, spin + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) + # TODO + if tebd_input_mode != "concat": + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + + del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( rcut, rcut_smth, @@ -229,7 +241,7 @@ def __init__( attn=attn, attn_layer=attn_layer, attn_dotr=attn_dotr, - attn_mask=attn_mask, + attn_mask=False, activation_function=activation_function, precision=precision, resnet_dt=resnet_dt, @@ -367,7 +379,7 @@ def serialize(self) -> dict: "attn": obj.attn_dim, "attn_layer": obj.attn_layer, "attn_dotr": obj.attn_dotr, - "attn_mask": obj.attn_mask, + "attn_mask": False, "activation_function": obj.activation_function, "resnet_dt": obj.resnet_dt, "scaling_factor": obj.scaling_factor, diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 0ddb8b9da6..15db4f5a22 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1838,14 +1838,17 @@ class DescrptDPA1Compat(DescrptSeAtten): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Only support `concat` in this version. + (Only support `concat` to keep consistent with other backend references.) + The way to mix the type embeddings. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) trainable: bool If the weights of embedding net are trainable. type_one_side: bool - Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. attn: int Hidden dimension of the attention vectors attn_layer: int @@ -1853,6 +1856,7 @@ class DescrptDPA1Compat(DescrptSeAtten): attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool + (Only support False to keep consistent with other backend references.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -1866,17 +1870,26 @@ class DescrptDPA1Compat(DescrptSeAtten): precision: str The precision of the embedding net parameters. Supported options are |PRECISION| scaling_factor: float - Not supported in this version. + (Only to keep consistent with other backend references.) + (Not used in this version.) + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 normalize: bool - Not supported in this version. + (Only support True to keep consistent with other backend references.) + (Not used in this version.) + Whether to normalize the hidden vectors in attention weights calculation. temperature: float - Not supported in this version. + (Only support 1.0 to keep consistent with other backend references.) + (Not used in this version.) + If not None, the scaling of attention weights is `temperature` itself. smooth_type_embedding: bool - Whether to use smooth process in attention weights calculation. Only support True in this version. + (Only support False to keep consistent with other backend references.) + Whether to use smooth process in attention weights calculation. concat_output_tebd: bool - Whether to concat type embedding at the output of the descriptor. Only support True in this version. + Whether to concat type embedding at the output of the descriptor. spin - The old implementation of deepspin (deprecated in the descriptor). Not supported in this version. + (Only support None to keep consistent with old implementation.) + The old implementation of deepspin. """ def __init__( @@ -1921,12 +1934,12 @@ def __init__( raise NotImplementedError( "Only support temperature == 1.0 in this version." ) - if not concat_output_tebd: - raise NotImplementedError( - "Only support concat_output_tebd == True in this version." - ) if spin is not None: raise NotImplementedError("Only support spin is None in this version.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) super().__init__( rcut, diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 129b10b6ff..0bd637dc02 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -190,7 +190,7 @@ def dlopen_library(module: str, filename: str): r"attention_layer_(\d+)/(layer_normalization)/(gamma)|" # r"attention_layer_(\d+)/(layer_normalization)_\d+/(beta)|" # r"attention_layer_(\d+)/(layer_normalization)_\d+/(gamma)|" -) +)[:-1] TRANSFER_PATTERN = ( EMBEDDING_NET_PATTERN diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 854e0fa43e..d6c54b22f5 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -39,27 +39,43 @@ @parameterized( - (True, False), # resnet_dt - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision - (0.0, 1e-8, 1e-2), # env_protection - (True, False), # smooth_type_embedding + (4,), # tebd_dim + ("concat",), # tebd_input_mode + (True,), # resnet_dt (True, False), # type_one_side - (True, False), # set_davg_zero + (20,), # attn (0, 2), # attn_layer + (True, False), # attn_dotr + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (1.0,), # scaling_factor + (True, False), # normalize + (None, 1.0), # temperature + (True, False), # smooth_type_embedding + (True, False), # concat_output_tebd + ("float64",), # precision ) class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param return { "sel": [10], @@ -68,16 +84,16 @@ def data(self) -> dict: "neuron": [6, 12, 24], "ntypes": self.ntypes, "axis_neuron": 3, - "tebd_dim": 4, - # "tebd_input_mode": tebd_input_mode, - "attn": 20, + "tebd_dim": tebd_dim, + "tebd_input_mode": tebd_input_mode, + "attn": attn, "attn_layer": attn_layer, - "attn_dotr": True, + "attn_dotr": attn_dotr, "attn_mask": False, - "scaling_factor": 1.0, - "normalize": True, - "temperature": 1.0, - "concat_output_tebd": True, + "scaling_factor": scaling_factor, + "normalize": normalize, + "temperature": temperature, + "concat_output_tebd": concat_output_tebd, "resnet_dt": resnet_dt, "type_one_side": type_one_side, "exclude_types": excluded_types, @@ -91,47 +107,73 @@ def data(self) -> dict: @property def skip_pt(self) -> bool: ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param return CommonTest.skip_pt @property def skip_dp(self) -> bool: ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param return CommonTest.skip_pt @property def skip_tf(self) -> bool: ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param # TODO (excluded_types != [] and attn_layer > 0) need fix return ( env_protection != 0.0 or smooth_type_embedding + or not normalize + or temperature != 1.0 or (excluded_types != [] and attn_layer > 0) ) @@ -174,14 +216,22 @@ def setUp(self): ) self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: @@ -221,14 +271,22 @@ def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: def rtol(self) -> float: """Relative tolerance for comparing the return value.""" ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param if precision == "float64": return 1e-10 @@ -241,14 +299,22 @@ def rtol(self) -> float: def atol(self) -> float: """Absolute tolerance for comparing the return value.""" ( + tebd_dim, + tebd_input_mode, resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, excluded_types, - precision, env_protection, - smooth_type_embedding, - type_one_side, set_davg_zero, - attn_layer, + scaling_factor, + normalize, + temperature, + smooth_type_embedding, + concat_output_tebd, + precision, ) = self.param if precision == "float64": return 1e-10 From 3bc25daa7dec49c3ca19d2cb8b90c72311e7febe Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 18:46:03 +0800 Subject: [PATCH 05/18] Update dpa1.py --- deepmd/dpmodel/descriptor/dpa1.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 03f1e36ca8..a976d8774b 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -462,10 +462,9 @@ def call( atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) # nfnl x nnei nlist_mask = nlist != -1 - nlist_masked = np.copy(nlist) # nfnl x nnei x 1 sw = np.where(nlist_mask[:, :, None], sw, 0.0) - nlist_masked[nlist_masked == -1] = 0 + nlist_masked = np.where(nlist_mask, nlist, 0) index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( From 85c7d6e1bc11665038a73f06a72bd64bb9988a15 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 22:48:21 +0800 Subject: [PATCH 06/18] Add trainable option for layernorm --- deepmd/dpmodel/descriptor/dpa1.py | 19 ++- deepmd/dpmodel/utils/network.py | 17 ++- deepmd/pt/model/descriptor/dpa1.py | 7 +- deepmd/pt/model/descriptor/se_atten.py | 16 ++- deepmd/pt/model/network/layernorm.py | 13 +- deepmd/tf/descriptor/se_atten.py | 81 ++---------- deepmd/tf/utils/network.py | 171 ------------------------- deepmd/utils/argcheck.py | 7 + 8 files changed, 79 insertions(+), 252 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index a976d8774b..ecef21c9d4 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -141,7 +141,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) trainable: bool - If the weights of embedding net are trainable. + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -222,6 +224,7 @@ def __init__( scaling_factor=1.0, normalize: bool = True, temperature: Optional[float] = None, + trainable_ln: bool = True, smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, @@ -254,6 +257,7 @@ def __init__( self.tebd_input_mode = tebd_input_mode self.resnet_dt = resnet_dt self.trainable = trainable + self.trainable_ln = trainable_ln self.type_one_side = type_one_side self.attn = attn self.attn_layer = attn_layer @@ -306,6 +310,7 @@ def __init__( scaling_factor=self.scaling_factor, normalize=self.normalize, temperature=self.temperature, + trainable_ln=self.trainable_ln, smooth=self.smooth, precision=self.precision, ) @@ -539,6 +544,7 @@ def serialize(self) -> dict: "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, + "trainable_ln": self.trainable_ln, "smooth_type_embedding": self.smooth, "type_one_side": self.type_one_side, "concat_output_tebd": self.concat_output_tebd, @@ -607,6 +613,7 @@ def __init__( scaling_factor: float = 1.0, normalize: bool = True, temperature: Optional[float] = None, + trainable_ln: bool = True, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -621,6 +628,7 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature + self.trainable_ln = trainable_ln self.smooth = smooth self.precision = precision self.network_type = NeighborGatedAttentionLayer @@ -635,6 +643,7 @@ def __init__( scaling_factor=scaling_factor, normalize=normalize, temperature=temperature, + trainable_ln=trainable_ln, smooth=smooth, precision=precision, ) @@ -690,6 +699,7 @@ def serialize(self): "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, + "trainable_ln": self.trainable_ln, "precision": self.precision, "attention_layers": [layer.serialize() for layer in self.attention_layers], } @@ -725,6 +735,7 @@ def __init__( scaling_factor: float = 1.0, normalize: bool = True, temperature: Optional[float] = None, + trainable_ln: bool = True, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -738,6 +749,7 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature + self.trainable_ln = trainable_ln self.precision = precision self.attention_layer = GatedAttentionLayer( nnei, @@ -751,7 +763,9 @@ def __init__( smooth=smooth, precision=precision, ) - self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) + self.attn_layer_norm = LayerNorm( + self.embed_dim, trainable=self.trainable_ln, precision=precision + ) def call( self, @@ -783,6 +797,7 @@ def serialize(self) -> dict: "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, + "trainable_ln": self.trainable_ln, "precision": self.precision, "attention_layer": self.attention_layer.serialize(), "attn_layer_norm": self.attn_layer_norm.serialize(), diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 88e97ee3c4..3490b654dd 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -399,6 +399,7 @@ def __init__( num_in: int, eps: float = 1e-5, uni_init: bool = True, + trainable: bool = True, precision: str = DEFAULT_PRECISION, ) -> None: self.eps = eps @@ -417,6 +418,8 @@ def __init__( if self.uni_init: self.w = np.ones_like(self.w) self.b = np.zeros_like(self.b) + # only to keep consistent with other backends + self.trainable = trainable def serialize(self) -> dict: """Serialize the layer to a dict. @@ -434,6 +437,7 @@ def serialize(self) -> dict: "@class": "LayerNorm", "@version": 1, "eps": self.eps, + "trainable": self.trainable, "precision": self.precision, "@variables": data, } @@ -477,6 +481,8 @@ def __setitem__(self, key, value): self.w = value elif key in ("b", "bias"): self.b = value + elif key == "trainable": + self.trainable = value elif key == "precision": self.precision = value elif key == "eps": @@ -489,6 +495,8 @@ def __getitem__(self, key): return self.w elif key in ("b", "bias"): return self.b + elif key == "trainable": + return self.trainable elif key == "precision": return self.precision elif key == "eps": @@ -512,21 +520,20 @@ def call(self, x: np.ndarray) -> np.ndarray: np.ndarray The output. """ - if self.w is None or self.b is None: - raise ValueError("w/b must be set") y = self.layer_norm_numpy(x, (self.num_in,), self.w, self.b, self.eps) return y @staticmethod - def layer_norm_numpy(x, shape, weight, bias, eps): + def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5): # mean and variance mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) # normalize x_normalized = (x - mean) / np.sqrt(var + eps) # shift and scale - x_ln = x_normalized * weight + bias - return x_ln + if weight is not None and bias is not None: + x_normalized = x_normalized * weight + bias + return x_normalized def make_multilayer_network(T_NetworkLayer, ModuleBase): diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index aec2e75e6e..e95af36674 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -123,7 +123,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) trainable: bool - If the weights of embedding net are trainable. + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -205,6 +207,7 @@ def __init__( temperature=None, concat_output_tebd: bool = True, trainable: bool = True, + trainable_ln: bool = True, smooth_type_embedding: bool = True, type_one_side: bool = False, # not implemented @@ -252,6 +255,7 @@ def __init__( type_one_side=type_one_side, exclude_types=exclude_types, env_protection=env_protection, + trainable_ln=trainable_ln, old_impl=old_impl, ) self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) @@ -385,6 +389,7 @@ def serialize(self) -> dict: "scaling_factor": obj.scaling_factor, "normalize": obj.normalize, "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, "smooth_type_embedding": obj.smooth, "type_one_side": obj.type_one_side, "concat_output_tebd": self.concat_output_tebd, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index cfd0a7f95d..d857bc31f7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -83,6 +83,7 @@ def __init__( type_one_side: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, + trainable_ln: bool = True, type: Optional[str] = None, old_impl: bool = False, ): @@ -119,6 +120,7 @@ def __init__( self.smooth = smooth self.type_one_side = type_one_side self.env_protection = env_protection + self.trainable_ln = trainable_ln self.old_impl = old_impl if isinstance(sel, int): @@ -157,6 +159,7 @@ def __init__( scaling_factor=self.scaling_factor, normalize=self.normalize, temperature=self.temperature, + trainable_ln=self.trainable_ln, smooth=self.smooth, precision=self.precision, ) @@ -468,6 +471,7 @@ def __init__( scaling_factor: float = 1.0, normalize: bool = True, temperature: Optional[float] = None, + trainable_ln: bool = True, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -482,6 +486,7 @@ def __init__( self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature + self.trainable_ln = trainable_ln self.smooth = smooth self.precision = precision self.network_type = NeighborGatedAttentionLayer @@ -497,7 +502,8 @@ def __init__( scaling_factor=scaling_factor, normalize=normalize, temperature=temperature, - smooth=self.smooth, + trainable_ln=trainable_ln, + smooth=smooth, precision=precision, ) ) @@ -563,6 +569,7 @@ def serialize(self) -> dict: "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, + "trainable_ln": self.trainable_ln, "precision": self.precision, "attention_layers": [layer.serialize() for layer in self.attention_layers], } @@ -598,6 +605,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, smooth: bool = True, + trainable_ln: bool = True, precision: str = DEFAULT_PRECISION, ): """Construct a neighbor-wise attention layer.""" @@ -611,6 +619,7 @@ def __init__( self.normalize = normalize self.temperature = temperature self.precision = precision + self.trainable_ln = trainable_ln self.attention_layer = GatedAttentionLayer( nnei, embed_dim, @@ -623,7 +632,9 @@ def __init__( smooth=smooth, precision=precision, ) - self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) + self.attn_layer_norm = LayerNorm( + self.embed_dim, trainable=trainable_ln, precision=precision + ) def forward( self, @@ -655,6 +666,7 @@ def serialize(self) -> dict: "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, + "trainable_ln": self.trainable_ln, "precision": self.precision, "attention_layer": self.attention_layer.serialize(), "attn_layer_norm": self.attn_layer_norm.serialize(), diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py index efb4836db7..27b9808010 100644 --- a/deepmd/pt/model/network/layernorm.py +++ b/deepmd/pt/model/network/layernorm.py @@ -31,6 +31,7 @@ def __init__( bavg: float = 0.0, stddev: float = 1.0, precision: str = DEFAULT_PRECISION, + trainable: bool = True, ): self.eps = eps self.uni_init = uni_init @@ -50,6 +51,10 @@ def __init__( if self.uni_init: nn.init.ones_(self.matrix.data) nn.init.zeros_(self.bias.data) + self.trainable = trainable + if not self.trainable: + self.matrix.requires_grad = False + self.bias.requires_grad = False def dim_out(self) -> int: return self.matrix.shape[0] @@ -73,10 +78,8 @@ def forward( mean = xx.mean(dim=-1, keepdim=True) variance = xx.var(dim=-1, unbiased=False, keepdim=True) yy = (xx - mean) / torch.sqrt(variance + self.eps) - if self.matrix is not None: - yy = yy * self.matrix - if self.bias is not None: - yy = yy + self.bias + if self.matrix is not None and self.bias is not None: + yy = yy * self.matrix + self.bias return yy def serialize(self) -> dict: @@ -90,6 +93,7 @@ def serialize(self) -> dict: nl = DPLayerNorm( self.matrix.shape[0], eps=self.eps, + trainable=self.trainable, precision=self.precision, ) nl.w = to_numpy_array(self.matrix) @@ -110,6 +114,7 @@ def deserialize(cls, data: dict) -> "LayerNorm": obj = cls( nl["matrix"].shape[0], eps=nl["eps"], + trainable=nl["trainable"], precision=nl["precision"], ) prec = PRECISION_DICT[obj.precision] diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 15db4f5a22..dcf785d6f4 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -181,6 +181,7 @@ def __init__( scaling_factor=1.0, normalize=True, temperature=None, + trainable_ln: bool = True, concat_output_tebd: bool = True, env_protection: float = 0.0, # not implement!! **kwargs, @@ -233,6 +234,7 @@ def __init__( raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding self.smooth = smooth_type_embedding + self.trainable_ln = trainable_ln self.ntypes = ntypes self.att_n = attn self.attn_layer = attn_layer @@ -251,12 +253,6 @@ def __init__( std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( GLOBAL_NP_FLOAT_PRECISION ) - # self.beta = np.zeros([self.attn_layer, self.filter_neuron[-1]]).astype( - # GLOBAL_NP_FLOAT_PRECISION - # ) - # self.gamma = np.ones([self.attn_layer, self.filter_neuron[-1]]).astype( - # GLOBAL_NP_FLOAT_PRECISION - # ) self.attention_layer_variables = None sub_graph = tf.Graph() with sub_graph.as_default(): @@ -891,38 +887,6 @@ def _lookup_type_embedding( return self.embedding_input_2 return self.embedding_input - def _feedforward(self, input_xyz, d_in, d_mid): - residual = input_xyz - input_xyz = tf.nn.relu( - one_layer( - input_xyz, - d_mid, - name="c_ffn1", - reuse=tf.AUTO_REUSE, - seed=self.seed, - activation_fn=None, - precision=self.filter_precision, - trainable=True, - uniform_seed=self.uniform_seed, - initial_variables=self.attention_layer_variables, - ) - ) - input_xyz = one_layer( - input_xyz, - d_in, - name="c_ffn2", - reuse=tf.AUTO_REUSE, - seed=self.seed, - activation_fn=None, - precision=self.filter_precision, - trainable=True, - uniform_seed=self.uniform_seed, - initial_variables=self.attention_layer_variables, - ) - input_xyz += residual - input_xyz = tf.keras.layers.LayerNormalization()(input_xyz) - return input_xyz - def _scaled_dot_attn( self, Q, @@ -1068,15 +1032,9 @@ def _attention_layers( reuse=tf.AUTO_REUSE, seed=self.seed, uniform_seed=self.uniform_seed, - trainable=trainable, + trainable=self.trainable_ln, initial_variables=self.attention_layer_variables, ) - # input_xyz = tf.keras.layers.LayerNormalization( - # beta_initializer=tf.constant_initializer(self.beta[i]), - # gamma_initializer=tf.constant_initializer(self.gamma[i]), - # dtype=self.filter_precision, - # )(input_xyz) - # input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n) return input_xyz def _filter_lower( @@ -1384,27 +1342,6 @@ def init_variables( self.attention_layer_variables = get_attention_layer_variables_from_graph_def( graph_def, suffix=suffix ) - # if self.attn_layer > 0: - # self.beta[0] = self.attention_layer_variables[ - # f"attention_layer_0{suffix}/layer_normalization/beta" - # ] - # self.gamma[0] = self.attention_layer_variables[ - # f"attention_layer_0{suffix}/layer_normalization/gamma" - # ] - # for i in range(1, self.attn_layer): - # self.beta[i] = self.attention_layer_variables[ - # f"attention_layer_{i}{suffix}/layer_normalization_{i}/beta" - # ] - # self.gamma[i] = self.attention_layer_variables[ - # f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma" - # ] - # for i in range(self.attn_layer): - # self.beta[i] = self.attention_layer_variables[ - # f"attention_layer_{i}{suffix}/layer_normalization/beta" - # ] - # self.gamma[i] = self.attention_layer_variables[ - # f"attention_layer_{i}{suffix}/layer_normalization/gamma" - # ] if self.stripped_type_embedding: self.two_side_embeeding_net_variables = ( @@ -1527,6 +1464,7 @@ def serialize_attention_layers( hidden_dim: int, dotr: bool, do_mask: bool, + trainable_ln: bool, variables: dict, bias: bool = True, suffix: str = "", @@ -1538,6 +1476,7 @@ def serialize_attention_layers( "hidden_dim": hidden_dim, "dotr": dotr, "do_mask": do_mask, + "trainable_ln": trainable_ln, "precision": self.precision.name, "attention_layers": [], } @@ -1592,6 +1531,7 @@ def serialize_attention_layers( layer_norm = LayerNorm( embed_dim, + trainable=self.trainable_ln, precision=self.precision.name, ) layer_norm["matrix"] = attention_layer_params[layer_idx][ @@ -1609,6 +1549,7 @@ def serialize_attention_layers( "smooth": self.smooth, }, "attn_layer_norm": layer_norm.serialize(), + "trainable_ln": self.trainable_ln, } ) return data @@ -1778,6 +1719,7 @@ def serialize(self, suffix: str = "") -> dict: "activation_function": self.activation_function_name, "resnet_dt": self.filter_resnet_dt, "smooth_type_embedding": self.smooth, + "trainable_ln": self.trainable_ln, "precision": self.filter_precision.name, "embeddings": self.serialize_network( ntypes=self.ntypes, @@ -1799,6 +1741,7 @@ def serialize(self, suffix: str = "") -> dict: hidden_dim=self.att_n, dotr=self.attn_dotr, do_mask=self.attn_mask, + trainable_ln=self.trainable_ln, variables=self.attention_layer_variables, suffix=suffix, ), @@ -1844,7 +1787,9 @@ class DescrptDPA1Compat(DescrptSeAtten): Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) trainable: bool - If the weights of embedding net are trainable. + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -1917,6 +1862,7 @@ def __init__( scaling_factor=1.0, normalize: bool = True, temperature: Optional[float] = None, + trainable_ln: bool = True, smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, @@ -1963,6 +1909,7 @@ def __init__( attn_mask=attn_mask, multi_task=True, stripped_type_embedding=False, + trainable_ln=trainable_ln, smooth_type_embedding=smooth_type_embedding, env_protection=env_protection, ) diff --git a/deepmd/tf/utils/network.py b/deepmd/tf/utils/network.py index 916f783050..7918b58d0c 100644 --- a/deepmd/tf/utils/network.py +++ b/deepmd/tf/utils/network.py @@ -200,177 +200,6 @@ def layernorm( return output -# class LayerNormCompat: -# """Implementation of Layer Normalization layer for testing with other backend references. -# -# Parameters -# ---------- -# num_in : int -# The input dimension of the layer. -# eps : float, optional -# A small value added to prevent division by zero in calculations. -# uni_init : bool, optional -# If initialize the weights to be zeros and ones. -# precision : str, optional -# The precision of the layer parameters. Supported options are |PRECISION| -# """ -# -# def __init__( -# self, -# num_in: int, -# eps: float = 1e-5, -# uni_init: bool = True, -# precision: str = "default", -# ) -> None: -# self.eps = eps -# self.uni_init = uni_init -# self.num_in = num_in -# self.filter_precision = get_precision(precision) -# self.layer_norm_variables = None -# -# def build( -# self, -# inputs, -# input_shape: List[int], -# reuse=None, -# suffix="", -# ): -# """Build the computational graph for the layer normalization. -# -# Parameters -# ---------- -# input_shape -# The shape of the input tensor. -# reuse -# The weights in the networks should be reused when get the variable. -# suffix -# Name suffix to identify this layer -# -# Returns -# ------- -# normalized_output -# The computational graph for the normalized output -# """ -# assert input_shape[-1] == self.num_in -# name = "layer_norm" + suffix -# with tf.variable_scope(name, reuse=reuse): -# gamma = tf.get_variable( -# "gamma", -# shape=[self.num_in], -# initializer=tf.ones_initializer(), -# dtype=self.filter_precision, -# trainable=True, -# ) -# beta = tf.get_variable( -# "beta", -# shape=[self.num_in], -# initializer=tf.zeros_initializer(), -# dtype=self.filter_precision, -# trainable=True, -# ) -# normalized_output = tf.contrib.layers.layer_norm( -# inputs=input, -# begin_norm_axis=-1, -# begin_params_axis=-1, -# epsilon=self.eps, -# activation_fn=None, -# param_initializers={ -# "gamma": tf.ones_initializer(), -# "beta": tf.zeros_initializer(), -# }, -# trainable=True, -# reuse=reuse, -# variables_collections=None, -# outputs_collections=None, -# data_format="NHWC", -# name=name, -# ) -# return normalized_output -# -# def init_variables( -# self, -# graph: tf.Graph, -# graph_def: tf.GraphDef, -# suffix="", -# model_type="original_model", -# ) -> None: -# """Init the layer norm variables with the given dict. -# -# Parameters -# ---------- -# graph : tf.Graph -# The input frozen model graph -# graph_def : tf.GraphDef -# The input frozen model graph_def -# suffix -# Name suffix to identify this layer -# model_type -# Indicator of whether this model is a compressed model -# """ -# self.layer_norm_variables = get_layer_norm_variables_from_graph_def( -# graph_def, suffix=suffix -# ) -# -# @classmethod -# def deserialize(cls, data: dict, suffix: str = ""): -# """Deserialize the layer from a dict. -# -# Parameters -# ---------- -# data : dict -# The dict to deserialize from. -# suffix : str, optional -# The suffix of the scope -# -# Returns -# ------- -# LayerNorm -# The deserialized layer -# """ -# data = data.copy() -# check_version_compatibility(data.pop("@version", 1), 1, 1) -# data_cls = data.pop("@class") -# assert data_cls == "LayerNorm", f"Invalid class {data_cls}" -# variables = data.pop("@variables") -# obj = cls( -# num_in=variables["w"].shape[0], -# eps=data.pop("eps"), -# precision=data.pop("precision"), -# ) -# obj.layer_norm_variables = { -# f"layer_norm{suffix}/gamma": variables["w"], -# f"layer_norm{suffix}/beta": variables["b"], -# } -# return obj -# -# def serialize(self, suffix: str = "") -> dict: -# """Serialize the layer to a dict. -# -# Parameters -# ---------- -# suffix : str, optional -# The suffix of the scope -# -# Returns -# ------- -# dict -# The serialized layer. -# """ -# assert self.layer_norm_variables is not None -# gamma = self.layer_norm_variables[f"layer_norm{suffix}/gamma"] -# beta = self.layer_norm_variables[f"layer_norm{suffix}/beta"] -# return { -# "@class": "LayerNorm", -# "@version": 1, -# "eps": self.eps, -# "precision": self.filter_precision.name, -# "@variables": { -# "w": gamma, -# "b": beta, -# }, -# } - - def embedding_net_rand_seed_shift(network_size): shift = 3 * (len(network_size) + 1) return shift diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index c51900e9a9..8dd2be2b6b 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -478,6 +478,9 @@ def descrpt_se_atten_args(): doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" + doc_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." + ) doc_tebd_dim = "The dimension of atom type embedding." doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." doc_scaling_factor = ( @@ -507,11 +510,15 @@ def descrpt_se_atten_args(): bool, optional=True, default=False, + alias=["smooth_type_embdding"], doc=doc_smooth_type_embedding, ), Argument( "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero ), + Argument( + "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln + ), # pt only Argument( "tebd_dim", From 756779af8b6285ff1235ba3d62e6882cae7e936a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:08:04 +0800 Subject: [PATCH 07/18] Updated docstr --- deepmd/dpmodel/descriptor/dpa1.py | 6 ++++-- deepmd/pt/model/descriptor/dpa1.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index ecef21c9d4..d5eb863b6f 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -155,7 +155,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor): attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool - (Deprecated, only support False to keep consistent with old implementation.) + (Only support False to keep consistent with other backend references.) + (Not used in this version.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -180,7 +181,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor): concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. spin - (Deprecated, only support None to keep consistent with old implementation.) + (Only support None to keep consistent with other backend references.) + (Not used in this version.) The old implementation of deepspin. Limitations diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index e95af36674..fabf51ba84 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -137,7 +137,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool - (Deprecated, only support False to keep consistent with old implementation.) + (Only support False to keep consistent with other backend references.) + (Not used in this version.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -162,7 +163,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. spin - (Deprecated, only support None to keep consistent with old implementation.) + (Only support None to keep consistent with other backend references.) + (Not used in this version.) The old implementation of deepspin. Limitations From cd6b47c1297bc0e4315d963d92b332b52dbd009e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:12:28 +0800 Subject: [PATCH 08/18] remove DPDescrptDPA1 test from pt/model/test_dpa1 --- source/tests/pt/model/test_dpa1.py | 62 ++++++------------------------ 1 file changed, 12 insertions(+), 50 deletions(-) diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index 3cec5d1876..7a08ecc826 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -5,15 +5,6 @@ import numpy as np import torch -try: - from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 - - support_se_atten = True -except ModuleNotFoundError: - support_se_atten = False -except ImportError: - support_se_atten = False - from deepmd.pt.model.descriptor.dpa1 import ( DescrptDPA1, ) @@ -34,7 +25,6 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -@unittest.skipIf(not support_se_atten, "EnvMat not supported") class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) @@ -49,10 +39,10 @@ def test_consistency( dstd = 0.1 + np.abs(dstd) for idt, prec, sm, to in itertools.product( - [False, True], - ["float64", "float32"], - [False, True], - [False, True], + [False, True], # resnet_dt + ["float64", "float32"], # precision + [False, True], # smooth_type_embedding + [False, True], # type_one_side ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -92,37 +82,9 @@ def test_consistency( atol=atol, err_msg=err_msg, ) - # dp impl - dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) - rd2, _, _, _, _ = dd2.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - # dp impl serialization - dd3 = DPDescrptDPA1.deserialize(dd2.serialize()) - rd3, _, _, _, _ = dd3.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd3, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) # old impl if idt is False and prec == "float64" and to is False: - dd4 = DescrptDPA1( + dd2 = DescrptDPA1( self.rcut, self.rcut_smth, self.sel_mix, @@ -134,10 +96,10 @@ def test_consistency( old_impl=True, ).to(env.DEVICE) dd0_state_dict = dd0.se_atten.state_dict() - dd4_state_dict = dd4.se_atten.state_dict() + dd4_state_dict = dd2.se_atten.state_dict() dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() - dd4_state_dict_attn = dd4.se_atten.dpa1_attention.state_dict() + dd4_state_dict_attn = dd2.se_atten.dpa1_attention.state_dict() for i in dd4_state_dict: dd4_state_dict[i] = ( dd0_state_dict[ @@ -152,26 +114,26 @@ def test_consistency( ) if ".bias" in i and "attn_layer_norm" not in i: dd4_state_dict[i] = dd4_state_dict[i].unsqueeze(0) - dd4.se_atten.load_state_dict(dd4_state_dict) + dd2.se_atten.load_state_dict(dd4_state_dict) dd0_state_dict_tebd = dd0.type_embedding.state_dict() - dd4_state_dict_tebd = dd4.type_embedding.state_dict() + dd4_state_dict_tebd = dd2.type_embedding.state_dict() for i in dd4_state_dict_tebd: dd4_state_dict_tebd[i] = ( dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] .detach() .clone() ) - dd4.type_embedding.load_state_dict(dd4_state_dict_tebd) + dd2.type_embedding.load_state_dict(dd4_state_dict_tebd) - rd4, _, _, _, _ = dd4( + rd2, _, _, _, _ = dd2( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), - rd4.detach().cpu().numpy(), + rd2.detach().cpu().numpy(), rtol=rtol, atol=atol, err_msg=err_msg, From 4ac5019a26cdde2e0c6e16243c68a7d5296d715e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:15:18 +0800 Subject: [PATCH 09/18] rm default value for version check --- deepmd/dpmodel/descriptor/dpa1.py | 10 +++++----- deepmd/pt/model/descriptor/dpa1.py | 6 +++--- deepmd/pt/model/descriptor/se_atten.py | 4 ++-- deepmd/tf/descriptor/se_atten.py | 12 ++++++------ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index d5eb863b6f..cbbc124267 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -571,9 +571,9 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "DescrptDPA1": """Deserialize from dict.""" data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) - data.pop("type", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") variables = data.pop("@variables") embeddings = data.pop("embeddings") type_embedding = data.pop("type_embedding") @@ -716,8 +716,8 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttention": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") attention_layers = data.pop("attention_layers") obj = cls(**data) obj.attention_layers = [ diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index fabf51ba84..e4358de0dc 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -415,9 +415,9 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptDPA1": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) - data.pop("type", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") variables = data.pop("@variables") embeddings = data.pop("embeddings") type_embedding = data.pop("type_embedding") diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d857bc31f7..15cfe44962 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -584,8 +584,8 @@ def deserialize(cls, data: dict) -> "NeighborGatedAttention": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") attention_layers = data.pop("attention_layers") obj = cls(**data) for ii, network in enumerate(attention_layers): diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index dcf785d6f4..e1eae22e72 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1647,9 +1647,9 @@ def deserialize(cls, data: dict, suffix: str = ""): if cls is not DescrptSeAtten: raise NotImplementedError("Not implemented in class %s" % cls.__name__) data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) - data.pop("type", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") embedding_net_variables = cls.deserialize_network( data.pop("embeddings"), suffix=suffix ) @@ -2041,9 +2041,9 @@ def deserialize(cls, data: dict, suffix: str = ""): if cls is not DescrptDPA1Compat: raise NotImplementedError("Not implemented in class %s" % cls.__name__) data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) - data.pop("@class", None) - data.pop("type", None) + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") embedding_net_variables = cls.deserialize_network( data.pop("embeddings"), suffix=suffix ) From 42a32d2fc0d8ee4f2ec4cdc2915fb800937bc336 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Apr 2024 00:34:35 +0800 Subject: [PATCH 10/18] Update ut model --- source/tests/pt/model/models/dpa1.pth | Bin 15469 -> 14399 bytes source/tests/pt/model/models/dpa2.pth | Bin 179745 -> 158910 bytes source/tests/pt/model/models/dpa2_hyb.json | 4 ---- 3 files changed, 4 deletions(-) diff --git a/source/tests/pt/model/models/dpa1.pth b/source/tests/pt/model/models/dpa1.pth index 75acf2fa15d874dc7c63a0fe1bbc3653267b2c4b..47766b10c2e942f3072857e8bb9334f506e5fdef 100644 GIT binary patch delta 3304 zcma)<{ZCtE6vyvv3oYGC%ZzLxi~;4PV_V-CEuh;Km1~UMe!&j96;=wTJR2CD zWhyLFafQdXMCTBv&iOW1-*t-r0sdxUO#Dq_qAu|l6MykMy|?XiZ?A?DAnEsezUT9t z^N?PCpZ#_!S0AbsYK3%SW_WZe?)Q$yN!&AaDm4(A(R zEwv`MJ7&dqrpaU`elh{(rIos@B_`rqKB$VB@%PBoX{eTN5wH$Pf@G2!CChoRNwbng z63;`;g^`iTR4Oq-#wOF#_9L0mL?$tM_t*%rgMGIQn+R-{VN0yK&}4KpaR#;~v1`(t zRNDB#ku1bJi}AEUbclAkSCkoE$rym`$!dC0+=1FvF&f~tF4%=34gyUq#4!)ek~j}7 z7m^IYMW9uNwpc&)Z+6is`!@QssfMOZHJyCKe(1mmHvu<`sFG|8uv@Z8;sSJH483d` zYUl0U;6=NQfREX2^WbM41I$`)?)Of_NoH&of|{cSA*9sws5QN;rjOM;Z64Ibhv1N6 z4Xf6EW{oiG-)1?;lk~tnXq_c6Sg=Zt1-MsoFz~}@$MZ`B?vvsE*g-u10eaJ7)pY8G zBUp~3i?VVY-4N>}vrd0l9@tatWWxRVeee)^aYixbEzIbsGNWV6pQ;+_(D?l@teoOu z^%QZIeVloJt?KuwLpo6!QlUuzPAWx7wP=hLjkBVo)q`10gD#jkoMZ4$DDY1bcuIzOL->7qdVi6A zLE9k+=kZTw-MXVd{0xC-Wmw!0vENf5eomVgf{RFOA&_IlmU&oW#LEgXCGflqFKmoB zP$Yg)+o1A$p+g-83 z;dJ)DugF`;v#}0?5`&7u5Z^8$CDnajRrmc4W4WfVyi4Fc8Qw3kWXrd8xbnJ|7XCqU zabqm%LO(3AsGGd5Zt^3>azkOs6ZkkQ!zU$}|79ty^3&v{jj^b!`mDsF?&x!MM_(Yz zF8ES``HH~TGTc1viAq+R>4(tOv|SL=c6x4W!*+*oyY_?6>l=vF3iQwBq5S&Ru$6}E zI?;5O&g37691`eKotLiEIrD3E*Q`67LPT%EsDnk9Bfr)-E*iNotL&n4P2)zUfW{PV zHZ*ndrsWnx&wk#MUu)@OOaVQ4_dvwM$UQXN+Gk`4mJDmHhJ8YOA3E1|7*ksGZF&YG zjjT=MjrLB%+4O2tI>rrlfurGe5qAG(HZb96$|gFuceTMGY|`Q2XspTuPyXP(lxUEG zV=1RBdyQytFl7tHMH$54Y|0+olYs>%Q?}&8e#1lIbjl_=*S)&aDcH-0I)fin_((Mn zsa8%!4+UNA<@+z-GoSuddjyw1ywJ^JumBUc@-keiw~`AUVn!ppx@$kAr5y0cf``xY-x zRLD^mSA892?aZC6ocvd>e&$hpE-9^~yH76<;-^wcp6PQnYs1$&%W~nF-`h=BLJc~E z+@7+F{!`H-x|fp=+@VKs?&IWag&g&B@^>g}FU??plg)Sf`lKr5QZ1vvUQwx<{p%t8 X6kU@FK04$1@+CC8n;pzdr|CJr{h$9i-?!gx z&y8K)I$hPbn6WS;$rzvL-@*?M#f5Z7EGk6pWA`M20W-59_nFiU4ki-`eoz=rj;2f3 zr(%4HkF6aZ6iR_<4FhY7=oy9?;iIFV+t=HaDY4liCXA$oWGXtugMJXd48}{>CdUU7 zd>@ULf+4~p9Cn&#}9qbWfa=DigpGhJ_Jh zRy`KuC*YbmmW%7-RLu$Han>YR*cz6RTG{Zmg-I7&A2&&R*(I1e#)jGob8e`Qo49fT z8c2DTE1!Z!uAF11pb1M$)w({Jln0ifR7OXn4(fVkQeNmlDYF0$DrKGmCub&7E+SQ4 zvbuwg8#-m)JE4nS*G;Z7?*bpk?gBr#CR4KR#zMDUa0^9ya_DM;1_*k&MBa2a+=@;^ zdg63gFK^lh>uJ-WD^2&(ro%+)4}D*IzN!u+k?JEKrThTlFB(=?=O(e~$|SOdTrh;DoklB~4jR`jlDTk05|f9G*4xOINDCl@ zVLZ};1~&~S1U}WDh;HLkX}iN7!fZl~lGI$<}R z=!2w!j?+)U9&(3F!9%pd!vgFL!y{L1z-iAm@aVjJhOBTGJcjyu0j6mOsK1Zu@2C2b z0FQ^^iJ~18HsH!P@Z@|2QlJ|SV1Wh!4w3?TvW1?aTO#u`t?-Nh&xYa9RXcFe4xW?i z``~#xpu@BQ)PI5MAEELu3UD+G$Hql?DY5_^^^fpElpK3&;N|QZy<&@Cn1NSCIG#Pm zCU7(1$p^1-Y&vocreb4JXMa=>_)&5q!s{Zu5zi*;rBiH^1>O|l1PwG|pv3|wML0zR zP7HV~a9V^J8t`FYr3Ky+;cXfSU|^Lsiihn)RN&zq5zf$r9}`}w!CsNBEW#%=7{~=H=Zk(S!e=z;&m|X9tDlST1x>6eNEFz;Ai|dva3bKzz*i!CO#=&Y z*0Ng~z7gSDnk!f4><;)&gzss#TAeLC3}8-#A1GO?ATz3S1AY|Yr{b`Zb;8dgT>O8m z3w{ydR~73ic6jv*@S6y~XW{YmfIY;S%*zfRNVst~lAr0SW-C)x&8#MbN?Aynt=X&zx$@9t ztv2M&Lp^ockSFh{-cLL&FKXNy478&0HfgfKuL*VLp#6>7)w=SQW*cuuq1vL=x`Tlm zuv*njPg9%roP1ZLnm22$3XUneDl?NUf0bxKIHoKl&DxICSQ%q6PdFtVGqb%l!D^}C zh^R_WyQUqS5M^-Hv~+MnIh{I91_zX7W+s(Rp zFT5uf-?i)DwZl4w^!-sq>2F*Q?X2` za{?Lpdd;br%Ug4J>nP(-=RVXN4$mDPX6 zdI18X8N>_V9aKQE34)hZ`+2PxYpYnT0%FaK#l@EMAzrY{l|b zB{RjM#bWWIGcQ}1EYV6vB(&0`)-|DZn^OTMz0fmY3}RT)X$hYqp!HZHsr?q}vMhDa z8MmpHCHUSda)p|Y(Y|7wp#16#L8c% z^#%Wsq*g)vWsTmoT0iu_vsUZR+yj!@z=SqvP9)3SsRWeZkml?roE8w_S|Qmc}Y zU$0d+dal=oFzYEv?I#KC)E|vFkd1h#!W+_tfww5B4d)T*yjB}Q1HD!o$=s(UwbK*Y zs2`0uP$MI*QJCws(O~YJ)M|;j^YvPQBCccBU{VVuw0bk*8s|9FWg~7-cpJ3v5}&@T zQ$;sQ*4U&rE}@M#!)tUSlHpBIXu{e=?s$yCE4p5rMBz>5@XknTXC}0>^yilLrU@{3 zo|7~;1Ewq)q{pJtZ;Gp*Jo(h@$(LmEI4zUMv+1LK4+3Ws)TplFPCrOPS>Iq;^F@TQ;XklEg15ie;Ad@*I~J`$(WxTLH2z zNo^&Oc^bQ{)wISgv`PdMtx9Uig!c0}m06-Xne{7Gnt*l{)2z(R869*vPSOVj4ZE-I~yD+f+Cz z+^ROM*9`9V=3KXL=cZe@=^aV!&V+WCet2f3z8syorLZ`Bx609=Z3V~hq_&MXisyoYJ&&WwsqPY z+_po;aJ}{>t?@r_ySI|s9~0W2bbo)Gzi8|g%XoiyTcrqU?=Z!o48^-d@n@!ZFR8ts z(Eg$yS=3vvT2iX7xV%XJqNH2+1C=ABeFzRrki*0wC&&@*!cpe=Yf}4LLiHVpP=oMr1p1i+ik7(kH&74(oebVXG!hzg!YAA6{_5{tX%IXEewCDQZ#7CKoLu7 z$C(1F>?>YnUo*uwN$sBr?O*!*CP@+N+GW(YDn(fP4is3ZCy2sBEBxO~@jWQ?@!e`; zx!vIhwP{2<$-9FnfRdu-v_S|DDw08Xfr#XhZqZr1O$Z-MDoa!pl8s^zML^jwL@_GN zQvxLWBT9jiqASXT=%%kMuh!QrE7F&BD{2g zn7J2FQk0`ih~6GutEnRDbNSLD`oOM|2Si`8(fkz^KpZy6`=P?*{ehT#08mm4M41qS zJo@&jRr>eKOU}w(-$W&jsPuJWFxh;VTcQex=xKV1YE+nh2oR0#V?E0CGlmTUBkey$ zO+%?+Rv*%ZBp4LK$R^2#1Lep@pu%J$fjHLFfRf^LlnF7)qaPbxrJr2UOJClja7GP` zDsxDTCL8Tcq85mU*tb>$P?7tSr~_i=AW%|-P$opZ{#=g)Vo6K1RND}I? zvhtqLU1BUONihy(LX7w5o&b`xqEH{M_0&JCDjGL|nkqeS5EIEJ>0&?}ABJ%fD%|tQ zKs1aUW##&!5rcAsv3|a$Db!dY42!8`lWCs@M1-_+iL+5*!s$TV;YOgOn1M1Onml^- z>}ox=!l$n*E1lL1qskl+Gsz~I<3KXh>%}>M*q1~C@60+2C@IcGnGomchsy$IpAW2( zMa68gNitG58L|riiR`C%XR-@{lHwwi3Gp+Jeq>~o{zK31`r!G+O><#XSVOg99@%^} zDdz(btDKYzP~pyB48+U}fs$eo%7j?#(c>3Y>nD5m9d`)~Dr-PoN;cXG#S$Q9?Y352 zhKk$}#Zn-!>hsIXb$!4fgjlMt@8xe=P7SU33Dt=eWTQ<{tOQ~*>`EFcyqgFh4sjJw zQY29(#LtnRv#9gLT&%~sb{-(E#1WM~D6S%#PmlJl_LS%s_Kr@;?vCObhj&Pwu_h@Al5^k!z`u(yMxL1KH#r`5PeOp{-inhzj%E1jN0%87L`k zL75Qiut$!l&FwSmVN{tL#0Ij-1=0dU%w6VW6aV z1Z6@z%9~=3ib8!u-(E9z!Kg6T*NWX_lbhlmAUURu*c2Z_g*U~=ftYzOP*OaBG9iA8 zO|e?~-Q%gMRG9sFAZC97C@Fr2G9mugqc0dyg_tFKf@Gd}5l2+|pm>RFa$dX) zMD%jkd<7Mze-()7Ujs^t-=j>3*FAdKETS*c5A-XUb`VCDIV9d7n`G_)l7U|@-UOtL z?hklp*0+F?;*Tg3;!oH$X`GAB$w))-HtZ^Ky?BRglK2o%hWK4TBK|YpnfN`Rq<9}? zLi|PUpqsAt>#4DYO&`Ffk~WAB$>yV(d>DvGtd=12e-vA}WKT#&czj!0vGoYL|(Qjc@siWdMvibB&2GJ(^ zn?aUM^as0lL#;R|z2zoa06jyIu+T(>Bk=+u65UrBJvS>fVNy=p5D*Foi5_H=3&;;dsp;<&N$`=P?r{ehTz01%WWN+?aQRGRT$mgrAc z7SE`JQDqK^!DRDEkqJc1EHY7H<{?1Ld**+*i_BpGWu)Rxg9Qo`B^0Lq9GDJQ7wVP2>PjNBsfJowuPO}%UWCHQn zlE_3QBQk+F%ped%CQ67*uYNpM1@=@oCDm(~d zfS7$O5CkYn2vDySpv>NH#snBu=7^X`Hpv_V$}vwug_$P`^h5Y;sym10r@-qNp(YbRcGL1cDMp2_@>4d+Oopo@Y11psE8_34;6o`0O zo}$7$Hvw^nZU%xpMG1MzJL|SlW5?|RRtf9GZnCi=1>#{OMT$yBkpfXy<0H!TZ(@V=+F#Q{#GU+^F zD8s83Uv??(Ne@zUWds_;8)TD|9Y8Whh)zH$I`Pi&y#)l(i4vldH%V~oEhCGjy$z$v z8W!)6O|l*W%CNo*C>1B(ne{y&C{C16oY*5H<30dZ$s*!IviYRa1R^q4ny7HEj{;Gz z7o1+M&$r*r3@U4b z7)&;w6pcW{%Aye!W*q{=tfv4$G@^uP>?H5T*NyHyZYT^YYgh~;n`9jhlw%!%3bT#` z0;_(!wp{;ubhxR88YojCB1V&qg(47-APGfOGC~oELkj{yD58W=>?DQaoJbk~iFY0)v~Uz+j^dh^b_hDW3*Jv@8lyVcO|HOxp+qQHTT01-bc zMN~M=n}C>K2ZB;W38jeLvswrJd7008T4A=42ZTXJIW;!}aquJ+QDO2eKumrI5T0wI z#B)vFNvDQ;HQx<`jkZo~C8JLYM}zbE!5Yvl;UIPqUK$&+t@>5HyK$=0`WkT zmP93^C4rcIFA$z?qQuip-elKr?4xfxr?9CFMjLg#XeXo0;XWXyCS{2V2fiPOsSf}_ zS)zop+!Ua=tDB!e)kfMNo+YCsO#|gfpF@R7p9f;n7l0rrQ9@AiPCGZ+e-=O1gw@6! z7B7)e&Wo3Uh?}(~D$M;V5OcobKqS%U&GKRxP1 zg}0`!ftddrAV^Y_kfgldGJp30-@<644vO!{=+o;b&~`g}f@Qn?!Ra0nC#5?})Xp3V zOH@=i6fY1$(chksyCe6(WIJ)aC?q4xR3PdT%T!dDs04`nR0;%{ijqE6Khj*SUvpc@ zdD&;3Fx$u*L=Q5`CFKVqa#pISaISj-F?lZ_C{>hDs-5M-POxVlcEW1o4vW5IlvA?; zhyy3liVAc02V(94KoG4cAzJlqQ}Ov~Yl*&oVsTR?Y&PbI7)(Z=6s|y=W-wQy!puW} znE4bS2v?N!w6l5|G0Uf&%|l_cF-OHPGD_y*Ksn|Ss4(+LAZ9)d2v0vz;^}8+`T4qk z56Uiv~Vn7@rhIJAu-1*5s9N`&2 zP_ihYWIIbGo1y1Db_$Nz_=9398RdkS21NWUX;ESR=|Ifi2n0!s5|XyFl(hayMf$Uo zOQOv%+NeWfCK)Al97u)^F$*ZgEZ#w@d(JG^|5i3g2a~B^si8U$^lGWw)`1tMD3uc&bUF9xDO{b!Zy zz0a&~zJwa6!G*=8WMrud#4IFLQOQVEAP(?yAV^h|kgA=fRBhQJpJp~KhtWnI5i7{Z zautZFNv@*8Ta5r>>Qz9Ht0?JlX4P3^#$O3+V~mQc$S60j)j-53H?ONvVa8toQI~I> z64g7-DxPo+HBh>YZ!4@Jqx}A&NC6U!+)S@Ug=ww>BAQ-Pqk4Qw;rO-GLZu0a8_3AQ z6(|?Xji@loO+dubGPPX)sCkgyaYwO!a7xkm_0&*hs}mc@C|64h5V5gLMTN7u35eLr zrj_f-8p$>*^9+-ks(e9VkWo&$%|IL-R^RQYFy9s+@Zp1(JAj}}Q9_wwy9_nm4Qyi! ziLGSxNl^+!j4Vn~Va9uanDJg9h*FgFG_!hIL-PZ`HpY7KAQ>g&LqIvk9jGwlP9SD{ z7zlb4CG;q7k1co6cOr|LcEM<)ZVw-302CwW`;haOmW1<9~Cc=QO=8(fry`#DJsnWDiHI(1_Wh_ z63Uc!&6SO%`pM~q(+|R^Qse6lZ;(+^cL2%Ip-BOyCdE6`z6AtLiV~U>`y4(ZTG{BI z{x%FY%7A!>jFR#YP=@kdK&eOZ&Xn%~L64$@9_6ib{&u-_Hhlo2jkZpFNJgL3q(DT= zniLftgrh)A{Z}ApQk2l7dfAzn9bo>&jBd>z!)9X+iciSMk`#!UNs^+Hk)%M({234= zDN0CE-cILlmrq2S#h0+!xI^L?8Ci}3F*nIkRCvSr8i={S0fHPwNl!?t&Z4L@cmLnQ zXrrzd-;vR$?`|eJ`et)f@77#A;RmN@gE%QYS&4SxC|HT2!clmE5QQF}8P(S{7dHD) zx1BI73dzVa6o~r5G87eNDFNcXlmbD9qJ#|XBA=FixZQHs-yKJ6{1MTEjB*+Ifry`# zC@P%cox>>{6-u8H?<8VZArHXw$P zQPK_v%F&KMg=t3uG3{wUP@O30GxsC$vCTEWHpV(Jnv5(qfp|1YY@(78n?TGM1cKN^ z39;ElK0p2N-a>s~qJ*BIj&HyL8+}lO$taT*0b+WLLlhMbcnlEeH_a;7I|l~oOxVA9 z0<}{IH6$jIQIf@ga%7WGVY10U9P1fCkew(YJG)5P$<)0vpRVJGjlW(@C8L}N(}0Md zWhW}kKOKnq8-XA@Q9^d=)#q1J?P{ zI?NFO1&qzi%Y^b{qYp6WC4jnBCc%7>@pFM-X*92J+6krgQr zGm|1kC8J1zD9*C;p=-~JG%u$HN__af!U{67_yl4W5}&B>1|xttzEwaFpC}>fXX5R}0-HDQP=c5-)Otb;p7^9+{j50I(fS3_u`XnkG?|vX=JOBi}i4uAfC4QrE z!OsTkgCFTj--Ay)1CvUR?)6{dd`i0NMgg4{$2x!Fa2)OhfMV$0`F2XV-Df;#aA8RZE& zfMom-oq$qw;+@Ig0)ps73DLLjlJ*c#hW1@RDLnDcwC@2y zc%r1w+-F`CZTAEu>Peq8#XpA4 z##}ExAtUQeAZ8}LiAqLq0x|PvK+u~gp*Km;pgK{K>O2y!)?0SX8=L5qJ&~a`SY51V`0&|zdx(|{zU}^^lItfw?{#x`o|GQ*G`$-zd(QUh5mcf zkMAsm!x=V*z?8`o$zim9_yzyo^!}ZN75;+Bb|=tG$a?kAQ%4g0$_n~9oZ!`)ys+~B z_iK1W9UeGaeU`_kr(SgDsmtOm-z9HG=tnLIulO833# zSC)2ib!$-EQm>Eny1Ip}Zksxmx;jOyPJ7cGOZ~2HQLEdgqf6Zqf`POoB>lPTI0ejv zq>s8LB%mZ@)8E`_a{1Vs`kVV`29OcUXvRLy*#Rf(XpwmacXmiePR_%QH=b$d5Tqv1%L} zbP%awkM`Ipc}6p+>||}HdfX>qP}$4c)5AUPGcc$mr}cD?>u!dwSyMCBeV{@}jX6;5 zX^^>Ypdn=q8L9E^nGPvy$cRnIb5=Dysfq6C3FXq$7RxiVA!Pxzrzhp}Q&v!W>Wn;d zFoaZP=it6G-REEkc{0bWbxNK-B1@*1k(%al-R2NKZ*xX!vHK|0=dL(;CTTr#!&5Nd zrMn(!VUM=yc?PIn+1A?Ajd><%y|S^jr)K7vr1eT-S`+T`rCv=-8g8!BuU=U}MrvUm zI%NeJu|@8wsaI0ddWm~x5E^HuP0KT84a&aOp1LHTpR#z`V@vY+DXXV#X`VUSfK+AY zXykJDIog0US!drB?jzBFELq&r%iO16LvEAHGX)!vpC6xs4dzuRwjy8um5r@EEieb& zt~xlGxwW;YH|C)W+tPDoo~asE($jji``ijE$!Sl+&2_4Vm4#%aZpuTatRo|KvwMQV zYJyU?xF;x_OHkXoe50W(q4xCpe16IrYEQLb7UIHZ`wMZS17ReqTY<4k_nbtI#@OW> zSQ@go0)=glaK zqU>nxvD@?PkP&51Yr8AYtcxi5Fj8CH=Sf6OM{HZ3X%|t}juCs@T@}E~TslVTCHGku zMOw^pZ`+$^M54;VX-_|q&re-CskS^L7gd%{du)H6p^K_(C;gQB&_&gRrJnYTEW#{v z+_pxQgz4$$KXtjMM;FX{`nDH6V@MCz*=9Z6kkrtuuQAdud%kvA7j!RTu8Y=JT`xU! zC&HH=OlXkU?~UjdgpcRBd7`@yc4$w%hNOVQ{`jE|gJaHw_Soz0du<)Y#&m0Y6B%(M zQ+8w{^_J`8qZ<#Ct3CEdOl3DR+Z_GKv(zOY7y>gNtsi;DfXSW?y6?!OBlfXtI_Q?e zOox&B#5Esu)4{{p_ILNur27p^*V6yU>KwfN%t70TibW;Q!=Dv#7ORvG3idB;940r)(|ty4`Vw zl%-;%OTBIz0Nr4iIce?Ybz9(chrvVMp6cdxUB4k^m9(dOz|CcuAaHY;7_lDS>6lDI zYzd@0j7rL2WqYd3>pCUr9wRp;dwSibB;90KdRFe*eY(vsos3v-*Z$K@2KV26O47{* zAJ4M^(tQOxw5R*H42|PB&y?aQG)oYyBZ5zYcGTT^ctm|0P{e;;$ z_nAre5_Vh*WgG#p&7Pr*g zJZ{Q*X^+i!pPF<>k)4`XEpVNhbW36BU;1L#QJ|X&)2nrn>%yZu3Z66Wla6jB_;{X4 zM>i4d(4JbHr{l^7)*f4mN#}AsftO_mi(TrT6S{ve2fJ;#`^==<2h+<)t#qACQCmhd z_nAre4klUab?#isk}=ZPx#ooK8!S2bwQEl3rXf3|zjhzesIpkvV{7xcDXXRJH<)p_ zM%y)E6eHig;NGuapo11QcqRU3IsRO_HDV)nqj$(q>&KX7`0Y^jC;R8#jdeGAPwn4B zIruMp@R>XQ=8fJ{=Bmd#t}u_wKNmkb5O4J0*ZvC%ii()2W6i#Mh6B>e&!$Y-k#saR zw%z0%bb%V<>YqE@YHRIhS#ajE4cW7(%{O0ZHaFhC$=lC(dXKkjf7K~+mBp#*W^cdN z$GlxnQGNDVd@_Bl4H}Q#>>X$eII`L#%Q}Ya0e8IA`5pR2bb6`h>A2c#ZhVedjI(cX z1l;l~i&O0_a9Rf^Tg3j9>68s5Y&>_1cc9I&!^JUTOk3w2Xd9=NjV7bB}#MXH$j62skLOWq`%7hlEH9Bxc+hA7QYO-Y=9dqp354tUj;jfm_OC8L8mgYw3 zdW_SO^^U=ex0>fLW|Gq*aI%e&f3xY74J2e7UhjxFb-UBC-k$y!-+fM@}@UZ4s}&({##495IgX@eUlKp7Ml?qw(iW-a&KJ%&xx6 zL0D(+*m=KQ`6u^O_;;DDjpZ$f`87-)Tg)B1O{YxEbphjKi+A7}HRhJRPDe*DpG5}^ zieZH5r8f7sIhzOVeOYwmkNsm90(zD81Bdz3VQFt%vk{$nZ=<6#EiXAwQg8Iw3H`a*0tp6M5BaZ>-;h&eZ7MuC@X1c+GUm z#u&Aqq~YI&FNRT~mpXS^UUxKijQ84mu0C)aE$LNOFdJp=Cw|b`+;I_T9W>>i?u#n@ zhS}V>OGgjCM-OeA+o2AN(-Cr-aH}Jy6@M_DGQ$uI8rRglsdHW4Le}CmS*^yWC9)I8B>uJT+;AgFlLEZcp(_coof5j*zuX@8ya zV)nwJuI)%4M{`Fv{`LAbi?Xe4dwNG-XLHA%zUL2-yUg|&i`-93KeM?p_fGWiz@3hq z#0Qv8nI1OO8OQE)T!PaBosNzR)Yv~C-r>Hp92(?o?$G?N>=^%5Zgr`b;EKU!dt>=s z=!}b?+h=jgcBVmIQfwE3NR`Q!b#z<^o@{$FpJ;Yz9x-0M+i@X?4>8YUTywX#U(Lo* zUF)*xQYVXlsOg)DDQtA#>PSS#FsGkmR*cv>_N*9Y0=?A55g+Mn?zlo+TAa=&?EI&h z?TvL?k%&*XIu=*QSc}t}Pu+vl_Pf^++v;g1TkE~>`5Am{q4}qqKH1Rfjg9v@uMtg7N5|&+`8!p=j$z8s z%eD;SW^?2B6wHYG90B{!u{infgVP>3*&^;(U^-ar@%;v^t6wI^-91+KtTbxEe0H={RIh;DKv^cGLzHTLgxB)bLw zUnEPIi8^R}vD1<4_?-^FkWu@vBiZS@9Da^upE+^nr0k_e4f)XBj@FK3zjNS+I?Lc# zlU=dRY;Np&7>o6pJ>F7VviDh?zIvFKfXyee-9i+4#F6WLmg78XA>+A69J!9)W71^Y z9Jzjf+ZUa)H#X{E`0q1Y8~=R-xqd$OW&b>@Wx@#>uRrR@ zwf|vfWyy`>sPwAUhSUimZu!89l1{K?7S^ct`9wJHaFJqLasmE<;eBw zXDm){?!vGv-|bxeX^T_&Zbz~^o->d0u+sZ;5j~sp>qXT}nWOsb*@N*>l znqxD!9!J;cW$PaPp-wjBOf-&U@7T0$tfj-&WEXsDHaD7|K#zAk;mBhJ{wGiN9L7d+ zs`;(MX%7BnPtNI&PdJi&!g8Dg4a>U>+Y)K{!lcQ#)$&b-+V2SdD^Tuqcm5V>y;nc_ zpF9h~>iZ#ljWKPGw8rs|gtCO069$a z#LzE#sojg?zYy74JJLGthl5>czS9e>t!ecad2`K;ns)T~#&$=>I;>6_Id$LXaH=S= zFg?-kNb5e!aUQmaF=?M;4YzbNX)6_dW81G@WHh#Ge$?kdIi_oEY+?RTVRbsv+=IES3RbUOJfEKaxXcjT{QfO(t~6E&(2IP%v~>2P!8@5F$w zC(&G_mztA!wb|OZ=m7Hf|Hk)S6@wGKRA2uPi|G)6t;$X`p) z;TJMi!q2u?4}~0lj>Xz%d#v7ldakZ_w011k(~s`nU}>!^*09;!`0^?A`0S@0$?k|) zoB~h7X&K!ct7C`n@#Utqk$l?w6Wj4QV=c$qpZ3=Fzo{0KArKN0)&LqZGs(<^HHfl>gfJjPP#`R_FYhu6+A7gj+%@u) z!4*+kD^^jlfO}l3ShuLHOSP_T{kyQJtP5J$DbM86c``&%`-gDn~ z-^>e7zMHw{zgcblD+<$krmf$&e8sxOTUM-IxhlMQ^zxQ)OXa4u>!zHNk+vZEuho6s zpZjN|jR+a8isnMo(q_s zEM2!^j(D?-Zp|)$s*LW$=_WP#DLaa@aUEWSoXM_s=moHy&uF+`WsWoN^Pm{(oJH`_do;hJ-Zpb*@Y+3FZW6bM3-i`U~sWRq+peSrK z6G0d1pfluB7RbR1!^W8*+vh9S*YLa;e(ghR$TW>XarWR2swcdnJC+Sr0tU^ur*D?oCdken|hKMNZ_ z4;dG5D_+XTGB;+VpM0T}qsG_mf62ND`@Pras-WtaO~>jxKTK65{{e0#w{V^R&&pYQgc&Qx|xxc7r4#JQETi5 zM`_r&T`s=+c4JR{cbWos2*;gaPi4gf>0xp7pc+19^%(74}9U1B^VB+rJ8=R(Hs&D@@qzUQsSRU0qJafhSFy-3IXL5_PVY`h#YUNOI} zE;ZYF_6@vhJ+{Vp4aZ_qbkMOpDPEVxzahuJ88-eHGX7*9KaGyhHU~_}3>t4)N&Lp! zAXybQ-Vu^s+l_bYd(nvgSxDXs8}EmVznI785lIhoZ9(s#@qv{jV0;LYy0Gz)kRY!< zmc05zNInf4e+?ObGiO%{$(Iu_7Js*r1dY!?g5>-^B5_fMe=a0n2+5aW`LQ13xlu9_+A=vqO&3Y z5`uq&z#N!eHL>$1^rQ7qkMR=@?H5*Q&=zzwtq;nD3J;8O!y=Arvuo$dDxtc=WhL>d zbn?->R2i_O72wH4MR>Ac@x`bfu)?Y*%8=^iGB=f#nf{!-hHN;jtUlG7d^BNH4lH7& zIiqq>5!ODigf$OVSmmP(sREbz^rSNLKu*E%LO86f)v7P~n6(HNvzo_y=l6klQN{3t z)c}+sRpK%)no(vxm78u(%*~rwN{7Xi5Z$1BHEJOFXn9d(u%!FCZ&&50@FJrI!IB0J zh80#rP=?e{^S0ax4Z~nt3H@p~`FK1kU^@jwRgE&FYFuXS;23-MOh24f=Aa6YkHxi&9r`ULo_#IQPlFX@o(_|kXW+Lm z*TV{{nJ7c5!DU`Iq|97W(8IjBXOBkJ2(Ojaqh^thXIB#}qUAX|8x`q^5GMWP( zT3BwrpPz-5FsES=oK|L^T1-A#4%HG^!kn>PwV=Ywh*}Cun3ur{tK}#|YK6;epFnNS z4{5DTCF4>Es|;#ETX1$QmsZssMo*}>b0=KY8}duT94eDR%UJ}%x&BNhn2NP zZ6qJ7E42wG(bBT13{-?x!4lfdu)->gGNjINndKEkn>CPId@j6JZok?>zI3y=5Ls-F zFRVQ+wsfis6VL&*jnTA#w!)%@@B(@fDl*&`!$L#M#|vxCV|~*bw$lOnGEyB>my(aC z;$^U;laM`^qaq|%z>=%E5>{AUg)*eBmeq4vQQq)t;IOjRs%y!|GqDX8v1XWi`pPG~&poKf>Ujq&VZIYqSlxv(q<)3fv)pVe&M+_R zpJ{gF_H6hy{8nN%7l z^?Q^MdyqQCJsLSPo`=Uu>{l<4kBJY%iW0vFlZgL--$MKntgw0+Wk|guE9TWDJ!On6Y5Ki;66me}*N@@4*VI_fdw_U%=d0W?o%V-0%S$R@PeeA^CW*{0Nq?V(Ix9 z6TZf0g=-e7soxA1tEg#qx7hg!&6uLj5JI zu=)yRNF9;I^5)Xq#-ngpSv~3)`DC#?4wGnkvHThpq5TGy(0&Umtp0&Ar2Z+3<;|th z#qv9Nt=wMqJ^9kjX{F^@EZ3CQnzs$i9QR`a(5HT4APY@sCu)NTO;lu*-LTLG^YhZ$ zIWeILm-U3zDxG{h2Qy$v7a^@PQ4yXjSaK~rU_of2gwQlMmQ^=o!?qIEsNUq`>6ZhG z2t{L}B7}Wl31J>AXiSvQm~PgXIpqa23*oS``c+@@F>4X5PS$>?2y1^>!deUqk`pB) zr)iXrZ779pB@C#6-2}=kE!-CR838m>acjuOwPYlj5?;XU8<+x!8 zSlNSWIQc|&!V-4$>IhV%xg%kL-E1k(H~%=NBudzCW+fe{uRs1;HJW^)GGTR+c~B8D zFDz-Q4;EA=N~laTXK-DEAGU=sP^ALo<1r7yB0>?Ds0iU0SVA}!7Q`h=h)cJ*shF<4 zPjCVpR#uOiNIqt*gT>8+vV_UX62HN?V{pFNrzc%*j|LHvbdY|%0k4`$J|>z5D@rsS zCJRaYmZsIif{;WBA!)V^3C?VUZ6))mS>)r1(*%piL_VS-*BXK)gmYj)KB9ztG$);0 zZazA+hgmnY*SLA`S&6IFeDaBUgeAnJ9#M&^M_3>>TZiPE4>y*8ta#=kI>>UlHEJ>W zL@dG*A`*+J$ZE0_mb7pgEQm#v5Q}E{urUp1!?qIo)k^Y-M1)0zA`wv$!Zomja4jrI zM3j(-ZkC9Lhvkmn0Ed+|pf-{(DiL84tw=;vgjT_#3(g&uZw{##YR(#--v3-W%5uR$ zwS|1?<~hS?jf)JcHD4N@Iqt#)V6EE5K-Pk==!~KTQIWB`7#5weaCqk$x1EmBPZ+FH zmy(Ys)Mc=wLy(P^qar+4z>*HR5*FkjO2|R1apmR{6?x;XLETE~QP+}>CsrFQBISv7 z9V$Y)0~R$dry}2cZc>Rkeni2{opg|;bzT)A9~13@)k$pkvrkl@?$ZNO_J}a?L?IxedL0Hm#BnMH6%0XB{d?zf(L6ne#vcz?a%oQyNr+!Ka6bzR0udzyB1#sBU^myFlrimjIIO%u z^#b{r_b{v|?~5>5AL6(0z61;U5GC{>7Q%9KV8`P_CT3Kt=Yvkkk)&YxHMIWLf zz3~PtVSW=9^dU;5FSr%@r z%r{>hnm+CWIIYYc^&$C08p0B0l7^_rlJf~HFq@u}^Ua~dN^s%*8vafP={Z*GRiBYh zgdi*;5+R6+5Pbnln)oFw2tkw(g0dtYADvwMw63it6+7~dQcHsFD#+;!GeZF2@TnWpH%+f<;Atp52uaV zs{-WYDG`K4)S@O)5$Z9pgnBG2s7aJilb8}vlRtPf=EPUm2?*Iv;8PRH$tS3T#chYM zgvr7ZzlD4^gX7AH+q|_pW5yKtY^>F4Dmj^T8muVmbeOC%@mpByVL@f0 zgv#u~k1KoE^lfN_!$w)7W|5O8ToWv!6s?Jh^h5}j(9VGct%;Jf<^$u)&E7Ti@l00Z zJos$Pel?$*A~srykf10@f<85Y8cyuljc3DWV-Bj7Dl_>GNvI{@2+#Bd! zd@dYTUQdP~u%!P;oT3sHr?7`x%h`*NelPGf*eH&IVwx0&}TOufYU}DR0qk)!+r=BQHvl&MW`QzCDf0>f*?f+ zK`Kk;zA+i*qH$T%pM=v!U8|lVCsVh>;#Nb8!elLq-$MH=END@b(4tr@Ju{w%ZDsRT zsTat}Y=>b**d>Q#7cydL!$IeB7rz#?AJpr}X( zya5Z`=ALobJddwwe2WgSjF8u>-X^CAPguf2!V{IK@Ps8Td=D0cCrSuUSrtE;z`}F< z2XNY`edad)TBI7k@e*hSk%DV$LE{RPVhJWoer=zuv&daPLZ9kh(%;4D#G#w zENS4Eupm28LUzjPw{l|ctfO$)Xlv9laz-^LOyU*IiHh)k150?ng$2!t5}H$1zt1N2 zF8&S<8?Rq|PtJ5ReL*6>)5osT0FjfZ$oRQoQTH5}7?+c9 z*-jZ$>Ez@ImH|t8i1Z{XLX-tduA~Pn=t-3H_;KPSJbr|2W2{xZ$;neI2bPQ+7;{k( z#y+rwF%K5BBuZ#Wvwf1kp%Au}(O0GVl9L&WV0AL~Lq!<-!xF|~SkRFu>DgoJ4BQ=0 zOE=e^k~h8-J{z+~4J4=NNLa+&t)6rwDp4H?i@NcFN%`g-xg}uBFt<(S=Z=lT=wQo? z^Lf>9a*BL}C5*WG5vWL8N5Yb}j)DdGh!XO#E6Ycr@7Gufr;XaDMw3${B&<$q4=O_K zg(cKJSdfq?AtAf6gcNH2n5M#yfQ`Od1<1)WAqb1;U8ZNUEZ<|{fo?? zn75spK79fmX6=OVnj)b@#v|Z8B<`}*!*fLIhk!5 ztSH-bn5-1>TiEJhK`EkyQtZk~F?&k3d0WqnS*j6U8*f0(A}7zQCRoHPauF5j&=4%) zo&yVV5hdheSN>!>drDEmJUDE$K{cP8q7`8YEontmqFND_&@O}pt%wp@(R`q>+{~WR z!`wY3r*RQ{Hs)Hjn4BULVF@$IL{wziSqe*-m%)NeLdfdz+lW~C@eP8N zq9XijUmj==fhFV51MJyU8g+5|+@TXYWBp8oU=4_v{S$@J1N>HQrAL zT06B??IWkiNLZb0527M$55bbg?uP{#i4rnWmQA7NWwYS`0#?|MVPE6@mrW*f(12+5^7Qw()z~eLi#GaHeR24jhsBg zI$#m6s7X|$58i+!+;76d6HSzOqA3e$ePdC>TX5KDtJT}&6nP0tXh~k85|x**g!VmH zke4VSFR_r0Yy1GVjj=|3NKR3eu!NCRB`UHKeF94uKZS*lx>4ezZdpL<8+(}VH)fu~ z9sGAhZ2W%p897C7!Xkdro2Ur?7qGx@-rk7MBxY64I!XsvW^7GB9V2H{Ucw|Ak(a0l z%{Q>5ncu>KyhI6k+0~rXG`jdZRBW6<^*uS$&1+`SD*3=HZI%2n5nii)VmRx|ZqfqL zm#D~yx?$0UR5jI}9{Y3~F6(*m4S;lV@|4SfB|U;%&O}9svS7)j^ne9@i4yv<8$YjH z7|NR8xnyP|V&nIy-sI$|mjg>i6#TiU2!9_~!k-5VY7-^YW;cFvxiD1FSO|xW)~ouG zlWB`!b<*}jMQHoO653)|kenzXIlJ+r%htL4=(4dCP8+pP4J4;%PFO@OniG|%=7c5G zgJI#>B}zQI?8b{{>sS{HdoFYD92{k%n1S-<%k+6h%6fB5Oln|e0$DC6d zD`DFhYt(3RisppX$>>2v7`?EB(FY5f6D2ffH`bi3a|e$1!(pTKs{lE9Rs>-Ytvs?s zMQF#s0w7L!vXC@i5R35tp=O-o@3?J`)9peX4{<%5eL zZs+IE;l=Z8IBnEkwUV5oK4B5Hs83XcdJQb0UJDEA6D8CqmI^$_-rd}zO1_S+Ho$LV z_o2}`Ka!}p*fZQTnC z)aLO8`K~chuHMl%W70$S)8Up;3wqQ(a*FVT)yerFD#G~?ENSk3SP-5lAv{qw2engV zR~&%LM(JP+H(R;^wjC$k-f6=i!7CQD5G7PgmQL1LnW#FW)+ z=AxqUufk!YtWmF#ljl(fETR;3iHh{&8?b0>Za*vE{M{md=q)H``l=Iss|KTJ5%SZkYUyatbv_{ z8@7!RFY=I+XHyO=VZ?08MMY*)A6UYe2Me+jC1fW`o=qj@krsS|c$v(mLbz?Gz*{}! zBXKemJQoB47QCCZ}yNgy&=sNRuwR%fn`vesMMF?^))pTmQap@1+|G1 zY7^x-7pO}7YR<&RaEsTk2)EE1=4v!t^v1QSf>F_oSeo*{oP2@u5}Mv9KNoLuuim&p z`QWH`ZEo%j-|`ikm#Nj8@K*Wab*nc+gCcf;szJFX2Ft~}+xZ2`&yj#0DWkW#%a?Cz z@hol$hgWO}6MK+xwHglZ%@z3ZhPDM)*?C0t9dk8?4;-r<*gP=$+IX)5HIAd>^=NrC znw?*zCUAJ77Tz%fkHx#?lv!R4S9Oe=q~Y*>A_E6mfR$D16oyXz-%yV_m7%BoZ>U#I zVd&H(&>@xvEDoz_44tk)n{)AEDaK+3N9y%RF}?R(8a{8+ip49|FI}!?GNeI=4CPm& zw=T~sP>md&rAPZwbihil9Nxr$**ahl0ft5al`YG{tIur-uTUXI&e4(dB6x9un#+;X z^+qLqu+w`@|Y7_wS}&;<=eCT;LC z0)Jd!ld)Eas;c{`gvz0N35VUC}p#hb@Q&%mE=$9wr)MsCrOgQLidzU$VjtsFlu zG2YfQD^H!z(VuBi>|D#tAA9!9H`LELc!3_2_k45W7jhwEw&@sp(;sgo7py&3wQ}?# zJxZ5}XzAK>;{$v#!+xQ|T9)@U+cGoJlFZ^&D_WNGDF2dim*_ZpA6?#F&W!iTcE(+* z;hOtHY?TzK%Q$km9vMUtbA0dge2=<e@|Yf>`_Irk^*BeK&>}p6@Um&V z%b#S-QyPZvUPJP|s+~hmYoX}jmIQl-QO{~9zIl}ni09XH4EenVq5Sf8ri_-Z&olG| z4a#@2($W|krPpBwy{LgGy+*|eE%&rPF#078&G)p@lK3r4z08q=pm6|L2<;GX&4o1DMp^#=}=FxuH1?mlsy{X0cMplYwUY`0RNB^WnInCm)ly-}s zyEN2tS6!yN`V-o76`f|5%Sxw#i$U+O>SbxKDLlQz2fZEy;)`+g};*HyuD13nF{6 z{hg7Y=}1bpkNReTE+yPmc`GPTDY8WKhp20}6uQ+lsN*B~Ss+sZW2c!DAS(jZ7Ruvk*<-wgVX2I4zePPuj``!8dD&@f1_J%a~F1>=t# z{z(gS=R7eu-Mq0Jxs`@(;RVWtvU9grO@||@-NM|A>Y}0e7KZl-OV+Pi+_GuYy7N?5 z#&pv#$gw?xheY2r?3$;#Ga_9>aGu5AM%3nJ24gdIEKNPZBG0lo-b0Hc&*W@<;=|sP zQN483V4#K-sBDh(wns|M$A?1QJel|MJzi!s*C~U8i2E-I<{4d-*pEuaT@ zPgQVagchOvdZI$huaOKHr9qHi+s)=V)P`M0l@ipeH^f<}&JZ4TYQ%65Rggadf^GjdEzWKYI9gfjUFOAeD&W zt13wM1sq+dMJb)WsM6BuOop7LL6A;Fu+D>rB#St@Sc@X1$`I2SmN2G8!yu(T_Qvj0 zOF6tu3-eXrH6}|sEoayY4Tf|AOY&4Yo6##ZG^f%JUZm10j;+>WNTrXx_%xr=Xbnf# zYEiyb$-7TlDy?JOdL2jW_YdBFczVLIjam$8^s(0;pEsKrVd#j#Knx-Ir#QCR9xF8m z_>n$gj-I1MIemK9^v+Y~a%_tpqnlIj8trk$R>qvCV<>a*77snnIG>|G)1!1%h+5L- z=M1|*hsAD6@>UHRdLd)B=@?3%=sP{}u4rZ4MH-HGSOW_E>S7N4LJ!H4LVxV`qkhSd zOEd`2o{<492evcjQVqjN#2kvD2c&b57GkZU(_>}D;-+b?3lbHrOXdaD-Y zT-!HBd&+SeV|MEpnvwFF3T|1qb9|2$=WQ3A!S)>S4o2RoBjq_F-bB4hLyUYxLm~k?nY`*zjz6ZwElip+e4LR_Xh>wAYpT8ldy*kfX%M=n zEt#a49f?Pa+$PqFZ*lx>dmPU?-r>l* zTEzVN)G{+?3O%_0GY8+(gOqJ}Z=pncruaUi{-UEO)v~AHxx@z?`%sV3C1y|2I{YKX ze5_+Azm`nQGRvld;}ed5s>kUT7kyzMevA7n-Ze27TyRPFL`C5zYm{FlW z?f-@W-)aCP)J4;W@rPjl;NU;CASY7k3~k%>9b>-NFi0X?fSyDr81pX;!`IL;(<&7I z&9MJyFys=b%sXdjL;GJw{h*;_uf^Z7M=t%y@t?Ffa*0?P8@flog%7cqmch?ZP{y92 za8AX)sjuafo3UMVEKNAUBB#1?yqgwBPQ}ltJ=^Hc$aEb!IEqBeGdP-=7;R{TD$C+% z4=u{KF?>%QS=Ez+z4Rbu6@q$JWizU`j-srpZ$w7qa4c7k(M8rbYF*!lF?l*BCLZOR z;%H_*V+wQ(Wmfc?HZ8<|zak$!q)5F$gc_rjmw$DXKI3-LtypvjU` z0~l4Jp^#IPX5;yODMtosk?8SSGL4Xc(l`*K=cAuaO)crG@!I@HK49*6SpOovgu- zQD8~F^{Qm_Xbo+)H(}=;>SnH=lUbmuIPTHn$gHpD4ly^(>6NFv9QA2Y&M5i%ww6uR zjH}UcG|^h;V*BOiSU`&*ljIDU)(^gdjI7m>gMq|T!Z92fYmH2rJ+S}AW%PZQ4Qd<% z#%lmhr0n@oJvf2m6ZJUXq2}B5U>yS|>A?1;a`Vc0=nee85tyly^%&oz<|jzOQyF=h zj^w184=Ff>V^j4Q-=^kk4>G1PX1a!HF2+ZIrTD%HN9y$meK=UpL2&e4!`hgdKi$uyTkr)wc|+kz1#nrxiM zi1|8#{@DJl3$pW6Gsn)*WAsLIxl6=!ajIZ z!tup=d{6E`bI`)P_`}{MjA_v@JMagL?^}pwF6GcNEyU@*cVS_HTF$W*TCDSnu=#QF zvl+EgL-Aim-@CAnJp5b5@zq+qc^q9N&TkWSYZ$p!N79G;(e7A`z0^8JthXcR%b<90 zwSl7>wJ6_wXD-4WWD~~>Jw{W0<|6G5q8PJT!|cG?+_b0%J;Dfc@Ek43mpOA$Z_6W$ za~ZZpgV72s-z>#lbSsC?)582!=+1Am#^=WQjQp97q|7NP^r)Y6=mI??pQ|5Pf`@Mx za&(&(?VK=2mT1}2%D9U(9BnZ0jfJ%>*g9Oy(O>9M`5wXA7H!A#OGaIyqxkC>=8+cc zZMSpuQZ35k@l^}P<1&t2uE%IRzG~6N;|j)HsbO~D3jjx2_>t089K2c!@_2mJVjYib z7WDs0#z+gV!AC zVAaWCIA;}L9<%+N6ew>LwPQ=Dv@k6{IhAk9DLHswcgs>|NHs%_r-IZFWdEkE=ewM* z%^$tC{dMQB>Ea9vL}3qIIMVHm3Pw?8YwH?kRBh)eFIto1I5y}_9vidn9OsVmU}(@$ z$FDoz@pK*#@#vX#rt`3PV#8v#UZ3K0JTCi>U!S5!Js221-~RTUYn(gYgTWD4>#qJT z40E#o>T0ydgYn5oGY{SPZ(LyVdFR2{@Vwh`bBgQ3(8%PhyCuc-VQ6IDl@z2EB%PY* zr`$Ai-;QODGt(O#lN2-4i!q^@d4Gz2^x{m?L3IzLn3-M-3Bw+GIK|BL#sWbu zy`5uHcW_OP5HYz{`OZ>%uOFg2y^qqYtD1ihjvA$-RsW1;KS$$s`U-$Zt(GP zmU2<_VPt4INHO(%7#B>v<8P*zdcN2?w!i&PDW;yUb0KSeccj~KXmD{nG&^2(o-Ng} zf!XnX3Q%lV9{M1~7+1&g#%w>Df-9CbX6rHMfvN5snD*n&15+*cnX0c-bZK=gXCC@C z#l)+QCC&c!f25dr)fgX|c>i>sc-0u2=*0W}8w^cN(!~r7_u!80HQV2J?t~hQO>}mi zaPEW}UeZ#|&Kiu&|CpUM7#3O>|DB>MYhr8J{<{CBn4LAT#q42MiiNQzHZGA?x666D z)WinnU|kniipf_K%bcD0sn}wP6KT(PIgbvS#-p?2U(Vg`=bKIBc!A4VD*1V64ptSp zQuL}nmNp0ayHd`pR)d!gEMeShwAXkc+8i=iH2kRDACol{bw zFi_EnJvBvV1bMMbIb(wuqyI5ugBYKjw1W>!OVy>p*fMvpbw;YG7>uuTjj5($Fg8Sy z_NEkbC)hbct+O!|9TN&LNvQTyQ!OgB7$htz!KE(cC{AnR$+SGxM8Ax7yC(YTw`*nc=j_$0hJ4|6uD@*BDx^PEI!eX+JUo zclnV!mpR{TXh*^~o1N#U0O5oTiq!qg6>{1l(T)ULq>WpaIU>+mKpabfKsV*49o+YG z*Uuc6VcLd7m*ERBG)|Mt+qqiTU5t@&gkeTATcrJ$t~HLR&Y9YJ2`04Dz@P_sV0K1a zW1NPC_98qik-DAEC#0>2G^;h@JTSBqk%4L7iQ3ul1fn8jzp)_?iAOD!4XvJ?J2GggCi5q>CqW3jvt-v zyu-Q2X)D6lwqvW~6tBjp(Db|`MNiT`MEa=pt`yU=8iOLRy5FRjp0pETSfqV#ine!- zOY6NU)?3<)2wC0ZDac}pv-1h(k)f?fbo`!l9vRw=P=7s%{&LD0G@WxMQumB&2Bu3% zvbBu%BK@$I^*`9!o??E|W<)x*{h1W=lQtwYKc97;pR^%~Uf=JXdx5qj(fRqjb1%@2 zB=y2byOGojBW*;`lgD02G4p5>621BtQ_Q@Y_*(V~W}efyphvbVy$1yB!~3(56G?a9u%)>*EzI zm0YCFM(T-2`waBazC!eo(>ex-tYdWpQj7_0F``{muQtFyy%~kk(oEvJ!+KxVQ3-!1sco0^&6M3ShraI|M|_4jB)Ov1M@9_t@sSp z3Yt33J-EdRYI@WL8cl!xX0)fN-i7}qC@n1mAJWplJ9hsk(vIKsQ-GiNzoutw43VeD zxr-vl*SdQSw6?nKSv&HlaqfPRcgH0l-+9iCoG>29qVWmHqUY_%3kbPsc>?lK0`hS} zR!vDj24Apq{s+iy6TkiUwmSSzoAhJ7utW9(qMcY(WADa}7rMSp|8WidSUH+rw4F4@ z;g51`>l}+AzsRdgqtU^Sl_cjC2a>%i3!YyoUm1q3lh;ypExjkB98-Yn~J+WwV@-PI!uNh zb%~SVuXf}nLe5A;cKpqb+(*dE6OnD7*^x&Hc{mZ7_PHH-@+5b^QP02ozIhlUM?co# zDEh*Fz&bkMgQs8qc?H1qV?ChxOZx%$;s8#d;a21cJ2E;$J&_3$-6ghx-ujh|HVXDe z%BCcY^a&?$Rm5{j0($Ea2Wn4bQzAO&m;>4y*_Vj!_{IV4iyTcvH~rHAT^%`jazeXv zzH>m=MAiXqo1q=uUAEQx0XC1ryuK->KJL`h+H=nQ|hFt3A43rlpXonRLtvp5|K?O*^$}Pfc(EiWX^ayvH{3Z z&mNgHC;9B?NI2jsI^goFvkxbq3Pls`XLtz*Y#ZiZ{96Z^VLu=`cd8R)#`*Q$)yX-U zCfQCJ<4BYlJx-i&B+tm4$qpn5GNZZU<42NrSL>txHAG3LASA9Glc}>3sGL%adp0 zAt#Eeh^rxC(ge?TIH4y}mx%6I>45e|wk4w5Ry&}5ktY+;X=@$O)sY_((M9VV&^3{9 zjR_j0dA$uCd9V=^D9PsPL;~_SA(L#bwr;RY;Cj7 zg0@X|>icHlhA}OXx=GoQUxIqu4f7@pp-mP2*rxBfwgaLQERdioP zs6kd{w)b~Pyz?AL5;Vw$X-9e|Z%ED+$vNDC>d1pl33I)}0UoG{98W|SU6q_X;E#-+ zoxpv_2_1-RNJKYXw^+y}HxgXFY1pc6F6btm`qCZAHR*V#BCEg?*)|ISXBtvRM0 zxjTe;eO@R5*?hen`2ishB_em-U`LkB0WwMc1$Ww!XA&|={vAp{-VEfnT{r%7LaYm{ zQ$HB7QAekKFhO-3=|AuOHcHgY?dZK@;fAqQ|(q~!Djbej{}6FC!To9bwK z*v1`s{d7!*drwc8ia7`D$eejVevydmc*KsJMaU8J5;&V4w&FHhn*v`yOwz9t^$5uri2iVrX9d>Y}Y61E;YeB+D6n$o+jcg#~v_#~g1mtgl z+%{lu(cIV?Z0&-k&u!E)629urZjAQg&b{AP#`bJhikvT#Q^Z$sdOer^y*6-2zvRi$ z`c-lcf23+*LND(+k{paXdm{LqqshTR&LrEAY&~WJN6OB053zmvaUub<=uCHCz9Vk4 z_DtJx8=Z7+b;7DK<=w8n z82J%kTR(Ob+rSaklCXRBq{E<4w-tix^82I=I=F99$jQ zmk2)L0IrD~0oXPgZ8bLLNXs(Jz_H5`I=U%fNA4!%d5OsOpdI-aAV>Wy&~sq&zIdnB zen9bZ_Ym9KP&C#Cl5tB=Lp@6_f3*VLO+VHy?{Hv{Cyn-m8<+fa@}SWMKH&iNM&4hZ z(B)gl+n6IwE6|Z6S0r?K&O|%%8bYp0M0T8FN4^Z?wgGGQd|ZcqpdV{rG@WKUAljP= zOT?~e51x@ck#eRar-+Y8EDOf{_-eQ0&FGkyoFf=%Iy<4`+UDE9k>}2K4;gB`g}l>j z1IdL&u0GpcVw?LN3vJNI#FgmYek&6OV(Xc9! z!M{FtJ-;%!V(eI)oFe|DDb~F!e|5{n$rG!tB{@fR#I-7Mky)A?Toaj?2;Q|UIoO{} zF$S00z>%+4p`-cnh;>>XNv!QM#angno%1K1Z? z2e56rw_Rjoj(oHRef#K|guZRM*p4h+3*?DJ-_`efg=$tgN6Nlp=$r?E}#pYF?9ojkePF12w)3fDn|SjR2xayw}5I(Jq7_0=+I aLI~5pruy#b^!uqNBDbw`7iLjhV*d{UJn66i diff --git a/source/tests/pt/model/models/dpa2_hyb.json b/source/tests/pt/model/models/dpa2_hyb.json index ee69ed4d69..f7d2234fae 100644 --- a/source/tests/pt/model/models/dpa2_hyb.json +++ b/source/tests/pt/model/models/dpa2_hyb.json @@ -22,12 +22,8 @@ "attn_layer": 0, "attn_dotr": true, "attn_mask": false, - "post_ln": true, - "ffn": false, - "ffn_embed_dim": 10, "activation_function": "tanh", "scaling_factor": 1.0, - "head_num": 1, "normalize": true, "temperature": 1.0 }, From 93cdd6a3007621a44781792dda805c7f5a7f4569 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Apr 2024 02:20:37 +0800 Subject: [PATCH 11/18] Fix numeric tests for layernorm --- source/tests/tf/test_data_large_batch.py | 112 +++---- source/tests/tf/test_descrpt_hybrid.py | 120 ++++---- source/tests/tf/test_descrpt_se_atten.py | 360 +++++++++++------------ source/tests/tf/test_model_se_atten.py | 112 +++---- 4 files changed, 352 insertions(+), 352 deletions(-) diff --git a/source/tests/tf/test_data_large_batch.py b/source/tests/tf/test_data_large_batch.py index dad6bbf252..baa173ad7a 100644 --- a/source/tests/tf/test_data_large_batch.py +++ b/source/tests/tf/test_data_large_batch.py @@ -201,37 +201,37 @@ def test_data_mixed_type(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.121172052273665543e01] + refe = [6.12116933882038480874e01] reff = [ - 1.154685702881510720e-02, - 1.756040710324277901e-02, - 7.130177886472930130e-04, - 2.368263097437618356e-02, - 1.684273251820418010e-02, - -2.240810960870319706e-03, - -7.940856869069763679e-03, - 9.685611956408284387e-03, - 1.905551469314455948e-05, - 8.701750245920510801e-03, - -2.715303056974926327e-02, - -8.833855542191653386e-04, - -4.384116594545389017e-02, - 5.810410831752661764e-03, - 2.624317854200653062e-03, - 7.850784565411857499e-03, - -2.274613183985864026e-02, - -2.321946424516053086e-04, + 1.15647422509625782266e-02, + 1.75814420518816301453e-02, + 7.13827966845788537686e-04, + 2.37054385620869625950e-02, + 1.68638656843611636771e-02, + -2.24281688803243482028e-03, + -7.95826529019691246425e-03, + 9.69601584607941019422e-03, + 1.91505445834121360688e-05, + 8.71431743822387999687e-03, + -2.71847766570148252629e-02, + -8.84489238783629392812e-04, + -4.38853499152154838403e-02, + 5.81882595276563344133e-03, + 2.62678184040523532775e-03, + 7.85911695413897895546e-03, + -2.27753728780730156644e-02, + -2.32454225018371036246e-04, ] refv = [ - -1.048816094719852016e-01, - 1.669430893268222804e-02, - 3.444164500535986783e-03, - 1.669430893268222110e-02, - -5.415326614376372166e-02, - -1.079201716688232750e-03, - 3.444164500535985916e-03, - -1.079201716688232750e-03, - -2.093268197504977288e-04, + -1.05000199239240657456e-01, + 1.67161895068729630942e-02, + 3.44771431604021672684e-03, + 1.67161895068729769720e-02, + -5.42193765251950815509e-02, + -1.08055824874348557069e-03, + 3.44771431604021585948e-03, + -1.08055824874348492017e-03, + -2.09534775642289020732e-04, ] refe = np.reshape(refe, [-1]) @@ -400,37 +400,37 @@ def test_stripped_data_mixed_type(self): np.savetxt("f11.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v11.out", v.reshape([1, -1]), delimiter=",") - refe = [6.124119974943835132e01] + refe = [6.12411774224343261608e01] reff = [ - 8.617444257623986525e-03, - 1.622774527785437321e-02, - 7.219537519817814273e-04, - 2.465257480331137924e-02, - 1.507377800325802181e-02, - -2.267846199393293988e-03, - -6.217685260668888089e-03, - 9.187965356558825195e-03, - -2.082402632037372596e-05, - 6.179226045047841662e-03, - -2.505229190184387472e-02, - -7.834051085801594424e-04, - -4.104669576212031240e-02, - 4.721690416727373704e-03, - 2.565744238275521286e-03, - 7.815135916805987862e-03, - -2.015888715255471572e-02, - -2.156226559634751916e-04, + 8.63770820796567855016e-03, + 1.62522026393666710331e-02, + 7.22919459978568399415e-04, + 2.46800946249053909654e-02, + 1.50982535714741239463e-02, + -2.27024703144847314271e-03, + -6.23780053459390554371e-03, + 9.20020798171328375858e-03, + -2.07267961671176406842e-05, + 6.19326848220238136006e-03, + -2.50892401262326376898e-02, + -7.84679030459893762407e-04, + -4.10982573216296109830e-02, + 4.73129070889465389027e-03, + 2.56865811814534347399e-03, + 7.82498654115005611021e-03, + -2.01927147752160932037e-02, + -2.15924720048428232235e-04, ] refv = [ - -8.500718686149140446e-02, - 1.389198522732191729e-02, - 3.059204598073241802e-03, - 1.389198522732190168e-02, - -4.908897840490741155e-02, - -9.530658829897690944e-04, - 3.059204598073239634e-03, - -9.530658829897688776e-04, - -1.999114402095244765e-04, + -8.51412082980419621103e-02, + 1.39169542815959588339e-02, + 3.06329019931955021105e-03, + 1.39169542815959536297e-02, + -4.91657098529515168561e-02, + -9.54629874035841556948e-04, + 3.06329019931954891001e-03, + -9.54629874035841773788e-04, + -2.00155580095981406675e-04, ] refe = np.reshape(refe, [-1]) diff --git a/source/tests/tf/test_descrpt_hybrid.py b/source/tests/tf/test_descrpt_hybrid.py index 6aa04118da..7f5e064376 100644 --- a/source/tests/tf/test_descrpt_hybrid.py +++ b/source/tests/tf/test_descrpt_hybrid.py @@ -185,66 +185,66 @@ def test_descriptor_hybrid(self): ] # below is copied from test_descript_se_atten.py ref_dout2 = [ - 1.3503570575883254e-04, - -9.3606804794552518e-05, - -9.3606804794552518e-05, - 6.4931435609575354e-05, - -3.4432462227712845e-04, - 2.3883309310633266e-04, - -2.1612770334269806e-04, - 1.4980041766865035e-04, - 5.1902342465554648e-04, - -3.5995814159000579e-04, - 1.0061650355705337e-04, - -7.5148260042556979e-05, - -7.5148260042556979e-05, - 5.6249549384058458e-05, - -2.7820514647114664e-04, - 2.0819618461713165e-04, - -1.5698895407951743e-04, - 1.1721016363267746e-04, - 4.0972585703616773e-04, - -3.0650763759131061e-04, - 7.5599650998659526e-05, - -5.8808888720672558e-05, - -5.8808888720672558e-05, - 4.5766209906762655e-05, - -2.1712714013251668e-04, - 1.6899894453623564e-04, - -1.2167120597162636e-04, - 9.4648599144861605e-05, - 3.2200758382615601e-04, - -2.5060486486718734e-04, - 1.1293831101452813e-04, - -7.9512063028041913e-05, - -7.9512063028041913e-05, - 5.5979262682797850e-05, - -2.9058515610909440e-04, - 2.0457554106366365e-04, - -1.8732839505532627e-04, - 1.3188376232775540e-04, - 4.4448730317793450e-04, - -3.1292650304617497e-04, - 1.3015885894252541e-04, - -8.8816609587789126e-05, - -8.8816609587789126e-05, - 6.0613949400496957e-05, - -3.2308121544925519e-04, - 2.2046786823295058e-04, - -2.1781481424814687e-04, - 1.4862599684199924e-04, - 4.9955378034266583e-04, - -3.4089120488765758e-04, - 1.0160496779809329e-04, - -7.4538471222199861e-05, - -7.4538471222199861e-05, - 5.4703671679263269e-05, - -2.7394267959121653e-04, - 2.0103409637607701e-04, - -1.6657135958432620e-04, - 1.2219321453198225e-04, - 4.1344754259964935e-04, - -3.0339251136512270e-04, + 1.35077997858830281628e-04, + -9.36317565146126714985e-05, + -9.36317565146126714985e-05, + 6.49457155161046269156e-05, + -3.44426119482271894060e-04, + 2.38892351975707574810e-04, + -2.16192628113445024177e-04, + 1.49838432021978586618e-04, + 5.19172506251499308108e-04, + -3.60044742999178198160e-04, + 1.00648981900694042455e-04, + -7.51687985725674168679e-05, + -7.51687985725674168679e-05, + 5.62621404496089786633e-05, + -2.78288905170686305408e-04, + 2.08248552733448707985e-04, + -1.57037506111419247626e-04, + 1.17240613774749092711e-04, + 4.09846227953978995209e-04, + -3.06582508385239355716e-04, + 7.56236313388503977959e-05, + -5.88249954799233110928e-05, + -5.88249954799233110928e-05, + 4.57767614608878164778e-05, + -2.17191782618980676941e-04, + 1.69041932410352632298e-04, + -1.21708419050609283887e-04, + 9.46734475047640323129e-05, + 3.22101565810662901230e-04, + -2.50667145896081176772e-04, + 1.12972766463605449241e-04, + -7.95331652304217509748e-05, + -7.95331652304217509748e-05, + 5.59918979793375151091e-05, + -2.90669309441163412500e-04, + 2.04626666596480422588e-04, + -1.87383581443938113499e-04, + 1.31917380775058677711e-04, + 4.44613289651917854839e-04, + -3.13002780120454830552e-04, + 1.30198051172878586420e-04, + -8.88399346622230731045e-05, + -8.88399346622230731045e-05, + 6.06275354032895547767e-05, + -3.23173886613725041324e-04, + 2.20522620462074609186e-04, + -2.17878181114203837987e-04, + 1.48663514408247710887e-04, + 4.99693951217273298233e-04, + -3.40973735611388808521e-04, + 1.01636483586918407768e-04, + -7.45585238544824841465e-05, + -7.45585238544824841465e-05, + 5.47161372646580776566e-05, + -2.74022957033491422908e-04, + 2.01084733576426032218e-04, + -1.66621218118959135701e-04, + 1.22224760787930633501e-04, + 4.13566215420014648540e-04, + -3.03467107774532218571e-04, ] places = 10 diff --git a/source/tests/tf/test_descrpt_se_atten.py b/source/tests/tf/test_descrpt_se_atten.py index 7a1bfd18f6..f4ce374c42 100644 --- a/source/tests/tf/test_descrpt_se_atten.py +++ b/source/tests/tf/test_descrpt_se_atten.py @@ -150,66 +150,66 @@ def test_descriptor_two_sides(self): np.savetxt("two.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 1.3503570575883254e-04, - -9.3606804794552518e-05, - -9.3606804794552518e-05, - 6.4931435609575354e-05, - -3.4432462227712845e-04, - 2.3883309310633266e-04, - -2.1612770334269806e-04, - 1.4980041766865035e-04, - 5.1902342465554648e-04, - -3.5995814159000579e-04, - 1.0061650355705337e-04, - -7.5148260042556979e-05, - -7.5148260042556979e-05, - 5.6249549384058458e-05, - -2.7820514647114664e-04, - 2.0819618461713165e-04, - -1.5698895407951743e-04, - 1.1721016363267746e-04, - 4.0972585703616773e-04, - -3.0650763759131061e-04, - 7.5599650998659526e-05, - -5.8808888720672558e-05, - -5.8808888720672558e-05, - 4.5766209906762655e-05, - -2.1712714013251668e-04, - 1.6899894453623564e-04, - -1.2167120597162636e-04, - 9.4648599144861605e-05, - 3.2200758382615601e-04, - -2.5060486486718734e-04, - 1.1293831101452813e-04, - -7.9512063028041913e-05, - -7.9512063028041913e-05, - 5.5979262682797850e-05, - -2.9058515610909440e-04, - 2.0457554106366365e-04, - -1.8732839505532627e-04, - 1.3188376232775540e-04, - 4.4448730317793450e-04, - -3.1292650304617497e-04, - 1.3015885894252541e-04, - -8.8816609587789126e-05, - -8.8816609587789126e-05, - 6.0613949400496957e-05, - -3.2308121544925519e-04, - 2.2046786823295058e-04, - -2.1781481424814687e-04, - 1.4862599684199924e-04, - 4.9955378034266583e-04, - -3.4089120488765758e-04, - 1.0160496779809329e-04, - -7.4538471222199861e-05, - -7.4538471222199861e-05, - 5.4703671679263269e-05, - -2.7394267959121653e-04, - 2.0103409637607701e-04, - -1.6657135958432620e-04, - 1.2219321453198225e-04, - 4.1344754259964935e-04, - -3.0339251136512270e-04, + 1.35077997858830281628e-04, + -9.36317565146126714985e-05, + -9.36317565146126714985e-05, + 6.49457155161046269156e-05, + -3.44426119482271894060e-04, + 2.38892351975707574810e-04, + -2.16192628113445024177e-04, + 1.49838432021978586618e-04, + 5.19172506251499308108e-04, + -3.60044742999178198160e-04, + 1.00648981900694042455e-04, + -7.51687985725674168679e-05, + -7.51687985725674168679e-05, + 5.62621404496089786633e-05, + -2.78288905170686305408e-04, + 2.08248552733448707985e-04, + -1.57037506111419247626e-04, + 1.17240613774749092711e-04, + 4.09846227953978995209e-04, + -3.06582508385239355716e-04, + 7.56236313388503977959e-05, + -5.88249954799233110928e-05, + -5.88249954799233110928e-05, + 4.57767614608878164778e-05, + -2.17191782618980676941e-04, + 1.69041932410352632298e-04, + -1.21708419050609283887e-04, + 9.46734475047640323129e-05, + 3.22101565810662901230e-04, + -2.50667145896081176772e-04, + 1.12972766463605449241e-04, + -7.95331652304217509748e-05, + -7.95331652304217509748e-05, + 5.59918979793375151091e-05, + -2.90669309441163412500e-04, + 2.04626666596480422588e-04, + -1.87383581443938113499e-04, + 1.31917380775058677711e-04, + 4.44613289651917854839e-04, + -3.13002780120454830552e-04, + 1.30198051172878586420e-04, + -8.88399346622230731045e-05, + -8.88399346622230731045e-05, + 6.06275354032895547767e-05, + -3.23173886613725041324e-04, + 2.20522620462074609186e-04, + -2.17878181114203837987e-04, + 1.48663514408247710887e-04, + 4.99693951217273298233e-04, + -3.40973735611388808521e-04, + 1.01636483586918407768e-04, + -7.45585238544824841465e-05, + -7.45585238544824841465e-05, + 5.47161372646580776566e-05, + -2.74022957033491422908e-04, + 2.01084733576426032218e-04, + -1.66621218118959135701e-04, + 1.22224760787930633501e-04, + 4.13566215420014648540e-04, + -3.03467107774532218571e-04, ] places = 10 @@ -328,66 +328,66 @@ def test_descriptor_one_side(self): np.savetxt("one.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 8.9336098555659429e-05, - -3.8921422089719007e-05, - -3.8921422089719007e-05, - 1.6975109833017758e-05, - -2.9184951813034413e-04, - 1.2724836941382651e-04, - -1.8062533253590169e-04, - 7.8681048972093648e-05, - 4.2206017420030542e-04, - -1.8398310612921889e-04, - 6.4996467281506633e-05, - -3.0812041327073575e-05, - -3.0812041327073575e-05, - 1.4663988013438402e-05, - -2.3274950984084172e-04, - 1.1059587214865573e-04, - -1.3043761448464089e-04, - 6.1788865409826698e-05, - 3.2900269837104958e-04, - -1.5623668424484728e-04, - 5.0697927477465942e-05, - -2.3511768544350768e-05, - -2.3511768544350768e-05, - 1.0919808814040025e-05, - -1.8622373494960208e-04, - 8.6439275444049409e-05, - -1.0326450661269683e-04, - 4.7880797898768150e-05, - 2.6230208262918372e-04, - -1.2172811361250681e-04, - 7.8240863239649707e-05, - -3.2501260967978116e-05, - -3.2501260967978116e-05, - 1.3502267073810926e-05, - -2.5360559687597850e-04, - 1.0535336854834091e-04, - -1.6047265448841568e-04, - 6.6660202062744658e-05, - 3.6833864909272261e-04, - -1.5301457671691837e-04, - 9.1148582997925288e-05, - -3.6614945467066073e-05, - -3.6614945467066073e-05, - 1.4709958908948206e-05, - -2.8364168092837332e-04, - 1.1394466218003484e-04, - -1.8721615730559043e-04, - 7.5203967811613109e-05, - 4.1632420070310456e-04, - -1.6724364343353009e-04, - 6.9506193268190631e-05, - -3.0228106532898472e-05, - -3.0228106532898472e-05, - 1.3156705594652870e-05, - -2.3740975974826574e-04, - 1.0328972070195332e-04, - -1.4218547815143072e-04, - 6.1827596642872941e-05, - 3.4031715116440432e-04, - -1.4804591640658066e-04, + 8.93630739076099766573e-05, + -3.89301763666544977088e-05, + -3.89301763666544977088e-05, + 1.69776207161541659875e-05, + -2.91934413405367434308e-04, + 1.27275579758193970945e-04, + -1.80678576267614851526e-04, + 7.86981804444128273503e-05, + 4.22180092132026806885e-04, + -1.84021204552106459797e-04, + 6.50166826308631336630e-05, + -3.08191630112232239067e-05, + -3.08191630112232239067e-05, + 1.46662082284045218266e-05, + -2.32818649311590855893e-04, + 1.10619882905346373389e-04, + -1.30477133579203922803e-04, + 6.18026466291577325669e-05, + 3.29098263271154821506e-04, + -1.56269574751685376771e-04, + 5.07138199677916164739e-05, + -2.35171440781703185510e-05, + -2.35171440781703185510e-05, + 1.09213797907981395710e-05, + -1.86279366618262112341e-04, + 8.64577620996147407865e-05, + -1.03296053419269992513e-04, + 4.78913622480582772448e-05, + 2.62378744147910732392e-04, + -1.21753360060300813640e-04, + 7.82644227540903690814e-05, + -3.25084361414888650958e-05, + -3.25084361414888650958e-05, + 1.35041631983765535098e-05, + -2.53679234140297192677e-04, + 1.05375493947693795707e-04, + -1.60519879294703589519e-04, + 6.66744631236456129558e-05, + 3.68443126822399244329e-04, + -1.53045684128227086913e-04, + 9.11756668850765601567e-05, + -3.66229408732609030826e-05, + -3.66229408732609030826e-05, + 1.47120125015788778301e-05, + -2.83723246380394433092e-04, + 1.13968452838666050924e-04, + -1.87270570170312914944e-04, + 7.52199008968667767218e-05, + 4.16441090538891684186e-04, + -1.67277425363850822723e-04, + 6.95274814976590320665e-05, + -3.02348814013024743688e-05, + -3.02348814013024743688e-05, + 1.31585743503078956499e-05, + -2.37479534432029007343e-04, + 1.03311591705779548338e-04, + -1.42227987950226271961e-04, + 6.18410015482571886070e-05, + 3.40414922285898623351e-04, + -1.48076286203042110793e-04, ] places = 10 @@ -499,66 +499,66 @@ def test_stripped_type_embedding_descriptor_two_sides(self): np.savetxt("two1.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 2.910296358673981606e-06, - -3.297689549631518680e-05, - -3.297689549631518680e-05, - 3.790996417030466402e-04, - -3.082208958603667925e-05, - 3.544004728264616810e-04, - -2.397997896082787038e-05, - 2.744923480535521121e-04, - 8.486866768450577558e-05, - -9.750155670867453753e-04, - 8.680391572974659491e-07, - -1.596948473518331016e-05, - -1.596948473518331016e-05, - 3.249686279109944903e-04, - -1.508338456375446526e-05, - 3.070479490395221158e-04, - -1.047241469038003787e-05, - 2.085462014454144320e-04, - 4.065724483202033993e-05, - -8.245932936607477210e-04, - 5.959146184656097397e-07, - -1.265847984116858078e-05, - -1.265847984116858078e-05, - 2.713109337202710531e-04, - -1.163070862097512446e-05, - 2.491582022684395484e-04, - -8.056716526966370043e-06, - 1.720174894426871476e-04, - 3.174999037064446555e-05, - -6.798281455902291598e-04, - 3.145148216891492605e-06, - -3.245585831548520087e-05, - -3.245585831548520087e-05, - 3.350745140453206166e-04, - -2.936281422860278914e-05, - 3.031890775924862423e-04, - -2.408578375619038739e-05, - 2.487530226589902390e-04, - 8.275930808338685728e-05, - -8.545607559813118157e-04, - 4.745334138737575192e-06, - -4.149649152356857482e-05, - -4.149649152356857482e-05, - 3.633282453063247882e-04, - -3.734652895210441184e-05, - 3.270295126452897193e-04, - -3.235347865588130865e-05, - 2.832387658145111447e-04, - 1.064511649928167193e-04, - -9.321000322425568741e-04, - 1.879347284602219830e-06, - -2.470327295060103235e-05, - -2.470327295060103235e-05, - 3.269344178119031551e-04, - -2.248434624179290029e-05, - 2.975826199248595046e-04, - -1.721291645154368551e-05, - 2.273800448313684436e-04, - 6.252118835933537862e-05, - -8.271938096175299659e-04, + 2.91097766899578214544e-06, + -3.29852641315371480153e-05, + -3.29852641315371480153e-05, + 3.79203396610324763253e-04, + -3.08296489918391639377e-05, + 3.54494448654088176176e-04, + -2.39859951795545287153e-05, + 2.74566675797922735754e-04, + 8.48899306339350606405e-05, + -9.75279256930798588154e-04, + 8.68233546069119197236e-07, + -1.59734540671145569350e-05, + -1.59734540671145569350e-05, + 3.25058299172223158675e-04, + -1.50870029997722798618e-05, + 3.07130006247707560297e-04, + -1.04749968193353404274e-05, + 2.08603290940140382453e-04, + 4.06672203401530534743e-05, + -8.24818142292956771496e-04, + 5.96048958156013435895e-07, + -1.26616643393676577874e-05, + -1.26616643393676577874e-05, + 2.71386217904519277955e-04, + -1.16335252819255226156e-05, + 2.49225002219057890918e-04, + -8.05872731607348350672e-06, + 1.72064906604221990903e-04, + 3.17578679792106490973e-05, + -6.80014462388431415590e-04, + 3.14589844246059013866e-06, + -3.24641804781093787271e-05, + -3.24641804781093787271e-05, + 3.35166446053445504782e-04, + -2.93700743352437964023e-05, + 3.03269488552582232397e-04, + -2.40918900326344598056e-05, + 2.48820558204534102165e-04, + 8.27802464035270346319e-05, + -8.54792312332452379302e-04, + 4.74647063755037437353e-06, + -4.15071266538516597008e-05, + -4.15071266538516597008e-05, + 3.63427481731051901445e-04, + -3.73557622901099313961e-05, + 3.27115874272415044135e-04, + -3.23616690622182231118e-05, + 2.83315238433851622219e-04, + 1.06478087368629440682e-04, + -9.32351467783467118162e-04, + 1.87979034371445873837e-06, + -2.47095892917853045061e-05, + -2.47095892917853045061e-05, + 3.27024569668371480752e-04, + -2.24898874228677589208e-05, + 2.97661928194053256209e-04, + -1.72172753256989610575e-05, + 2.27442187831376464941e-04, + 6.25369616966375661696e-05, + -8.27419096402015846574e-04, ] places = 10 diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index d75dc0cfff..8d6c5afa4c 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -155,37 +155,37 @@ def test_model(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.121172052273667e01] + refe = [6.12116933882038480874e01] reff = [ - 1.1546857028815118e-02, - 1.7560407103242779e-02, - 7.1301778864729290e-04, - 2.3682630974376197e-02, - 1.6842732518204180e-02, - -2.2408109608703206e-03, - -7.9408568690697776e-03, - 9.6856119564082792e-03, - 1.9055514693144326e-05, - 8.7017502459205160e-03, - -2.7153030569749256e-02, - -8.8338555421916490e-04, - -4.3841165945453904e-02, - 5.8104108317526765e-03, - 2.6243178542006552e-03, - 7.8507845654118558e-03, - -2.2746131839858654e-02, - -2.3219464245160639e-04, + 1.15647422509625782266e-02, + 1.75814420518816301453e-02, + 7.13827966845788537686e-04, + 2.37054385620869625950e-02, + 1.68638656843611636771e-02, + -2.24281688803243482028e-03, + -7.95826529019691246425e-03, + 9.69601584607941019422e-03, + 1.91505445834121360688e-05, + 8.71431743822387999687e-03, + -2.71847766570148252629e-02, + -8.84489238783629392812e-04, + -4.38853499152154838403e-02, + 5.81882595276563344133e-03, + 2.62678184040523532775e-03, + 7.85911695413897895546e-03, + -2.27753728780730156644e-02, + -2.32454225018371036246e-04, ] refv = [ - -0.10488160947198523, - 0.016694308932682225, - 0.003444164500535988, - 0.016694308932682235, - -0.05415326614376374, - -0.0010792017166882334, - 0.003444164500535988, - -0.001079201716688233, - -0.00020932681975049773, + -1.05000199239240685212e-01, + 1.67161895068729665637e-02, + 3.44771431604021759421e-03, + 1.67161895068729804414e-02, + -5.42193765251950954287e-02, + -1.08055824874348513701e-03, + 3.44771431604021802789e-03, + -1.08055824874348470332e-03, + -2.09534775642288966522e-04, ] refe = np.reshape(refe, [-1]) @@ -618,37 +618,37 @@ def test_stripped_type_embedding_model(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.124119974943835132e01] + refe = [6.12411774224343261608e01] reff = [ - 8.617444257623986525e-03, - 1.622774527785437321e-02, - 7.219537519817814273e-04, - 2.465257480331137924e-02, - 1.507377800325802181e-02, - -2.267846199393293988e-03, - -6.217685260668888089e-03, - 9.187965356558825195e-03, - -2.082402632037372596e-05, - 6.179226045047841662e-03, - -2.505229190184387472e-02, - -7.834051085801594424e-04, - -4.104669576212031240e-02, - 4.721690416727373704e-03, - 2.565744238275521286e-03, - 7.815135916805987862e-03, - -2.015888715255471572e-02, - -2.156226559634751916e-04, + 8.63770820796567855016e-03, + 1.62522026393666710331e-02, + 7.22919459978568399415e-04, + 2.46800946249053909654e-02, + 1.50982535714741239463e-02, + -2.27024703144847314271e-03, + -6.23780053459390554371e-03, + 9.20020798171328375858e-03, + -2.07267961671176406842e-05, + 6.19326848220238136006e-03, + -2.50892401262326376898e-02, + -7.84679030459893762407e-04, + -4.10982573216296109830e-02, + 4.73129070889465389027e-03, + 2.56865811814534347399e-03, + 7.82498654115005611021e-03, + -2.01927147752160932037e-02, + -2.15924720048428232235e-04, ] refv = [ - -8.500718686149139058e-02, - 1.389198522732191729e-02, - 3.059204598073241802e-03, - 1.389198522732190515e-02, - -4.908897840490741848e-02, - -9.530658829897693113e-04, - 3.059204598073239634e-03, - -9.530658829897692029e-04, - -1.999114402095244223e-04, + -8.51412082980419621103e-02, + 1.39169542815959605686e-02, + 3.06329019931955021105e-03, + 1.39169542815959553644e-02, + -4.91657098529515099172e-02, + -9.54629874035841340107e-04, + 3.06329019931954847633e-03, + -9.54629874035841340107e-04, + -2.00155580095981352464e-04, ] refe = np.reshape(refe, [-1]) From ddd62e86d1e53417bb5d9f2c02d15208cc06cee1 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:16:51 +0800 Subject: [PATCH 12/18] update docs --- deepmd/pt/model/descriptor/dpa1.py | 4 +- deepmd/pt/model/descriptor/se_atten.py | 145 +++++++++++++++++++------ 2 files changed, 113 insertions(+), 36 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index e4358de0dc..5a91bb15af 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -450,9 +450,9 @@ def forward( Parameters ---------- - coord_ext + extended_coord The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext + extended_atype The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 15cfe44962..8eb1f6c2d9 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -67,8 +67,7 @@ def __init__( axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", - # set_davg_zero: bool = False, - set_davg_zero: bool = True, # TODO + set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, @@ -87,14 +86,65 @@ def __init__( type: Optional[str] = None, old_impl: bool = False, ): - """Construct an embedding net of type `se_atten`. - - Args: - - rcut: Cut-off radius. - - rcut_smth: Smooth hyper-parameter for pair force & energy. - - sel: For each element type, how many atoms is selected as neighbors. - - filter_neuron: Number of neurons in each hidden layers of the embedding net. - - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. + r"""Construct an embedding net of type `se_atten`. + + Parameters + ---------- + rcut : float + The cut-off radius :math:`r_c` + rcut_smth : float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron : int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim : int + Dimension of the type embedding + tebd_input_mode : str + The way to mix the type embeddings. Supported options are `concat`. + (TODO need to support stripped_type_embedding option) + resnet_dt : bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable_ln : bool + Whether to use trainable shift and scale weights in layer normalization. + type_one_side : bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn : int + Hidden dimension of the attention vectors + attn_layer : int + Number of attention layers + attn_dotr : bool + If dot the angular gate to the attention weights + attn_mask : bool + (Only support False to keep consistent with other backend references.) + (Not used in this version.) + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero : bool + Set the shift of embedding net input to zero. + activation_function : str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision : str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor : float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize : bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature : float + If not None, the scaling of attention weights is `temperature` itself. """ super().__init__() del type @@ -343,18 +393,37 @@ def forward( extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, ): - """Calculate decoded embedding for each atom. + """Compute the descriptor. - Args: - - coord: Tell atom coordinates with shape [nframes, natoms[1]*3]. - - atype: Tell atom types with shape [nframes, natoms[1]]. - - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. - - box: Tell simulation box with shape [nframes, 9]. + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall x nt + extended_atype_embd + The extended type embedding of atoms. shape: nf x nall + mapping + The index mapping, not required by this descriptor. Returns ------- - - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. - - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] + result + The descriptor. shape: nf x nloc x (ng x axis_neuron) + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + """ del mapping assert extended_atype_embd is not None @@ -516,15 +585,18 @@ def forward( input_r: Optional[torch.Tensor] = None, sw: Optional[torch.Tensor] = None, ): - """ - Args: - input_G: Input G, [nframes * nloc, nnei, embed_dim]. - nei_mask: neighbor mask, [nframes * nloc, nnei]. - input_r: normalized radial, [nframes, nloc, nei, 3]. + """Compute the multi-layer gated self-attention. - Returns - ------- - out: Output G, [nframes * nloc, nnei, embed_dim] + Parameters + ---------- + input_G + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: nf x nloc x nnei """ out = input_G # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 @@ -749,15 +821,20 @@ def forward( sw: Optional[torch.Tensor] = None, attnw_shift: float = 20.0, ): - """ - Args: - query: input G, [nframes * nloc, nnei, embed_dim]. - nei_mask: neighbor mask, [nframes * nloc, nnei]. - input_r: normalized radial, [nframes, nloc, nei, 3]. + """Compute the gated self-attention. - Returns - ------- - type_embedding: + Parameters + ---------- + query + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: (nf x nloc) x nnei + attnw_shift : float + The attention weight shift to preserve smoothness when doing padding before softmax. """ q, k, v = self.in_proj(query).chunk(3, dim=-1) # [nframes * nloc, nnei, hidden_dim] From 859a2e5fbbdfaa6d1bb9ef913f5e8f451c53fb5f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:30:49 +0800 Subject: [PATCH 13/18] Update deepmd/dpmodel/descriptor/dpa1.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/dpmodel/descriptor/dpa1.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index cbbc124267..7cf3b0ce9f 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -244,7 +244,6 @@ def __init__( if tebd_input_mode != "concat": raise NotImplementedError("tebd_input_mode != 'concat' not implemented") - del attn_mask, spin self.rcut = rcut self.rcut_smth = rcut_smth if isinstance(sel, int): From 70c9f1777f360768cfe2f0fcc2f9f5df68e37f14 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:33:33 +0800 Subject: [PATCH 14/18] cleanup `analyze_descrpt` --- deepmd/pt/model/descriptor/se_a.py | 31 -------------- deepmd/pt/model/descriptor/se_atten.py | 59 -------------------------- 2 files changed, 90 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index e17b7c5d54..8b83f0d27b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -624,34 +624,3 @@ def forward( None, sw, ) - - -def analyze_descrpt(matrix, ndescrpt, natoms): - """Collect avg, square avg and count of descriptors in a batch.""" - ntypes = natoms.shape[1] - 2 - start_index = 0 - sysr = [] - sysa = [] - sysn = [] - sysr2 = [] - sysa2 = [] - for type_i in range(ntypes): - end_index = start_index + natoms[0, 2 + type_i] - dd = matrix[:, start_index:end_index] # all descriptors for this element - start_index = end_index - dd = np.reshape( - dd, [-1, 4] - ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr.append(sumr) - sysa.append(suma) - sysn.append(sumn) - sysr2.append(sumr2) - sysa2.append(suma2) - return sysr, sysr2, sysa, sysa2, sysn diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 8eb1f6c2d9..73a974cc4e 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -8,7 +8,6 @@ Union, ) -import numpy as np import torch import torch.nn as nn import torch.nn.functional as torch_func @@ -917,61 +916,3 @@ def deserialize(cls, data: dict) -> "GatedAttentionLayer": obj.in_proj = MLPLayer.deserialize(in_proj) obj.out_proj = MLPLayer.deserialize(out_proj) return obj - - -def analyze_descrpt(matrix, ndescrpt, natoms, mixed_types=False, real_atype=None): - """Collect avg, square avg and count of descriptors in a batch.""" - ntypes = natoms.shape[1] - 2 - if not mixed_types: - sysr = [] - sysa = [] - sysn = [] - sysr2 = [] - sysa2 = [] - start_index = 0 - for type_i in range(ntypes): - end_index = start_index + natoms[0, 2 + type_i] - dd = matrix[:, start_index:end_index] - start_index = end_index - dd = np.reshape( - dd, [-1, 4] - ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr.append(sumr) - sysa.append(suma) - sysn.append(sumn) - sysr2.append(sumr2) - sysa2.append(suma2) - else: - sysr = [0.0 for i in range(ntypes)] - sysa = [0.0 for i in range(ntypes)] - sysn = [0 for i in range(ntypes)] - sysr2 = [0.0 for i in range(ntypes)] - sysa2 = [0.0 for i in range(ntypes)] - for frame_item in range(matrix.shape[0]): - dd_ff = matrix[frame_item] - atype_frame = real_atype[frame_item] - for type_i in range(ntypes): - type_idx = atype_frame == type_i - dd = dd_ff[type_idx] - dd = np.reshape(dd, [-1, 4]) # typen_atoms * nnei, 4 - ddr = dd[:, :1] - dda = dd[:, 1:] - sumr = np.sum(ddr) - suma = np.sum(dda) / 3.0 - sumn = dd.shape[0] - sumr2 = np.sum(np.multiply(ddr, ddr)) - suma2 = np.sum(np.multiply(dda, dda)) / 3.0 - sysr[type_i] += sumr - sysa[type_i] += suma - sysn[type_i] += sumn - sysr2[type_i] += sumr2 - sysa2[type_i] += suma2 - - return sysr, sysr2, sysa, sysa2, sysn From 7ce15f022f4be58a8f382392be7bc38b5dd3a6ba Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:34:19 +0800 Subject: [PATCH 15/18] add dpmodel ut --- deepmd/dpmodel/descriptor/__init__.py | 4 ++ .../common/dpmodel/test_descriptor_dpa1.py | 35 ++++++++++++++++ source/tests/pt/model/test_dpa1.py | 41 +++++++++++++------ 3 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 source/tests/common/dpmodel/test_descriptor_dpa1.py diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index a19a2aa034..bbc332588c 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dpa1 import ( + DescrptDPA1, +) from .hybrid import ( DescrptHybrid, ) @@ -15,6 +18,7 @@ __all__ = [ "DescrptSeA", "DescrptSeR", + "DescrptDPA1", "DescrptHybrid", "make_base_descriptor", ] diff --git a/source/tests/common/dpmodel/test_descriptor_dpa1.py b/source/tests/common/dpmodel/test_descriptor_dpa1.py new file mode 100644 index 0000000000..d50e0cba86 --- /dev/null +++ b/source/tests/common/dpmodel/test_descriptor_dpa1.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptDPA1, +) + +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestDescrptDPA1(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + em0 = DescrptDPA1(self.rcut, self.rcut_smth, self.sel, ntypes=2) + em0.davg = davg + em0.dstd = dstd + em1 = DescrptDPA1.deserialize(em0.serialize()) + mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist) + mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii], mm1[ii]) diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index 7a08ecc826..7567f18593 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -5,6 +5,7 @@ import numpy as np import torch +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 from deepmd.pt.model.descriptor.dpa1 import ( DescrptDPA1, ) @@ -82,9 +83,23 @@ def test_consistency( atol=atol, err_msg=err_msg, ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) # old impl if idt is False and prec == "float64" and to is False: - dd2 = DescrptDPA1( + dd3 = DescrptDPA1( self.rcut, self.rcut_smth, self.sel_mix, @@ -96,12 +111,12 @@ def test_consistency( old_impl=True, ).to(env.DEVICE) dd0_state_dict = dd0.se_atten.state_dict() - dd4_state_dict = dd2.se_atten.state_dict() + dd3_state_dict = dd3.se_atten.state_dict() dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() - dd4_state_dict_attn = dd2.se_atten.dpa1_attention.state_dict() - for i in dd4_state_dict: - dd4_state_dict[i] = ( + dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() + for i in dd3_state_dict: + dd3_state_dict[i] = ( dd0_state_dict[ i.replace(".deep_layers.", ".layers.") .replace("filter_layers_old.", "filter_layers._networks.") @@ -113,27 +128,27 @@ def test_consistency( .clone() ) if ".bias" in i and "attn_layer_norm" not in i: - dd4_state_dict[i] = dd4_state_dict[i].unsqueeze(0) - dd2.se_atten.load_state_dict(dd4_state_dict) + dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) + dd3.se_atten.load_state_dict(dd3_state_dict) dd0_state_dict_tebd = dd0.type_embedding.state_dict() - dd4_state_dict_tebd = dd2.type_embedding.state_dict() - for i in dd4_state_dict_tebd: - dd4_state_dict_tebd[i] = ( + dd3_state_dict_tebd = dd3.type_embedding.state_dict() + for i in dd3_state_dict_tebd: + dd3_state_dict_tebd[i] = ( dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] .detach() .clone() ) - dd2.type_embedding.load_state_dict(dd4_state_dict_tebd) + dd3.type_embedding.load_state_dict(dd3_state_dict_tebd) - rd2, _, _, _, _ = dd2( + rd3, _, _, _, _ = dd3( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), - rd2.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), rtol=rtol, atol=atol, err_msg=err_msg, From 9662a2c97528448f6065deaddcc50ac87ffa91cc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:37:20 +0800 Subject: [PATCH 16/18] add docs for not-used param --- deepmd/dpmodel/descriptor/dpa1.py | 4 ++-- deepmd/pt/model/descriptor/dpa1.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 7cf3b0ce9f..af58f8f2e2 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -156,7 +156,7 @@ class DescrptDPA1(NativeOP, BaseDescriptor): If dot the angular gate to the attention weights attn_mask: bool (Only support False to keep consistent with other backend references.) - (Not used in this version.) + (Not used in this version. True option is not implemented.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -182,7 +182,7 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Whether to concat type embedding at the output of the descriptor. spin (Only support None to keep consistent with other backend references.) - (Not used in this version.) + (Not used in this version. Not-none option is not implemented.) The old implementation of deepspin. Limitations diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 5a91bb15af..d255d13a3e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -138,7 +138,7 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): If dot the angular gate to the attention weights attn_mask: bool (Only support False to keep consistent with other backend references.) - (Not used in this version.) + (Not used in this version. True option is not implemented.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. @@ -164,7 +164,7 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Whether to concat type embedding at the output of the descriptor. spin (Only support None to keep consistent with other backend references.) - (Not used in this version.) + (Not used in this version. Not-none option is not implemented.) The old implementation of deepspin. Limitations From 2a7c0c7f7d8702d9c43364679d584f5eeee82796 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 24 Apr 2024 01:14:47 +0800 Subject: [PATCH 17/18] Revert "Fix numeric tests for layernorm" This reverts commit 93cdd6a3007621a44781792dda805c7f5a7f4569. --- source/tests/tf/test_data_large_batch.py | 112 +++---- source/tests/tf/test_descrpt_hybrid.py | 120 ++++---- source/tests/tf/test_descrpt_se_atten.py | 360 +++++++++++------------ source/tests/tf/test_model_se_atten.py | 112 +++---- 4 files changed, 352 insertions(+), 352 deletions(-) diff --git a/source/tests/tf/test_data_large_batch.py b/source/tests/tf/test_data_large_batch.py index baa173ad7a..dad6bbf252 100644 --- a/source/tests/tf/test_data_large_batch.py +++ b/source/tests/tf/test_data_large_batch.py @@ -201,37 +201,37 @@ def test_data_mixed_type(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.12116933882038480874e01] + refe = [6.121172052273665543e01] reff = [ - 1.15647422509625782266e-02, - 1.75814420518816301453e-02, - 7.13827966845788537686e-04, - 2.37054385620869625950e-02, - 1.68638656843611636771e-02, - -2.24281688803243482028e-03, - -7.95826529019691246425e-03, - 9.69601584607941019422e-03, - 1.91505445834121360688e-05, - 8.71431743822387999687e-03, - -2.71847766570148252629e-02, - -8.84489238783629392812e-04, - -4.38853499152154838403e-02, - 5.81882595276563344133e-03, - 2.62678184040523532775e-03, - 7.85911695413897895546e-03, - -2.27753728780730156644e-02, - -2.32454225018371036246e-04, + 1.154685702881510720e-02, + 1.756040710324277901e-02, + 7.130177886472930130e-04, + 2.368263097437618356e-02, + 1.684273251820418010e-02, + -2.240810960870319706e-03, + -7.940856869069763679e-03, + 9.685611956408284387e-03, + 1.905551469314455948e-05, + 8.701750245920510801e-03, + -2.715303056974926327e-02, + -8.833855542191653386e-04, + -4.384116594545389017e-02, + 5.810410831752661764e-03, + 2.624317854200653062e-03, + 7.850784565411857499e-03, + -2.274613183985864026e-02, + -2.321946424516053086e-04, ] refv = [ - -1.05000199239240657456e-01, - 1.67161895068729630942e-02, - 3.44771431604021672684e-03, - 1.67161895068729769720e-02, - -5.42193765251950815509e-02, - -1.08055824874348557069e-03, - 3.44771431604021585948e-03, - -1.08055824874348492017e-03, - -2.09534775642289020732e-04, + -1.048816094719852016e-01, + 1.669430893268222804e-02, + 3.444164500535986783e-03, + 1.669430893268222110e-02, + -5.415326614376372166e-02, + -1.079201716688232750e-03, + 3.444164500535985916e-03, + -1.079201716688232750e-03, + -2.093268197504977288e-04, ] refe = np.reshape(refe, [-1]) @@ -400,37 +400,37 @@ def test_stripped_data_mixed_type(self): np.savetxt("f11.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v11.out", v.reshape([1, -1]), delimiter=",") - refe = [6.12411774224343261608e01] + refe = [6.124119974943835132e01] reff = [ - 8.63770820796567855016e-03, - 1.62522026393666710331e-02, - 7.22919459978568399415e-04, - 2.46800946249053909654e-02, - 1.50982535714741239463e-02, - -2.27024703144847314271e-03, - -6.23780053459390554371e-03, - 9.20020798171328375858e-03, - -2.07267961671176406842e-05, - 6.19326848220238136006e-03, - -2.50892401262326376898e-02, - -7.84679030459893762407e-04, - -4.10982573216296109830e-02, - 4.73129070889465389027e-03, - 2.56865811814534347399e-03, - 7.82498654115005611021e-03, - -2.01927147752160932037e-02, - -2.15924720048428232235e-04, + 8.617444257623986525e-03, + 1.622774527785437321e-02, + 7.219537519817814273e-04, + 2.465257480331137924e-02, + 1.507377800325802181e-02, + -2.267846199393293988e-03, + -6.217685260668888089e-03, + 9.187965356558825195e-03, + -2.082402632037372596e-05, + 6.179226045047841662e-03, + -2.505229190184387472e-02, + -7.834051085801594424e-04, + -4.104669576212031240e-02, + 4.721690416727373704e-03, + 2.565744238275521286e-03, + 7.815135916805987862e-03, + -2.015888715255471572e-02, + -2.156226559634751916e-04, ] refv = [ - -8.51412082980419621103e-02, - 1.39169542815959588339e-02, - 3.06329019931955021105e-03, - 1.39169542815959536297e-02, - -4.91657098529515168561e-02, - -9.54629874035841556948e-04, - 3.06329019931954891001e-03, - -9.54629874035841773788e-04, - -2.00155580095981406675e-04, + -8.500718686149140446e-02, + 1.389198522732191729e-02, + 3.059204598073241802e-03, + 1.389198522732190168e-02, + -4.908897840490741155e-02, + -9.530658829897690944e-04, + 3.059204598073239634e-03, + -9.530658829897688776e-04, + -1.999114402095244765e-04, ] refe = np.reshape(refe, [-1]) diff --git a/source/tests/tf/test_descrpt_hybrid.py b/source/tests/tf/test_descrpt_hybrid.py index 7f5e064376..6aa04118da 100644 --- a/source/tests/tf/test_descrpt_hybrid.py +++ b/source/tests/tf/test_descrpt_hybrid.py @@ -185,66 +185,66 @@ def test_descriptor_hybrid(self): ] # below is copied from test_descript_se_atten.py ref_dout2 = [ - 1.35077997858830281628e-04, - -9.36317565146126714985e-05, - -9.36317565146126714985e-05, - 6.49457155161046269156e-05, - -3.44426119482271894060e-04, - 2.38892351975707574810e-04, - -2.16192628113445024177e-04, - 1.49838432021978586618e-04, - 5.19172506251499308108e-04, - -3.60044742999178198160e-04, - 1.00648981900694042455e-04, - -7.51687985725674168679e-05, - -7.51687985725674168679e-05, - 5.62621404496089786633e-05, - -2.78288905170686305408e-04, - 2.08248552733448707985e-04, - -1.57037506111419247626e-04, - 1.17240613774749092711e-04, - 4.09846227953978995209e-04, - -3.06582508385239355716e-04, - 7.56236313388503977959e-05, - -5.88249954799233110928e-05, - -5.88249954799233110928e-05, - 4.57767614608878164778e-05, - -2.17191782618980676941e-04, - 1.69041932410352632298e-04, - -1.21708419050609283887e-04, - 9.46734475047640323129e-05, - 3.22101565810662901230e-04, - -2.50667145896081176772e-04, - 1.12972766463605449241e-04, - -7.95331652304217509748e-05, - -7.95331652304217509748e-05, - 5.59918979793375151091e-05, - -2.90669309441163412500e-04, - 2.04626666596480422588e-04, - -1.87383581443938113499e-04, - 1.31917380775058677711e-04, - 4.44613289651917854839e-04, - -3.13002780120454830552e-04, - 1.30198051172878586420e-04, - -8.88399346622230731045e-05, - -8.88399346622230731045e-05, - 6.06275354032895547767e-05, - -3.23173886613725041324e-04, - 2.20522620462074609186e-04, - -2.17878181114203837987e-04, - 1.48663514408247710887e-04, - 4.99693951217273298233e-04, - -3.40973735611388808521e-04, - 1.01636483586918407768e-04, - -7.45585238544824841465e-05, - -7.45585238544824841465e-05, - 5.47161372646580776566e-05, - -2.74022957033491422908e-04, - 2.01084733576426032218e-04, - -1.66621218118959135701e-04, - 1.22224760787930633501e-04, - 4.13566215420014648540e-04, - -3.03467107774532218571e-04, + 1.3503570575883254e-04, + -9.3606804794552518e-05, + -9.3606804794552518e-05, + 6.4931435609575354e-05, + -3.4432462227712845e-04, + 2.3883309310633266e-04, + -2.1612770334269806e-04, + 1.4980041766865035e-04, + 5.1902342465554648e-04, + -3.5995814159000579e-04, + 1.0061650355705337e-04, + -7.5148260042556979e-05, + -7.5148260042556979e-05, + 5.6249549384058458e-05, + -2.7820514647114664e-04, + 2.0819618461713165e-04, + -1.5698895407951743e-04, + 1.1721016363267746e-04, + 4.0972585703616773e-04, + -3.0650763759131061e-04, + 7.5599650998659526e-05, + -5.8808888720672558e-05, + -5.8808888720672558e-05, + 4.5766209906762655e-05, + -2.1712714013251668e-04, + 1.6899894453623564e-04, + -1.2167120597162636e-04, + 9.4648599144861605e-05, + 3.2200758382615601e-04, + -2.5060486486718734e-04, + 1.1293831101452813e-04, + -7.9512063028041913e-05, + -7.9512063028041913e-05, + 5.5979262682797850e-05, + -2.9058515610909440e-04, + 2.0457554106366365e-04, + -1.8732839505532627e-04, + 1.3188376232775540e-04, + 4.4448730317793450e-04, + -3.1292650304617497e-04, + 1.3015885894252541e-04, + -8.8816609587789126e-05, + -8.8816609587789126e-05, + 6.0613949400496957e-05, + -3.2308121544925519e-04, + 2.2046786823295058e-04, + -2.1781481424814687e-04, + 1.4862599684199924e-04, + 4.9955378034266583e-04, + -3.4089120488765758e-04, + 1.0160496779809329e-04, + -7.4538471222199861e-05, + -7.4538471222199861e-05, + 5.4703671679263269e-05, + -2.7394267959121653e-04, + 2.0103409637607701e-04, + -1.6657135958432620e-04, + 1.2219321453198225e-04, + 4.1344754259964935e-04, + -3.0339251136512270e-04, ] places = 10 diff --git a/source/tests/tf/test_descrpt_se_atten.py b/source/tests/tf/test_descrpt_se_atten.py index f4ce374c42..7a1bfd18f6 100644 --- a/source/tests/tf/test_descrpt_se_atten.py +++ b/source/tests/tf/test_descrpt_se_atten.py @@ -150,66 +150,66 @@ def test_descriptor_two_sides(self): np.savetxt("two.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 1.35077997858830281628e-04, - -9.36317565146126714985e-05, - -9.36317565146126714985e-05, - 6.49457155161046269156e-05, - -3.44426119482271894060e-04, - 2.38892351975707574810e-04, - -2.16192628113445024177e-04, - 1.49838432021978586618e-04, - 5.19172506251499308108e-04, - -3.60044742999178198160e-04, - 1.00648981900694042455e-04, - -7.51687985725674168679e-05, - -7.51687985725674168679e-05, - 5.62621404496089786633e-05, - -2.78288905170686305408e-04, - 2.08248552733448707985e-04, - -1.57037506111419247626e-04, - 1.17240613774749092711e-04, - 4.09846227953978995209e-04, - -3.06582508385239355716e-04, - 7.56236313388503977959e-05, - -5.88249954799233110928e-05, - -5.88249954799233110928e-05, - 4.57767614608878164778e-05, - -2.17191782618980676941e-04, - 1.69041932410352632298e-04, - -1.21708419050609283887e-04, - 9.46734475047640323129e-05, - 3.22101565810662901230e-04, - -2.50667145896081176772e-04, - 1.12972766463605449241e-04, - -7.95331652304217509748e-05, - -7.95331652304217509748e-05, - 5.59918979793375151091e-05, - -2.90669309441163412500e-04, - 2.04626666596480422588e-04, - -1.87383581443938113499e-04, - 1.31917380775058677711e-04, - 4.44613289651917854839e-04, - -3.13002780120454830552e-04, - 1.30198051172878586420e-04, - -8.88399346622230731045e-05, - -8.88399346622230731045e-05, - 6.06275354032895547767e-05, - -3.23173886613725041324e-04, - 2.20522620462074609186e-04, - -2.17878181114203837987e-04, - 1.48663514408247710887e-04, - 4.99693951217273298233e-04, - -3.40973735611388808521e-04, - 1.01636483586918407768e-04, - -7.45585238544824841465e-05, - -7.45585238544824841465e-05, - 5.47161372646580776566e-05, - -2.74022957033491422908e-04, - 2.01084733576426032218e-04, - -1.66621218118959135701e-04, - 1.22224760787930633501e-04, - 4.13566215420014648540e-04, - -3.03467107774532218571e-04, + 1.3503570575883254e-04, + -9.3606804794552518e-05, + -9.3606804794552518e-05, + 6.4931435609575354e-05, + -3.4432462227712845e-04, + 2.3883309310633266e-04, + -2.1612770334269806e-04, + 1.4980041766865035e-04, + 5.1902342465554648e-04, + -3.5995814159000579e-04, + 1.0061650355705337e-04, + -7.5148260042556979e-05, + -7.5148260042556979e-05, + 5.6249549384058458e-05, + -2.7820514647114664e-04, + 2.0819618461713165e-04, + -1.5698895407951743e-04, + 1.1721016363267746e-04, + 4.0972585703616773e-04, + -3.0650763759131061e-04, + 7.5599650998659526e-05, + -5.8808888720672558e-05, + -5.8808888720672558e-05, + 4.5766209906762655e-05, + -2.1712714013251668e-04, + 1.6899894453623564e-04, + -1.2167120597162636e-04, + 9.4648599144861605e-05, + 3.2200758382615601e-04, + -2.5060486486718734e-04, + 1.1293831101452813e-04, + -7.9512063028041913e-05, + -7.9512063028041913e-05, + 5.5979262682797850e-05, + -2.9058515610909440e-04, + 2.0457554106366365e-04, + -1.8732839505532627e-04, + 1.3188376232775540e-04, + 4.4448730317793450e-04, + -3.1292650304617497e-04, + 1.3015885894252541e-04, + -8.8816609587789126e-05, + -8.8816609587789126e-05, + 6.0613949400496957e-05, + -3.2308121544925519e-04, + 2.2046786823295058e-04, + -2.1781481424814687e-04, + 1.4862599684199924e-04, + 4.9955378034266583e-04, + -3.4089120488765758e-04, + 1.0160496779809329e-04, + -7.4538471222199861e-05, + -7.4538471222199861e-05, + 5.4703671679263269e-05, + -2.7394267959121653e-04, + 2.0103409637607701e-04, + -1.6657135958432620e-04, + 1.2219321453198225e-04, + 4.1344754259964935e-04, + -3.0339251136512270e-04, ] places = 10 @@ -328,66 +328,66 @@ def test_descriptor_one_side(self): np.savetxt("one.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 8.93630739076099766573e-05, - -3.89301763666544977088e-05, - -3.89301763666544977088e-05, - 1.69776207161541659875e-05, - -2.91934413405367434308e-04, - 1.27275579758193970945e-04, - -1.80678576267614851526e-04, - 7.86981804444128273503e-05, - 4.22180092132026806885e-04, - -1.84021204552106459797e-04, - 6.50166826308631336630e-05, - -3.08191630112232239067e-05, - -3.08191630112232239067e-05, - 1.46662082284045218266e-05, - -2.32818649311590855893e-04, - 1.10619882905346373389e-04, - -1.30477133579203922803e-04, - 6.18026466291577325669e-05, - 3.29098263271154821506e-04, - -1.56269574751685376771e-04, - 5.07138199677916164739e-05, - -2.35171440781703185510e-05, - -2.35171440781703185510e-05, - 1.09213797907981395710e-05, - -1.86279366618262112341e-04, - 8.64577620996147407865e-05, - -1.03296053419269992513e-04, - 4.78913622480582772448e-05, - 2.62378744147910732392e-04, - -1.21753360060300813640e-04, - 7.82644227540903690814e-05, - -3.25084361414888650958e-05, - -3.25084361414888650958e-05, - 1.35041631983765535098e-05, - -2.53679234140297192677e-04, - 1.05375493947693795707e-04, - -1.60519879294703589519e-04, - 6.66744631236456129558e-05, - 3.68443126822399244329e-04, - -1.53045684128227086913e-04, - 9.11756668850765601567e-05, - -3.66229408732609030826e-05, - -3.66229408732609030826e-05, - 1.47120125015788778301e-05, - -2.83723246380394433092e-04, - 1.13968452838666050924e-04, - -1.87270570170312914944e-04, - 7.52199008968667767218e-05, - 4.16441090538891684186e-04, - -1.67277425363850822723e-04, - 6.95274814976590320665e-05, - -3.02348814013024743688e-05, - -3.02348814013024743688e-05, - 1.31585743503078956499e-05, - -2.37479534432029007343e-04, - 1.03311591705779548338e-04, - -1.42227987950226271961e-04, - 6.18410015482571886070e-05, - 3.40414922285898623351e-04, - -1.48076286203042110793e-04, + 8.9336098555659429e-05, + -3.8921422089719007e-05, + -3.8921422089719007e-05, + 1.6975109833017758e-05, + -2.9184951813034413e-04, + 1.2724836941382651e-04, + -1.8062533253590169e-04, + 7.8681048972093648e-05, + 4.2206017420030542e-04, + -1.8398310612921889e-04, + 6.4996467281506633e-05, + -3.0812041327073575e-05, + -3.0812041327073575e-05, + 1.4663988013438402e-05, + -2.3274950984084172e-04, + 1.1059587214865573e-04, + -1.3043761448464089e-04, + 6.1788865409826698e-05, + 3.2900269837104958e-04, + -1.5623668424484728e-04, + 5.0697927477465942e-05, + -2.3511768544350768e-05, + -2.3511768544350768e-05, + 1.0919808814040025e-05, + -1.8622373494960208e-04, + 8.6439275444049409e-05, + -1.0326450661269683e-04, + 4.7880797898768150e-05, + 2.6230208262918372e-04, + -1.2172811361250681e-04, + 7.8240863239649707e-05, + -3.2501260967978116e-05, + -3.2501260967978116e-05, + 1.3502267073810926e-05, + -2.5360559687597850e-04, + 1.0535336854834091e-04, + -1.6047265448841568e-04, + 6.6660202062744658e-05, + 3.6833864909272261e-04, + -1.5301457671691837e-04, + 9.1148582997925288e-05, + -3.6614945467066073e-05, + -3.6614945467066073e-05, + 1.4709958908948206e-05, + -2.8364168092837332e-04, + 1.1394466218003484e-04, + -1.8721615730559043e-04, + 7.5203967811613109e-05, + 4.1632420070310456e-04, + -1.6724364343353009e-04, + 6.9506193268190631e-05, + -3.0228106532898472e-05, + -3.0228106532898472e-05, + 1.3156705594652870e-05, + -2.3740975974826574e-04, + 1.0328972070195332e-04, + -1.4218547815143072e-04, + 6.1827596642872941e-05, + 3.4031715116440432e-04, + -1.4804591640658066e-04, ] places = 10 @@ -499,66 +499,66 @@ def test_stripped_type_embedding_descriptor_two_sides(self): np.savetxt("two1.out", model_dout.reshape([1, -1]), delimiter=",") ref_dout = [ - 2.91097766899578214544e-06, - -3.29852641315371480153e-05, - -3.29852641315371480153e-05, - 3.79203396610324763253e-04, - -3.08296489918391639377e-05, - 3.54494448654088176176e-04, - -2.39859951795545287153e-05, - 2.74566675797922735754e-04, - 8.48899306339350606405e-05, - -9.75279256930798588154e-04, - 8.68233546069119197236e-07, - -1.59734540671145569350e-05, - -1.59734540671145569350e-05, - 3.25058299172223158675e-04, - -1.50870029997722798618e-05, - 3.07130006247707560297e-04, - -1.04749968193353404274e-05, - 2.08603290940140382453e-04, - 4.06672203401530534743e-05, - -8.24818142292956771496e-04, - 5.96048958156013435895e-07, - -1.26616643393676577874e-05, - -1.26616643393676577874e-05, - 2.71386217904519277955e-04, - -1.16335252819255226156e-05, - 2.49225002219057890918e-04, - -8.05872731607348350672e-06, - 1.72064906604221990903e-04, - 3.17578679792106490973e-05, - -6.80014462388431415590e-04, - 3.14589844246059013866e-06, - -3.24641804781093787271e-05, - -3.24641804781093787271e-05, - 3.35166446053445504782e-04, - -2.93700743352437964023e-05, - 3.03269488552582232397e-04, - -2.40918900326344598056e-05, - 2.48820558204534102165e-04, - 8.27802464035270346319e-05, - -8.54792312332452379302e-04, - 4.74647063755037437353e-06, - -4.15071266538516597008e-05, - -4.15071266538516597008e-05, - 3.63427481731051901445e-04, - -3.73557622901099313961e-05, - 3.27115874272415044135e-04, - -3.23616690622182231118e-05, - 2.83315238433851622219e-04, - 1.06478087368629440682e-04, - -9.32351467783467118162e-04, - 1.87979034371445873837e-06, - -2.47095892917853045061e-05, - -2.47095892917853045061e-05, - 3.27024569668371480752e-04, - -2.24898874228677589208e-05, - 2.97661928194053256209e-04, - -1.72172753256989610575e-05, - 2.27442187831376464941e-04, - 6.25369616966375661696e-05, - -8.27419096402015846574e-04, + 2.910296358673981606e-06, + -3.297689549631518680e-05, + -3.297689549631518680e-05, + 3.790996417030466402e-04, + -3.082208958603667925e-05, + 3.544004728264616810e-04, + -2.397997896082787038e-05, + 2.744923480535521121e-04, + 8.486866768450577558e-05, + -9.750155670867453753e-04, + 8.680391572974659491e-07, + -1.596948473518331016e-05, + -1.596948473518331016e-05, + 3.249686279109944903e-04, + -1.508338456375446526e-05, + 3.070479490395221158e-04, + -1.047241469038003787e-05, + 2.085462014454144320e-04, + 4.065724483202033993e-05, + -8.245932936607477210e-04, + 5.959146184656097397e-07, + -1.265847984116858078e-05, + -1.265847984116858078e-05, + 2.713109337202710531e-04, + -1.163070862097512446e-05, + 2.491582022684395484e-04, + -8.056716526966370043e-06, + 1.720174894426871476e-04, + 3.174999037064446555e-05, + -6.798281455902291598e-04, + 3.145148216891492605e-06, + -3.245585831548520087e-05, + -3.245585831548520087e-05, + 3.350745140453206166e-04, + -2.936281422860278914e-05, + 3.031890775924862423e-04, + -2.408578375619038739e-05, + 2.487530226589902390e-04, + 8.275930808338685728e-05, + -8.545607559813118157e-04, + 4.745334138737575192e-06, + -4.149649152356857482e-05, + -4.149649152356857482e-05, + 3.633282453063247882e-04, + -3.734652895210441184e-05, + 3.270295126452897193e-04, + -3.235347865588130865e-05, + 2.832387658145111447e-04, + 1.064511649928167193e-04, + -9.321000322425568741e-04, + 1.879347284602219830e-06, + -2.470327295060103235e-05, + -2.470327295060103235e-05, + 3.269344178119031551e-04, + -2.248434624179290029e-05, + 2.975826199248595046e-04, + -1.721291645154368551e-05, + 2.273800448313684436e-04, + 6.252118835933537862e-05, + -8.271938096175299659e-04, ] places = 10 diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index 8d6c5afa4c..d75dc0cfff 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -155,37 +155,37 @@ def test_model(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.12116933882038480874e01] + refe = [6.121172052273667e01] reff = [ - 1.15647422509625782266e-02, - 1.75814420518816301453e-02, - 7.13827966845788537686e-04, - 2.37054385620869625950e-02, - 1.68638656843611636771e-02, - -2.24281688803243482028e-03, - -7.95826529019691246425e-03, - 9.69601584607941019422e-03, - 1.91505445834121360688e-05, - 8.71431743822387999687e-03, - -2.71847766570148252629e-02, - -8.84489238783629392812e-04, - -4.38853499152154838403e-02, - 5.81882595276563344133e-03, - 2.62678184040523532775e-03, - 7.85911695413897895546e-03, - -2.27753728780730156644e-02, - -2.32454225018371036246e-04, + 1.1546857028815118e-02, + 1.7560407103242779e-02, + 7.1301778864729290e-04, + 2.3682630974376197e-02, + 1.6842732518204180e-02, + -2.2408109608703206e-03, + -7.9408568690697776e-03, + 9.6856119564082792e-03, + 1.9055514693144326e-05, + 8.7017502459205160e-03, + -2.7153030569749256e-02, + -8.8338555421916490e-04, + -4.3841165945453904e-02, + 5.8104108317526765e-03, + 2.6243178542006552e-03, + 7.8507845654118558e-03, + -2.2746131839858654e-02, + -2.3219464245160639e-04, ] refv = [ - -1.05000199239240685212e-01, - 1.67161895068729665637e-02, - 3.44771431604021759421e-03, - 1.67161895068729804414e-02, - -5.42193765251950954287e-02, - -1.08055824874348513701e-03, - 3.44771431604021802789e-03, - -1.08055824874348470332e-03, - -2.09534775642288966522e-04, + -0.10488160947198523, + 0.016694308932682225, + 0.003444164500535988, + 0.016694308932682235, + -0.05415326614376374, + -0.0010792017166882334, + 0.003444164500535988, + -0.001079201716688233, + -0.00020932681975049773, ] refe = np.reshape(refe, [-1]) @@ -618,37 +618,37 @@ def test_stripped_type_embedding_model(self): np.savetxt("f.out", f.reshape([1, -1]), delimiter=",") np.savetxt("v.out", v.reshape([1, -1]), delimiter=",") - refe = [6.12411774224343261608e01] + refe = [6.124119974943835132e01] reff = [ - 8.63770820796567855016e-03, - 1.62522026393666710331e-02, - 7.22919459978568399415e-04, - 2.46800946249053909654e-02, - 1.50982535714741239463e-02, - -2.27024703144847314271e-03, - -6.23780053459390554371e-03, - 9.20020798171328375858e-03, - -2.07267961671176406842e-05, - 6.19326848220238136006e-03, - -2.50892401262326376898e-02, - -7.84679030459893762407e-04, - -4.10982573216296109830e-02, - 4.73129070889465389027e-03, - 2.56865811814534347399e-03, - 7.82498654115005611021e-03, - -2.01927147752160932037e-02, - -2.15924720048428232235e-04, + 8.617444257623986525e-03, + 1.622774527785437321e-02, + 7.219537519817814273e-04, + 2.465257480331137924e-02, + 1.507377800325802181e-02, + -2.267846199393293988e-03, + -6.217685260668888089e-03, + 9.187965356558825195e-03, + -2.082402632037372596e-05, + 6.179226045047841662e-03, + -2.505229190184387472e-02, + -7.834051085801594424e-04, + -4.104669576212031240e-02, + 4.721690416727373704e-03, + 2.565744238275521286e-03, + 7.815135916805987862e-03, + -2.015888715255471572e-02, + -2.156226559634751916e-04, ] refv = [ - -8.51412082980419621103e-02, - 1.39169542815959605686e-02, - 3.06329019931955021105e-03, - 1.39169542815959553644e-02, - -4.91657098529515099172e-02, - -9.54629874035841340107e-04, - 3.06329019931954847633e-03, - -9.54629874035841340107e-04, - -2.00155580095981352464e-04, + -8.500718686149139058e-02, + 1.389198522732191729e-02, + 3.059204598073241802e-03, + 1.389198522732190515e-02, + -4.908897840490741848e-02, + -9.530658829897693113e-04, + 3.059204598073239634e-03, + -9.530658829897692029e-04, + -1.999114402095244223e-04, ] refe = np.reshape(refe, [-1]) From 8e17498f4136b9904ff87db02683b5d11daa50e3 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 24 Apr 2024 02:19:07 +0800 Subject: [PATCH 18/18] Make layernorm compat with old models --- deepmd/dpmodel/descriptor/dpa1.py | 18 ++++- deepmd/pt/model/descriptor/dpa1.py | 8 ++ deepmd/pt/model/descriptor/se_atten.py | 17 +++- deepmd/tf/descriptor/se_atten.py | 78 +++++++++++++------ deepmd/tf/env.py | 4 +- deepmd/utils/argcheck.py | 2 + .../tests/consistent/descriptor/test_dpa1.py | 9 +++ 7 files changed, 109 insertions(+), 27 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index af58f8f2e2..a551a57628 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -144,6 +144,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor): If the weights of this descriptors are trainable. trainable_ln: bool Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -227,6 +229,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, @@ -243,6 +246,9 @@ def __init__( # TODO if tebd_input_mode != "concat": raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 self.rcut = rcut self.rcut_smth = rcut_smth @@ -259,6 +265,7 @@ def __init__( self.resnet_dt = resnet_dt self.trainable = trainable self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.type_one_side = type_one_side self.attn = attn self.attn_layer = attn_layer @@ -312,6 +319,7 @@ def __init__( normalize=self.normalize, temperature=self.temperature, trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, smooth=self.smooth, precision=self.precision, ) @@ -546,6 +554,7 @@ def serialize(self) -> dict: "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "smooth_type_embedding": self.smooth, "type_one_side": self.type_one_side, "concat_output_tebd": self.concat_output_tebd, @@ -615,6 +624,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, trainable_ln: bool = True, + ln_eps: float = 1e-5, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -630,6 +640,7 @@ def __init__( self.normalize = normalize self.temperature = temperature self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.smooth = smooth self.precision = precision self.network_type = NeighborGatedAttentionLayer @@ -645,6 +656,7 @@ def __init__( normalize=normalize, temperature=temperature, trainable_ln=trainable_ln, + ln_eps=ln_eps, smooth=smooth, precision=precision, ) @@ -701,6 +713,7 @@ def serialize(self): "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "precision": self.precision, "attention_layers": [layer.serialize() for layer in self.attention_layers], } @@ -737,6 +750,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, trainable_ln: bool = True, + ln_eps: float = 1e-5, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -751,6 +765,7 @@ def __init__( self.normalize = normalize self.temperature = temperature self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.precision = precision self.attention_layer = GatedAttentionLayer( nnei, @@ -765,7 +780,7 @@ def __init__( precision=precision, ) self.attn_layer_norm = LayerNorm( - self.embed_dim, trainable=self.trainable_ln, precision=precision + self.embed_dim, eps=ln_eps, trainable=self.trainable_ln, precision=precision ) def call( @@ -799,6 +814,7 @@ def serialize(self) -> dict: "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "precision": self.precision, "attention_layer": self.attention_layer.serialize(), "attn_layer_norm": self.attn_layer_norm.serialize(), diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index d255d13a3e..852e08403c 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -126,6 +126,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): If the weights of this descriptors are trainable. trainable_ln: bool Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -210,6 +212,7 @@ def __init__( concat_output_tebd: bool = True, trainable: bool = True, trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, smooth_type_embedding: bool = True, type_one_side: bool = False, # not implemented @@ -231,6 +234,9 @@ def __init__( # TODO if tebd_input_mode != "concat": raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( @@ -258,6 +264,7 @@ def __init__( exclude_types=exclude_types, env_protection=env_protection, trainable_ln=trainable_ln, + ln_eps=ln_eps, old_impl=old_impl, ) self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) @@ -392,6 +399,7 @@ def serialize(self) -> dict: "normalize": obj.normalize, "temperature": obj.temperature, "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, "smooth_type_embedding": obj.smooth, "type_one_side": obj.type_one_side, "concat_output_tebd": self.concat_output_tebd, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 73a974cc4e..66da86ce29 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -82,6 +82,7 @@ def __init__( exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, type: Optional[str] = None, old_impl: bool = False, ): @@ -112,6 +113,8 @@ def __init__( y = x + dt * \phi (Wx + b) trainable_ln : bool Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, Optional + The epsilon value for layer normalization. type_one_side : bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -170,6 +173,10 @@ def __init__( self.type_one_side = type_one_side self.env_protection = env_protection self.trainable_ln = trainable_ln + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + self.ln_eps = ln_eps self.old_impl = old_impl if isinstance(sel, int): @@ -209,6 +216,7 @@ def __init__( normalize=self.normalize, temperature=self.temperature, trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, smooth=self.smooth, precision=self.precision, ) @@ -540,6 +548,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, trainable_ln: bool = True, + ln_eps: float = 1e-5, smooth: bool = True, precision: str = DEFAULT_PRECISION, ): @@ -555,6 +564,7 @@ def __init__( self.normalize = normalize self.temperature = temperature self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.smooth = smooth self.precision = precision self.network_type = NeighborGatedAttentionLayer @@ -571,6 +581,7 @@ def __init__( normalize=normalize, temperature=temperature, trainable_ln=trainable_ln, + ln_eps=ln_eps, smooth=smooth, precision=precision, ) @@ -641,6 +652,7 @@ def serialize(self) -> dict: "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "precision": self.precision, "attention_layers": [layer.serialize() for layer in self.attention_layers], } @@ -677,6 +689,7 @@ def __init__( temperature: Optional[float] = None, smooth: bool = True, trainable_ln: bool = True, + ln_eps: float = 1e-5, precision: str = DEFAULT_PRECISION, ): """Construct a neighbor-wise attention layer.""" @@ -691,6 +704,7 @@ def __init__( self.temperature = temperature self.precision = precision self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.attention_layer = GatedAttentionLayer( nnei, embed_dim, @@ -704,7 +718,7 @@ def __init__( precision=precision, ) self.attn_layer_norm = LayerNorm( - self.embed_dim, trainable=trainable_ln, precision=precision + self.embed_dim, eps=ln_eps, trainable=trainable_ln, precision=precision ) def forward( @@ -738,6 +752,7 @@ def serialize(self) -> dict: "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "precision": self.precision, "attention_layer": self.attention_layer.serialize(), "attn_layer_norm": self.attn_layer_norm.serialize(), diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 34e99b045a..0ba426ee4b 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -97,51 +97,53 @@ class DescrptSeAtten(DescrptSeA): Parameters ---------- - rcut + rcut: float The cut-off radius :math:`r_c` - rcut_smth + rcut_smth: float From where the environment matrix should be smoothed :math:`r_s` - sel : list[int], int + sel: list[int], int list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius int: the total maxmum number of atoms in the cut-off radius - neuron : list[int] + neuron: list[int] Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` - axis_neuron + axis_neuron: int Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) - resnet_dt + resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) - trainable + trainable: bool If the weights of embedding net are trainable. - seed + seed: int, Optional Random seed for initializing the network parameters. - type_one_side + type_one_side: bool Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. - set_davg_zero + set_davg_zero: bool Set the shift of embedding net input to zero. - activation_function + activation_function: str The activation function in the embedding net. Supported options are |ACTIVATION_FN| - precision + precision: str The precision of the embedding net parameters. Supported options are |PRECISION| - uniform_seed + uniform_seed: bool Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed - attn + attn: int The length of hidden vector during scale-dot attention computation. - attn_layer + attn_layer: int The number of layers in attention mechanism. - attn_dotr + attn_dotr: bool Whether to dot the relative coordinates on the attention weights as a gated scheme. - attn_mask + attn_mask: bool Whether to mask the diagonal in the attention weights. - multi_task + ln_eps: float, Optional + The epsilon value for layer normalization. + multi_task: bool If the model has multi fitting nets to train. - stripped_type_embedding + stripped_type_embedding: bool Whether to strip the type embedding into a separated embedding network. Default value will be True in `se_atten_v2` descriptor. - smooth_type_embedding + smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. @@ -182,6 +184,7 @@ def __init__( normalize=True, temperature=None, trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-3, concat_output_tebd: bool = True, env_protection: float = 0.0, # not implement!! **kwargs, @@ -203,6 +206,9 @@ def __init__( raise NotImplementedError("concat_output_tebd is not supported.") if env_protection != 0.0: raise NotImplementedError("env_protection != 0.0 is not supported.") + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-3 if isinstance(sel, list): sel = sum(sel) DescrptSeA.__init__( @@ -235,6 +241,7 @@ def __init__( self.stripped_type_embedding = stripped_type_embedding self.smooth = smooth_type_embedding self.trainable_ln = trainable_ln + self.ln_eps = ln_eps self.ntypes = ntypes self.att_n = attn self.attn_layer = attn_layer @@ -1037,6 +1044,7 @@ def _attention_layers( seed=self.seed, uniform_seed=self.uniform_seed, trainable=self.trainable_ln, + eps=self.ln_eps, initial_variables=self.attention_layer_variables, ) return input_xyz @@ -1347,6 +1355,17 @@ def init_variables( graph_def, suffix=suffix ) + def compat_ln_pattern(old_key): + pattern = r"attention_layer_(\d+)/(layer_normalization)_\d+" + replacement = r"attention_layer_\1/\2" + if bool(re.search(pattern, old_key)): + new_key = re.sub(pattern, replacement, old_key) + v = self.attention_layer_variables.pop(old_key) + self.attention_layer_variables[new_key] = v + + for item_key in list(self.attention_layer_variables.keys()): + compat_ln_pattern(item_key) + if self.stripped_type_embedding: self.two_side_embeeding_net_variables = ( get_extra_embedding_net_variables_from_graph_def( @@ -1469,6 +1488,7 @@ def serialize_attention_layers( dotr: bool, do_mask: bool, trainable_ln: bool, + ln_eps: float, variables: dict, bias: bool = True, suffix: str = "", @@ -1481,6 +1501,7 @@ def serialize_attention_layers( "dotr": dotr, "do_mask": do_mask, "trainable_ln": trainable_ln, + "ln_eps": ln_eps, "precision": self.precision.name, "attention_layers": [], } @@ -1536,6 +1557,7 @@ def serialize_attention_layers( layer_norm = LayerNorm( embed_dim, trainable=self.trainable_ln, + eps=self.ln_eps, precision=self.precision.name, ) layer_norm["matrix"] = attention_layer_params[layer_idx][ @@ -1554,6 +1576,7 @@ def serialize_attention_layers( }, "attn_layer_norm": layer_norm.serialize(), "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, } ) return data @@ -1724,6 +1747,7 @@ def serialize(self, suffix: str = "") -> dict: "resnet_dt": self.filter_resnet_dt, "smooth_type_embedding": self.smooth, "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, "precision": self.filter_precision.name, "embeddings": self.serialize_network( ntypes=self.ntypes, @@ -1746,6 +1770,7 @@ def serialize(self, suffix: str = "") -> dict: dotr=self.attn_dotr, do_mask=self.attn_mask, trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, variables=self.attention_layer_variables, suffix=suffix, ), @@ -1773,12 +1798,12 @@ class DescrptDPA1Compat(DescrptSeAtten): The cut-off radius :math:`r_c` rcut_smth: float From where the environment matrix should be smoothed :math:`r_s` - sel : list[int], int + sel: list[int], int list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius int: the total maxmum number of atoms in the cut-off radius - ntypes : int + ntypes: int Number of element types - neuron : list[int] + neuron: list[int] Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` axis_neuron: int Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) @@ -1794,6 +1819,8 @@ class DescrptDPA1Compat(DescrptSeAtten): If the weights of this descriptors are trainable. trainable_ln: bool Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. @@ -1867,6 +1894,7 @@ def __init__( normalize: bool = True, temperature: Optional[float] = None, trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-3, smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, @@ -1890,6 +1918,9 @@ def __init__( raise NotImplementedError( "old implementation of attn_mask is not supported." ) + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-3 super().__init__( rcut, @@ -1914,6 +1945,7 @@ def __init__( multi_task=True, stripped_type_embedding=False, trainable_ln=trainable_ln, + ln_eps=ln_eps, smooth_type_embedding=smooth_type_embedding, env_protection=env_protection, ) diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 0bd637dc02..3d4edadd8a 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -188,8 +188,8 @@ def dlopen_library(module: str, filename: str): r"attention_layer_(\d+)/(c_out)/(bias)|" r"attention_layer_(\d+)/(layer_normalization)/(beta)|" r"attention_layer_(\d+)/(layer_normalization)/(gamma)|" - # r"attention_layer_(\d+)/(layer_normalization)_\d+/(beta)|" - # r"attention_layer_(\d+)/(layer_normalization)_\d+/(gamma)|" + r"attention_layer_(\d+)/(layer_normalization)_\d+/(beta)|" + r"attention_layer_(\d+)/(layer_normalization)_\d+/(gamma)|" )[:-1] TRANSFER_PATTERN = ( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8dd2be2b6b..94c225010a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -481,6 +481,7 @@ def descrpt_se_atten_args(): doc_trainable_ln = ( "Whether to use trainable shift and scale weights in layer normalization." ) + doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." doc_tebd_dim = "The dimension of atom type embedding." doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." doc_scaling_factor = ( @@ -519,6 +520,7 @@ def descrpt_se_atten_args(): Argument( "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln ), + Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps), # pt only Argument( "tebd_dim", diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index d6c54b22f5..c0ca46c91e 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -52,6 +52,7 @@ (1.0,), # scaling_factor (True, False), # normalize (None, 1.0), # temperature + (1e-5,), # ln_eps (True, False), # smooth_type_embedding (True, False), # concat_output_tebd ("float64",), # precision @@ -73,6 +74,7 @@ def data(self) -> dict: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -93,6 +95,7 @@ def data(self) -> dict: "scaling_factor": scaling_factor, "normalize": normalize, "temperature": temperature, + "ln_eps": ln_eps, "concat_output_tebd": concat_output_tebd, "resnet_dt": resnet_dt, "type_one_side": type_one_side, @@ -120,6 +123,7 @@ def skip_pt(self) -> bool: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -142,6 +146,7 @@ def skip_dp(self) -> bool: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -164,6 +169,7 @@ def skip_tf(self) -> bool: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -229,6 +235,7 @@ def setUp(self): scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -284,6 +291,7 @@ def rtol(self) -> float: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision, @@ -312,6 +320,7 @@ def atol(self) -> float: scaling_factor, normalize, temperature, + ln_eps, smooth_type_embedding, concat_output_tebd, precision,