From 0382d2b3b225c7be70ccc03d6479f91f816c8d13 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 2 Mar 2024 20:31:19 -0500 Subject: [PATCH] do not return g2, h2, sw in hybrid descriptors g2, h2, and sw are heavily dependent on the neighbor list. We cannot ensure the sub descriptors require the same neighbor list as the parent descriptor. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/hybrid.py | 8 +------- deepmd/pt/model/descriptor/hybrid.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 46f2616b84..96640d75c8 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -176,7 +176,7 @@ def call( """ out_descriptor = [] out_gr = [] - out_g2 = [] + out_g2 = None out_h2 = None out_sw = None if self.sel_no_mixed_types is not None: @@ -199,15 +199,9 @@ def call( out_descriptor.append(odescriptor) if gr is not None: out_gr.append(gr) - if g2 is not None: - out_g2.append(g2) - if self.get_rcut() == descrpt.get_rcut(): - out_h2 = h2 - out_sw = sw out_descriptor = np.concatenate(out_descriptor, axis=-1) out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None - out_g2 = np.concatenate(out_g2, axis=-1) if out_g2 else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index b53adca462..204ca7589d 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -200,7 +200,7 @@ def forward( """ out_descriptor = [] out_gr = [] - out_g2 = [] + out_g2: Optional[torch.Tensor] = None out_h2: Optional[torch.Tensor] = None out_sw: Optional[torch.Tensor] = None if self.sel_no_mixed_types is not None: @@ -225,14 +225,8 @@ def forward( out_descriptor.append(odescriptor) if gr is not None: out_gr.append(gr) - if g2 is not None: - out_g2.append(g2) - if self.get_rcut() == descrpt.get_rcut(): - out_h2 = h2 - out_sw = sw out_descriptor = torch.cat(out_descriptor, dim=-1) out_gr = torch.cat(out_gr, dim=-2) if out_gr else None - out_g2 = torch.cat(out_g2, dim=-1) if out_g2 else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod