Skip to content

Commit 3bd950f

Browse files
pd: update loc_mapping for dpa3 in paddle backend (#4797)
adapt preparing input spec in paddle backend to support `loc_mapping` for dpa3 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for dynamic neighbor and angle selection, enabling flexible graph-based computations. - Introduced options for local atom index mapping and alternative edge feature initialization. - Added exponential switch function for neighbor smoothing. - Introduced JIT-compiled SiLUT activation with higher-order gradient support, configurable via environment variable. - New utility functions for aggregation and graph index computation. - **Improvements** - Enhanced documentation for descriptor parameters. - Defaulted local mapping to enabled for relevant descriptors. - Improved device handling and tensor operations for better compatibility and clarity. - **Bug Fixes** - Simplified test skipping logic to rely on a unified flag. - **Chores** - Refactored and clarified internal logic for smoother weight and switch function computations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent fdfd4c0 commit 3bd950f

File tree

11 files changed

+872
-136
lines changed

11 files changed

+872
-136
lines changed

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class DescrptDPA3(BaseDescriptor, paddle.nn.Layer):
9191
Whether to use bias in the type embedding layer.
9292
use_loc_mapping : bool, Optional
9393
Whether to use local atom index mapping in training or non-parallel inference.
94-
Not supported yet in Paddle.
94+
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
9595
type_map : list[str], Optional
9696
A list of strings. Give the name to each type of atoms.
9797
@@ -117,7 +117,7 @@ def __init__(
117117
seed: Optional[Union[int, list[int]]] = None,
118118
use_econf_tebd: bool = False,
119119
use_tebd_bias: bool = False,
120-
use_loc_mapping: bool = False,
120+
use_loc_mapping: bool = True,
121121
type_map: Optional[list[str]] = None,
122122
) -> None:
123123
super().__init__()
@@ -160,6 +160,8 @@ def init_subclass_params(sub_data, sub_class):
160160
fix_stat_std=self.repflow_args.fix_stat_std,
161161
optim_update=self.repflow_args.optim_update,
162162
smooth_edge_update=self.repflow_args.smooth_edge_update,
163+
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
164+
use_exp_switch=self.repflow_args.use_exp_switch,
163165
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
164166
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
165167
use_loc_mapping=use_loc_mapping,
@@ -170,8 +172,8 @@ def init_subclass_params(sub_data, sub_class):
170172
)
171173

172174
self.use_econf_tebd = use_econf_tebd
173-
self.use_tebd_bias = use_tebd_bias
174175
self.use_loc_mapping = use_loc_mapping
176+
self.use_tebd_bias = use_tebd_bias
175177
self.type_map = type_map
176178
self.tebd_dim = self.repflow_args.n_dim
177179
self.type_embedding = TypeEmbedNet(
@@ -487,12 +489,16 @@ def forward(
487489
The smooth switch function. shape: nf x nloc x nnei
488490
489491
"""
492+
parallel_mode = comm_dict is not None
490493
# cast the input to internal precsion
491494
extended_coord = extended_coord.to(dtype=self.prec)
492495
nframes, nloc, nnei = nlist.shape
493496
nall = extended_coord.reshape([nframes, -1]).shape[1] // 3
494497

495-
node_ebd_ext = self.type_embedding(extended_atype)
498+
if not parallel_mode and self.use_loc_mapping:
499+
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
500+
else:
501+
node_ebd_ext = self.type_embedding(extended_atype)
496502
node_ebd_inp = node_ebd_ext[:, :nloc, :]
497503
# repflows
498504
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(

deepmd/pd/model/descriptor/env_mat.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import paddle
44

55
from deepmd.pd.utils.preprocess import (
6+
compute_exp_sw,
67
compute_smooth_weight,
78
)
89

@@ -14,6 +15,7 @@ def _make_env_mat(
1415
ruct_smth: float,
1516
radial_only: bool = False,
1617
protection: float = 0.0,
18+
use_exp_switch: bool = False,
1719
):
1820
"""Make smooth environment matrix."""
1921
bsz, natoms, nnei = nlist.shape
@@ -24,15 +26,20 @@ def _make_env_mat(
2426
nlist = paddle.where(mask, nlist, nall - 1)
2527
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
2628
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
27-
coord_r = paddle.take_along_axis(coord, axis=1, indices=index)
29+
coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
30+
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index)
2831
coord_r = coord_r.reshape([bsz, natoms, nnei, 3])
2932
diff = coord_r - coord_l
3033
length = paddle.linalg.norm(diff, axis=-1, keepdim=True)
3134
# for index 0 nloc atom
3235
length = length + (~mask.unsqueeze(-1)).astype(length.dtype)
3336
t0 = 1 / (length + protection)
3437
t1 = diff / (length + protection) ** 2
35-
weight = compute_smooth_weight(length, ruct_smth, rcut)
38+
weight = (
39+
compute_smooth_weight(length, ruct_smth, rcut)
40+
if not use_exp_switch
41+
else compute_exp_sw(length, ruct_smth, rcut)
42+
)
3643
weight = weight * mask.unsqueeze(-1).astype(weight.dtype)
3744
if radial_only:
3845
env_mat = t0 * weight
@@ -51,6 +58,7 @@ def prod_env_mat(
5158
rcut_smth: float,
5259
radial_only: bool = False,
5360
protection: float = 0.0,
61+
use_exp_switch: bool = False,
5462
):
5563
"""Generate smooth environment matrix from atom coordinates and other context.
5664
@@ -63,6 +71,7 @@ def prod_env_mat(
6371
- rcut_smth: Smooth hyper-parameter for pair force & energy.
6472
- radial_only: Whether to return a full description or a radial-only descriptor.
6573
- protection: Protection parameter to prevent division by zero errors during calculations.
74+
- use_exp_switch: Whether to use the exponential switch function.
6675
6776
Returns
6877
-------
@@ -75,6 +84,7 @@ def prod_env_mat(
7584
rcut_smth,
7685
radial_only,
7786
protection=protection,
87+
use_exp_switch=use_exp_switch,
7888
) # shape [n_atom, dim, 4 or 1]
7989
t_avg = mean[atype] # [n_atom, dim, 4 or 1]
8090
t_std = stddev[atype] # [n_atom, dim, 4 or 1]

0 commit comments

Comments
 (0)