Skip to content

Commit 993b41d

Browse files
authored
set dtype for xp_take_along_axis
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 6074d33 commit 993b41d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis):
6161
else:
6262
indices = xp.reshape(indices, (0, 0))
6363

64-
offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
64+
offset = (xp.arange(indices.shape[0], dtype=indices.type) * m)[:, xp.newaxis]
6565
indices = xp.reshape(offset + indices, (-1,))
6666

6767
out = xp.take(arr, indices)

0 commit comments

Comments
 (0)