Skip to content

Commit 12b5c54

Browse files
committed
fix(jax): use more safe_for_vector_norm
Signed-off-by: Jinzhe Zeng <[email protected]> (cherry picked from commit 2433566)
1 parent 1d95c18 commit 12b5c54

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def call(
490490
sw = xp.where(nlist_mask, sw, xp.zeros_like(sw))
491491

492492
# get angle nlist (maybe smaller)
493-
a_dist_mask = (xp.linalg.vector_norm(diff, axis=-1) < self.a_rcut)[
493+
a_dist_mask = (safe_for_vector_norm(diff, axis=-1) < self.a_rcut)[
494494
:, :, : self.a_sel
495495
]
496496
a_nlist = nlist[:, :, : self.a_sel]

0 commit comments

Comments
 (0)