Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 54d27cb

Browse files
reminisceeric-haibin-lin
authored andcommitted
[OP] Support range as advanced index for ndarrays (#16047)
1 parent 47f8ceb commit 54d27cb

File tree

3 files changed

+161
-158
lines changed

3 files changed

+161
-158
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,17 +1053,15 @@ def _advanced_index_to_array(idx, ax_len, ctx):
10531053
if idx.dtype != idx_dtype:
10541054
idx = idx.astype(idx_dtype)
10551055
return idx.as_in_context(ctx)
1056-
10571056
elif isinstance(idx, (np.ndarray, list, tuple)):
10581057
return array(idx, ctx, idx_dtype)
1059-
10601058
elif isinstance(idx, integer_types):
10611059
return array([idx], ctx, idx_dtype)
1062-
10631060
elif isinstance(idx, py_slice):
10641061
start, stop, step = idx.indices(ax_len)
10651062
return arange(start, stop, step, ctx=ctx, dtype=idx_dtype)
1066-
1063+
elif sys.version_info[0] > 2 and isinstance(idx, range):
1064+
return arange(idx.start, idx.stop, idx.step, ctx=ctx, dtype=idx_dtype)
10671065
else:
10681066
raise RuntimeError('illegal index type {}'.format(type(idx)))
10691067

@@ -2888,6 +2886,7 @@ def _scatter_set_nd(self, value_nd, indices):
28882886
lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
28892887
)
28902888

2889+
28912890
def indexing_key_expand_implicit_axes(key, shape):
28922891
"""Make implicit axes explicit by adding ``slice(None)``.
28932892
Examples
@@ -2984,6 +2983,8 @@ def _is_advanced_index(idx):
29842983
return True
29852984
elif isinstance(idx, py_slice) or idx is None:
29862985
return False
2986+
elif sys.version_info[0] > 2 and isinstance(idx, range):
2987+
return True
29872988
else:
29882989
raise RuntimeError('illegal index type {}'.format(type(idx)))
29892990

@@ -2995,7 +2996,8 @@ def get_indexing_dispatch_code(key):
29952996
for idx in key:
29962997
if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
29972998
return _NDARRAY_ADVANCED_INDEXING
2998-
2999+
elif sys.version_info[0] > 2 and isinstance(idx, range):
3000+
return _NDARRAY_ADVANCED_INDEXING
29993001
elif not (isinstance(idx, (py_slice, integer_types)) or idx is None):
30003002
raise ValueError(
30013003
'NDArray does not support slicing with key {} of type {}.'

python/mxnet/numpy/multiarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def __getitem__(self, key):
235235
key, shape[0]))
236236
return self._at(key)
237237
elif isinstance(key, py_slice):
238-
if (key.step is None or key.step == 1):
239-
if key.start is not None or key.stop is not None:
238+
if key.step is None or key.step == 1:
239+
if key.start is not None or key.stop is not None:
240240
return self._slice(key.start, key.stop)
241241
else:
242242
return self

tests/python/unittest/test_numpy_ndarray.py

Lines changed: 152 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,6 @@ def test_np_ndarray_copy():
377377
@with_seed()
378378
@use_np
379379
def test_np_ndarray_indexing():
380-
"""
381-
Test all indexing.
382-
"""
383380
def np_int(index, int_type=np.int32):
384381
"""
385382
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
@@ -507,156 +504,160 @@ def test_setitem_autograd(np_array, index):
507504

508505
shape = (8, 16, 9, 9)
509506
np_array = _np.arange(_np.prod(_np.array(shape)), dtype='int32').reshape(shape) # native np array
510-
507+
511508
# Test sliced output being ndarray:
512509
index_list = [
513-
# Basic indexing
514-
# Single int as index
515-
0,
516-
np.int32(0),
517-
np.int64(0),
518-
5,
519-
np.int32(5),
520-
np.int64(5),
521-
-1,
522-
np.int32(-1),
523-
np.int64(-1),
524-
# Slicing as index
525-
slice(5),
526-
np_int(slice(5), np.int32),
527-
np_int(slice(5), np.int64),
528-
slice(1, 5),
529-
np_int(slice(1, 5), np.int32),
530-
np_int(slice(1, 5), np.int64),
531-
slice(1, 5, 2),
532-
np_int(slice(1, 5, 2), np.int32),
533-
np_int(slice(1, 5, 2), np.int64),
534-
slice(7, 0, -1),
535-
np_int(slice(7, 0, -1)),
536-
np_int(slice(7, 0, -1), np.int64),
537-
slice(None, 6),
538-
np_int(slice(None, 6)),
539-
np_int(slice(None, 6), np.int64),
540-
slice(None, 6, 3),
541-
np_int(slice(None, 6, 3)),
542-
np_int(slice(None, 6, 3), np.int64),
543-
slice(1, None),
544-
np_int(slice(1, None)),
545-
np_int(slice(1, None), np.int64),
546-
slice(1, None, 3),
547-
np_int(slice(1, None, 3)),
548-
np_int(slice(1, None, 3), np.int64),
549-
slice(None, None, 2),
550-
np_int(slice(None, None, 2)),
551-
np_int(slice(None, None, 2), np.int64),
552-
slice(None, None, -1),
553-
np_int(slice(None, None, -1)),
554-
np_int(slice(None, None, -1), np.int64),
555-
slice(None, None, -2),
556-
np_int(slice(None, None, -2), np.int32),
557-
np_int(slice(None, None, -2), np.int64),
558-
# Multiple ints as indices
559-
(1, 2, 3),
560-
np_int((1, 2, 3)),
561-
np_int((1, 2, 3), np.int64),
562-
(-1, -2, -3),
563-
np_int((-1, -2, -3)),
564-
np_int((-1, -2, -3), np.int64),
565-
(1, 2, 3, 4),
566-
np_int((1, 2, 3, 4)),
567-
np_int((1, 2, 3, 4), np.int64),
568-
(-4, -3, -2, -1),
569-
np_int((-4, -3, -2, -1)),
570-
np_int((-4, -3, -2, -1), np.int64),
571-
# slice(None) as indices
572-
(slice(None), slice(None), 1, 8),
573-
(slice(None), slice(None), -1, 8),
574-
(slice(None), slice(None), 1, -8),
575-
(slice(None), slice(None), -1, -8),
576-
np_int((slice(None), slice(None), 1, 8)),
577-
np_int((slice(None), slice(None), 1, 8), np.int64),
578-
(slice(None), slice(None), 1, 8),
579-
np_int((slice(None), slice(None), -1, -8)),
580-
np_int((slice(None), slice(None), -1, -8), np.int64),
581-
(slice(None), 2, slice(1, 5), 1),
582-
np_int((slice(None), 2, slice(1, 5), 1)),
583-
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
584-
# Mixture of ints and slices as indices
585-
(slice(None, None, -1), 2, slice(1, 5), 1),
586-
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
587-
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
588-
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
589-
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
590-
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
591-
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
592-
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
593-
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
594-
(slice(1, 8, 2), 1, slice(3, 8), 2),
595-
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
596-
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
597-
# Test Ellipsis ('...')
598-
(1, Ellipsis, -1),
599-
(slice(2), Ellipsis, None, 0),
600-
# Test newaxis
601-
None,
602-
(1, None, -2, 3, -4),
603-
(1, slice(2, 5), None),
604-
(slice(None), slice(1, 4), None, slice(2, 3)),
605-
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
606-
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
607-
(None, slice(1, 2), 3, None),
608-
(1, None, 2, 3, None, None, 4),
609-
# Advanced indexing
610-
([1, 2], slice(3, 5), None, None, [3, 4]),
611-
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
612-
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
613-
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
614-
[1],
615-
[1, 2],
616-
[2, 1, 3],
617-
[7, 5, 0, 3, 6, 2, 1],
618-
np.array([6, 3], dtype=np.int32),
619-
np.array([[3, 4], [0, 6]], dtype=np.int32),
620-
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
621-
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
622-
np.array([[2], [0], [1]], dtype=np.int32),
623-
np.array([[2], [0], [1]], dtype=np.int64),
624-
np.array([4, 7], dtype=np.int32),
625-
np.array([4, 7], dtype=np.int64),
626-
np.array([[3, 6], [2, 1]], dtype=np.int32),
627-
np.array([[3, 6], [2, 1]], dtype=np.int64),
628-
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
629-
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
630-
(1, [2, 3]),
631-
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
632-
(1, [2, 3]),
633-
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
634-
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
635-
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
636-
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
637-
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
638-
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
639-
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
640-
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
641-
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
642-
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
643-
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
644-
[0],
645-
[0, 1],
646-
[1, 2, 3],
647-
[2, 0, 5, 6],
648-
([1, 1], [2, 3]),
649-
([1], [4], [5]),
650-
([1], [4], [5], [6]),
651-
([[1]], [[2]]),
652-
([[1]], [[2]], [[3]], [[4]]),
653-
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
654-
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
655-
([[[[1]]]], 3, slice(0, 3), [1, 3]),
656-
([[[[1]]]], 3, slice(0, 3), 0),
657-
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
658-
([1, 2], slice(3, 5), [2, 3], [3, 4]),
659-
([1, 2], slice(3, 5), (2, 3), [3, 4]),
510+
(),
511+
# Basic indexing
512+
# Single int as index
513+
0,
514+
np.int32(0),
515+
np.int64(0),
516+
5,
517+
np.int32(5),
518+
np.int64(5),
519+
-1,
520+
np.int32(-1),
521+
np.int64(-1),
522+
# Slicing as index
523+
slice(5),
524+
np_int(slice(5), np.int32),
525+
np_int(slice(5), np.int64),
526+
slice(1, 5),
527+
np_int(slice(1, 5), np.int32),
528+
np_int(slice(1, 5), np.int64),
529+
slice(1, 5, 2),
530+
np_int(slice(1, 5, 2), np.int32),
531+
np_int(slice(1, 5, 2), np.int64),
532+
slice(7, 0, -1),
533+
np_int(slice(7, 0, -1)),
534+
np_int(slice(7, 0, -1), np.int64),
535+
slice(None, 6),
536+
np_int(slice(None, 6)),
537+
np_int(slice(None, 6), np.int64),
538+
slice(None, 6, 3),
539+
np_int(slice(None, 6, 3)),
540+
np_int(slice(None, 6, 3), np.int64),
541+
slice(1, None),
542+
np_int(slice(1, None)),
543+
np_int(slice(1, None), np.int64),
544+
slice(1, None, 3),
545+
np_int(slice(1, None, 3)),
546+
np_int(slice(1, None, 3), np.int64),
547+
slice(None, None, 2),
548+
np_int(slice(None, None, 2)),
549+
np_int(slice(None, None, 2), np.int64),
550+
slice(None, None, -1),
551+
np_int(slice(None, None, -1)),
552+
np_int(slice(None, None, -1), np.int64),
553+
slice(None, None, -2),
554+
np_int(slice(None, None, -2), np.int32),
555+
np_int(slice(None, None, -2), np.int64),
556+
# Multiple ints as indices
557+
(1, 2, 3),
558+
np_int((1, 2, 3)),
559+
np_int((1, 2, 3), np.int64),
560+
(-1, -2, -3),
561+
np_int((-1, -2, -3)),
562+
np_int((-1, -2, -3), np.int64),
563+
(1, 2, 3, 4),
564+
np_int((1, 2, 3, 4)),
565+
np_int((1, 2, 3, 4), np.int64),
566+
(-4, -3, -2, -1),
567+
np_int((-4, -3, -2, -1)),
568+
np_int((-4, -3, -2, -1), np.int64),
569+
# slice(None) as indices
570+
(slice(None), slice(None), 1, 8),
571+
(slice(None), slice(None), -1, 8),
572+
(slice(None), slice(None), 1, -8),
573+
(slice(None), slice(None), -1, -8),
574+
np_int((slice(None), slice(None), 1, 8)),
575+
np_int((slice(None), slice(None), 1, 8), np.int64),
576+
(slice(None), slice(None), 1, 8),
577+
np_int((slice(None), slice(None), -1, -8)),
578+
np_int((slice(None), slice(None), -1, -8), np.int64),
579+
(slice(None), 2, slice(1, 5), 1),
580+
np_int((slice(None), 2, slice(1, 5), 1)),
581+
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
582+
# Mixture of ints and slices as indices
583+
(slice(None, None, -1), 2, slice(1, 5), 1),
584+
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
585+
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
586+
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
587+
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
588+
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
589+
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
590+
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
591+
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
592+
(slice(1, 8, 2), 1, slice(3, 8), 2),
593+
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
594+
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
595+
# Test Ellipsis ('...')
596+
(1, Ellipsis, -1),
597+
(slice(2), Ellipsis, None, 0),
598+
# Test newaxis
599+
None,
600+
(1, None, -2, 3, -4),
601+
(1, slice(2, 5), None),
602+
(slice(None), slice(1, 4), None, slice(2, 3)),
603+
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
604+
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
605+
(None, slice(1, 2), 3, None),
606+
(1, None, 2, 3, None, None, 4),
607+
# Advanced indexing
608+
([1, 2], slice(3, 5), None, None, [3, 4]),
609+
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
610+
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
611+
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
612+
[1],
613+
[1, 2],
614+
[2, 1, 3],
615+
[7, 5, 0, 3, 6, 2, 1],
616+
np.array([6, 3], dtype=np.int32),
617+
np.array([[3, 4], [0, 6]], dtype=np.int32),
618+
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
619+
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
620+
np.array([[2], [0], [1]], dtype=np.int32),
621+
np.array([[2], [0], [1]], dtype=np.int64),
622+
np.array([4, 7], dtype=np.int32),
623+
np.array([4, 7], dtype=np.int64),
624+
np.array([[3, 6], [2, 1]], dtype=np.int32),
625+
np.array([[3, 6], [2, 1]], dtype=np.int64),
626+
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
627+
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
628+
(1, [2, 3]),
629+
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
630+
(1, [2, 3]),
631+
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
632+
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
633+
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
634+
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
635+
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
636+
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
637+
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
638+
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
639+
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
640+
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
641+
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
642+
[0],
643+
[0, 1],
644+
[1, 2, 3],
645+
[2, 0, 5, 6],
646+
([1, 1], [2, 3]),
647+
([1], [4], [5]),
648+
([1], [4], [5], [6]),
649+
([[1]], [[2]]),
650+
([[1]], [[2]], [[3]], [[4]]),
651+
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
652+
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
653+
([[[[1]]]], 3, slice(0, 3), [1, 3]),
654+
([[[[1]]]], 3, slice(0, 3), 0),
655+
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
656+
([1, 2], slice(3, 5), [2, 3], [3, 4]),
657+
([1, 2], slice(3, 5), (2, 3), [3, 4]),
658+
range(4),
659+
range(3, 0, -1),
660+
(range(4,), [1]),
660661
]
661662
for index in index_list:
662663
test_getitem(np_array, index)

0 commit comments

Comments
 (0)