Skip to content

Commit 0f817e1

Browse files
authored
style: extend no-explicit-dtype check to xp and jnp (#4247)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Expanded the `DPChecker` to recognize additional libraries ("xp" and "jnp") for enhanced validation of function calls. - **Bug Fixes** - Improved compatibility of the `offset` calculation in the `xp_take_along_axis` function to ensure it matches the data type of the `indices` array. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent c2d0560 commit 0f817e1

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis):
6161
else:
6262
indices = xp.reshape(indices, (0, 0))
6363

64-
offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
64+
offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis]
6565
indices = xp.reshape(offset + indices, (-1,))
6666

6767
out = xp.take(arr, indices)

source/checker/deepmd_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def visit_call(self, node):
3737
if (
3838
isinstance(node.func, Attribute)
3939
and isinstance(node.func.expr, Name)
40-
and node.func.expr.name in {"np", "tf", "torch"}
40+
and node.func.expr.name in {"np", "tf", "torch", "xp", "jnp"}
4141
and node.func.attrname
4242
in {
4343
# https://pytorch.org/docs/stable/torch.html#creation-ops

0 commit comments

Comments
 (0)