@@ -377,9 +377,6 @@ def test_np_ndarray_copy():
377
377
@with_seed ()
378
378
@use_np
379
379
def test_np_ndarray_indexing ():
380
- """
381
- Test all indexing.
382
- """
383
380
def np_int (index , int_type = np .int32 ):
384
381
"""
385
382
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
@@ -507,156 +504,160 @@ def test_setitem_autograd(np_array, index):
507
504
508
505
shape = (8 , 16 , 9 , 9 )
509
506
np_array = _np .arange (_np .prod (_np .array (shape )), dtype = 'int32' ).reshape (shape ) # native np array
510
-
507
+
511
508
# Test sliced output being ndarray:
512
509
index_list = [
513
- # Basic indexing
514
- # Single int as index
515
- 0 ,
516
- np .int32 (0 ),
517
- np .int64 (0 ),
518
- 5 ,
519
- np .int32 (5 ),
520
- np .int64 (5 ),
521
- - 1 ,
522
- np .int32 (- 1 ),
523
- np .int64 (- 1 ),
524
- # Slicing as index
525
- slice (5 ),
526
- np_int (slice (5 ), np .int32 ),
527
- np_int (slice (5 ), np .int64 ),
528
- slice (1 , 5 ),
529
- np_int (slice (1 , 5 ), np .int32 ),
530
- np_int (slice (1 , 5 ), np .int64 ),
531
- slice (1 , 5 , 2 ),
532
- np_int (slice (1 , 5 , 2 ), np .int32 ),
533
- np_int (slice (1 , 5 , 2 ), np .int64 ),
534
- slice (7 , 0 , - 1 ),
535
- np_int (slice (7 , 0 , - 1 )),
536
- np_int (slice (7 , 0 , - 1 ), np .int64 ),
537
- slice (None , 6 ),
538
- np_int (slice (None , 6 )),
539
- np_int (slice (None , 6 ), np .int64 ),
540
- slice (None , 6 , 3 ),
541
- np_int (slice (None , 6 , 3 )),
542
- np_int (slice (None , 6 , 3 ), np .int64 ),
543
- slice (1 , None ),
544
- np_int (slice (1 , None )),
545
- np_int (slice (1 , None ), np .int64 ),
546
- slice (1 , None , 3 ),
547
- np_int (slice (1 , None , 3 )),
548
- np_int (slice (1 , None , 3 ), np .int64 ),
549
- slice (None , None , 2 ),
550
- np_int (slice (None , None , 2 )),
551
- np_int (slice (None , None , 2 ), np .int64 ),
552
- slice (None , None , - 1 ),
553
- np_int (slice (None , None , - 1 )),
554
- np_int (slice (None , None , - 1 ), np .int64 ),
555
- slice (None , None , - 2 ),
556
- np_int (slice (None , None , - 2 ), np .int32 ),
557
- np_int (slice (None , None , - 2 ), np .int64 ),
558
- # Multiple ints as indices
559
- (1 , 2 , 3 ),
560
- np_int ((1 , 2 , 3 )),
561
- np_int ((1 , 2 , 3 ), np .int64 ),
562
- (- 1 , - 2 , - 3 ),
563
- np_int ((- 1 , - 2 , - 3 )),
564
- np_int ((- 1 , - 2 , - 3 ), np .int64 ),
565
- (1 , 2 , 3 , 4 ),
566
- np_int ((1 , 2 , 3 , 4 )),
567
- np_int ((1 , 2 , 3 , 4 ), np .int64 ),
568
- (- 4 , - 3 , - 2 , - 1 ),
569
- np_int ((- 4 , - 3 , - 2 , - 1 )),
570
- np_int ((- 4 , - 3 , - 2 , - 1 ), np .int64 ),
571
- # slice(None) as indices
572
- (slice (None ), slice (None ), 1 , 8 ),
573
- (slice (None ), slice (None ), - 1 , 8 ),
574
- (slice (None ), slice (None ), 1 , - 8 ),
575
- (slice (None ), slice (None ), - 1 , - 8 ),
576
- np_int ((slice (None ), slice (None ), 1 , 8 )),
577
- np_int ((slice (None ), slice (None ), 1 , 8 ), np .int64 ),
578
- (slice (None ), slice (None ), 1 , 8 ),
579
- np_int ((slice (None ), slice (None ), - 1 , - 8 )),
580
- np_int ((slice (None ), slice (None ), - 1 , - 8 ), np .int64 ),
581
- (slice (None ), 2 , slice (1 , 5 ), 1 ),
582
- np_int ((slice (None ), 2 , slice (1 , 5 ), 1 )),
583
- np_int ((slice (None ), 2 , slice (1 , 5 ), 1 ), np .int64 ),
584
- # Mixture of ints and slices as indices
585
- (slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 ),
586
- np_int ((slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 )),
587
- np_int ((slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 ), np .int64 ),
588
- (slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 ),
589
- np_int ((slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 )),
590
- np_int ((slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 ), np .int64 ),
591
- (slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 )),
592
- np_int ((slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 ))),
593
- np_int ((slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 )), np .int64 ),
594
- (slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 ),
595
- np_int ((slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 )),
596
- np_int ((slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 ), np .int64 ),
597
- # Test Ellipsis ('...')
598
- (1 , Ellipsis , - 1 ),
599
- (slice (2 ), Ellipsis , None , 0 ),
600
- # Test newaxis
601
- None ,
602
- (1 , None , - 2 , 3 , - 4 ),
603
- (1 , slice (2 , 5 ), None ),
604
- (slice (None ), slice (1 , 4 ), None , slice (2 , 3 )),
605
- (slice (1 , 3 ), slice (1 , 3 ), slice (1 , 3 ), slice (1 , 3 ), None ),
606
- (slice (1 , 3 ), slice (1 , 3 ), None , slice (1 , 3 ), slice (1 , 3 )),
607
- (None , slice (1 , 2 ), 3 , None ),
608
- (1 , None , 2 , 3 , None , None , 4 ),
609
- # Advanced indexing
610
- ([1 , 2 ], slice (3 , 5 ), None , None , [3 , 4 ]),
611
- (slice (None ), slice (3 , 5 ), None , None , [2 , 3 ], [3 , 4 ]),
612
- (slice (None ), slice (3 , 5 ), None , [2 , 3 ], None , [3 , 4 ]),
613
- (None , slice (None ), slice (3 , 5 ), [2 , 3 ], None , [3 , 4 ]),
614
- [1 ],
615
- [1 , 2 ],
616
- [2 , 1 , 3 ],
617
- [7 , 5 , 0 , 3 , 6 , 2 , 1 ],
618
- np .array ([6 , 3 ], dtype = np .int32 ),
619
- np .array ([[3 , 4 ], [0 , 6 ]], dtype = np .int32 ),
620
- np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int32 ),
621
- np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int64 ),
622
- np .array ([[2 ], [0 ], [1 ]], dtype = np .int32 ),
623
- np .array ([[2 ], [0 ], [1 ]], dtype = np .int64 ),
624
- np .array ([4 , 7 ], dtype = np .int32 ),
625
- np .array ([4 , 7 ], dtype = np .int64 ),
626
- np .array ([[3 , 6 ], [2 , 1 ]], dtype = np .int32 ),
627
- np .array ([[3 , 6 ], [2 , 1 ]], dtype = np .int64 ),
628
- np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int32 ),
629
- np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int64 ),
630
- (1 , [2 , 3 ]),
631
- (1 , [2 , 3 ], np .array ([[3 ], [0 ]], dtype = np .int32 )),
632
- (1 , [2 , 3 ]),
633
- (1 , [2 , 3 ], np .array ([[3 ], [0 ]], dtype = np .int64 )),
634
- (1 , [2 ], np .array ([[5 ], [3 ]], dtype = np .int32 ), slice (None )),
635
- (1 , [2 ], np .array ([[5 ], [3 ]], dtype = np .int64 ), slice (None )),
636
- (1 , [2 , 3 ], np .array ([[6 ], [0 ]], dtype = np .int32 ), slice (2 , 5 )),
637
- (1 , [2 , 3 ], np .array ([[6 ], [0 ]], dtype = np .int64 ), slice (2 , 5 )),
638
- (1 , [2 , 3 ], np .array ([[4 ], [7 ]], dtype = np .int32 ), slice (2 , 5 , 2 )),
639
- (1 , [2 , 3 ], np .array ([[4 ], [7 ]], dtype = np .int64 ), slice (2 , 5 , 2 )),
640
- (1 , [2 ], np .array ([[3 ]], dtype = np .int32 ), slice (None , None , - 1 )),
641
- (1 , [2 ], np .array ([[3 ]], dtype = np .int64 ), slice (None , None , - 1 )),
642
- (1 , [2 ], np .array ([[3 ]], dtype = np .int32 ), np .array ([[5 , 7 ], [2 , 4 ]], dtype = np .int64 )),
643
- (1 , [2 ], np .array ([[4 ]], dtype = np .int32 ), np .array ([[1 , 3 ], [5 , 7 ]], dtype = 'int64' )),
644
- [0 ],
645
- [0 , 1 ],
646
- [1 , 2 , 3 ],
647
- [2 , 0 , 5 , 6 ],
648
- ([1 , 1 ], [2 , 3 ]),
649
- ([1 ], [4 ], [5 ]),
650
- ([1 ], [4 ], [5 ], [6 ]),
651
- ([[1 ]], [[2 ]]),
652
- ([[1 ]], [[2 ]], [[3 ]], [[4 ]]),
653
- (slice (0 , 2 ), [[1 ], [6 ]], slice (0 , 2 ), slice (0 , 5 , 2 )),
654
- ([[[[1 ]]]], [[1 ]], slice (0 , 3 ), [1 , 5 ]),
655
- ([[[[1 ]]]], 3 , slice (0 , 3 ), [1 , 3 ]),
656
- ([[[[1 ]]]], 3 , slice (0 , 3 ), 0 ),
657
- ([[[[1 ]]]], [[2 ], [12 ]], slice (0 , 3 ), slice (None )),
658
- ([1 , 2 ], slice (3 , 5 ), [2 , 3 ], [3 , 4 ]),
659
- ([1 , 2 ], slice (3 , 5 ), (2 , 3 ), [3 , 4 ]),
510
+ (),
511
+ # Basic indexing
512
+ # Single int as index
513
+ 0 ,
514
+ np .int32 (0 ),
515
+ np .int64 (0 ),
516
+ 5 ,
517
+ np .int32 (5 ),
518
+ np .int64 (5 ),
519
+ - 1 ,
520
+ np .int32 (- 1 ),
521
+ np .int64 (- 1 ),
522
+ # Slicing as index
523
+ slice (5 ),
524
+ np_int (slice (5 ), np .int32 ),
525
+ np_int (slice (5 ), np .int64 ),
526
+ slice (1 , 5 ),
527
+ np_int (slice (1 , 5 ), np .int32 ),
528
+ np_int (slice (1 , 5 ), np .int64 ),
529
+ slice (1 , 5 , 2 ),
530
+ np_int (slice (1 , 5 , 2 ), np .int32 ),
531
+ np_int (slice (1 , 5 , 2 ), np .int64 ),
532
+ slice (7 , 0 , - 1 ),
533
+ np_int (slice (7 , 0 , - 1 )),
534
+ np_int (slice (7 , 0 , - 1 ), np .int64 ),
535
+ slice (None , 6 ),
536
+ np_int (slice (None , 6 )),
537
+ np_int (slice (None , 6 ), np .int64 ),
538
+ slice (None , 6 , 3 ),
539
+ np_int (slice (None , 6 , 3 )),
540
+ np_int (slice (None , 6 , 3 ), np .int64 ),
541
+ slice (1 , None ),
542
+ np_int (slice (1 , None )),
543
+ np_int (slice (1 , None ), np .int64 ),
544
+ slice (1 , None , 3 ),
545
+ np_int (slice (1 , None , 3 )),
546
+ np_int (slice (1 , None , 3 ), np .int64 ),
547
+ slice (None , None , 2 ),
548
+ np_int (slice (None , None , 2 )),
549
+ np_int (slice (None , None , 2 ), np .int64 ),
550
+ slice (None , None , - 1 ),
551
+ np_int (slice (None , None , - 1 )),
552
+ np_int (slice (None , None , - 1 ), np .int64 ),
553
+ slice (None , None , - 2 ),
554
+ np_int (slice (None , None , - 2 ), np .int32 ),
555
+ np_int (slice (None , None , - 2 ), np .int64 ),
556
+ # Multiple ints as indices
557
+ (1 , 2 , 3 ),
558
+ np_int ((1 , 2 , 3 )),
559
+ np_int ((1 , 2 , 3 ), np .int64 ),
560
+ (- 1 , - 2 , - 3 ),
561
+ np_int ((- 1 , - 2 , - 3 )),
562
+ np_int ((- 1 , - 2 , - 3 ), np .int64 ),
563
+ (1 , 2 , 3 , 4 ),
564
+ np_int ((1 , 2 , 3 , 4 )),
565
+ np_int ((1 , 2 , 3 , 4 ), np .int64 ),
566
+ (- 4 , - 3 , - 2 , - 1 ),
567
+ np_int ((- 4 , - 3 , - 2 , - 1 )),
568
+ np_int ((- 4 , - 3 , - 2 , - 1 ), np .int64 ),
569
+ # slice(None) as indices
570
+ (slice (None ), slice (None ), 1 , 8 ),
571
+ (slice (None ), slice (None ), - 1 , 8 ),
572
+ (slice (None ), slice (None ), 1 , - 8 ),
573
+ (slice (None ), slice (None ), - 1 , - 8 ),
574
+ np_int ((slice (None ), slice (None ), 1 , 8 )),
575
+ np_int ((slice (None ), slice (None ), 1 , 8 ), np .int64 ),
576
+ (slice (None ), slice (None ), 1 , 8 ),
577
+ np_int ((slice (None ), slice (None ), - 1 , - 8 )),
578
+ np_int ((slice (None ), slice (None ), - 1 , - 8 ), np .int64 ),
579
+ (slice (None ), 2 , slice (1 , 5 ), 1 ),
580
+ np_int ((slice (None ), 2 , slice (1 , 5 ), 1 )),
581
+ np_int ((slice (None ), 2 , slice (1 , 5 ), 1 ), np .int64 ),
582
+ # Mixture of ints and slices as indices
583
+ (slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 ),
584
+ np_int ((slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 )),
585
+ np_int ((slice (None , None , - 1 ), 2 , slice (1 , 5 ), 1 ), np .int64 ),
586
+ (slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 ),
587
+ np_int ((slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 )),
588
+ np_int ((slice (None , None , - 1 ), 2 , slice (1 , 7 , 2 ), 1 ), np .int64 ),
589
+ (slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 )),
590
+ np_int ((slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 ))),
591
+ np_int ((slice (1 , 8 , 2 ), slice (14 , 2 , - 2 ), slice (3 , 8 ), slice (0 , 7 , 3 )), np .int64 ),
592
+ (slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 ),
593
+ np_int ((slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 )),
594
+ np_int ((slice (1 , 8 , 2 ), 1 , slice (3 , 8 ), 2 ), np .int64 ),
595
+ # Test Ellipsis ('...')
596
+ (1 , Ellipsis , - 1 ),
597
+ (slice (2 ), Ellipsis , None , 0 ),
598
+ # Test newaxis
599
+ None ,
600
+ (1 , None , - 2 , 3 , - 4 ),
601
+ (1 , slice (2 , 5 ), None ),
602
+ (slice (None ), slice (1 , 4 ), None , slice (2 , 3 )),
603
+ (slice (1 , 3 ), slice (1 , 3 ), slice (1 , 3 ), slice (1 , 3 ), None ),
604
+ (slice (1 , 3 ), slice (1 , 3 ), None , slice (1 , 3 ), slice (1 , 3 )),
605
+ (None , slice (1 , 2 ), 3 , None ),
606
+ (1 , None , 2 , 3 , None , None , 4 ),
607
+ # Advanced indexing
608
+ ([1 , 2 ], slice (3 , 5 ), None , None , [3 , 4 ]),
609
+ (slice (None ), slice (3 , 5 ), None , None , [2 , 3 ], [3 , 4 ]),
610
+ (slice (None ), slice (3 , 5 ), None , [2 , 3 ], None , [3 , 4 ]),
611
+ (None , slice (None ), slice (3 , 5 ), [2 , 3 ], None , [3 , 4 ]),
612
+ [1 ],
613
+ [1 , 2 ],
614
+ [2 , 1 , 3 ],
615
+ [7 , 5 , 0 , 3 , 6 , 2 , 1 ],
616
+ np .array ([6 , 3 ], dtype = np .int32 ),
617
+ np .array ([[3 , 4 ], [0 , 6 ]], dtype = np .int32 ),
618
+ np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int32 ),
619
+ np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int64 ),
620
+ np .array ([[2 ], [0 ], [1 ]], dtype = np .int32 ),
621
+ np .array ([[2 ], [0 ], [1 ]], dtype = np .int64 ),
622
+ np .array ([4 , 7 ], dtype = np .int32 ),
623
+ np .array ([4 , 7 ], dtype = np .int64 ),
624
+ np .array ([[3 , 6 ], [2 , 1 ]], dtype = np .int32 ),
625
+ np .array ([[3 , 6 ], [2 , 1 ]], dtype = np .int64 ),
626
+ np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int32 ),
627
+ np .array ([[7 , 3 ], [2 , 6 ], [0 , 5 ], [4 , 1 ]], dtype = np .int64 ),
628
+ (1 , [2 , 3 ]),
629
+ (1 , [2 , 3 ], np .array ([[3 ], [0 ]], dtype = np .int32 )),
630
+ (1 , [2 , 3 ]),
631
+ (1 , [2 , 3 ], np .array ([[3 ], [0 ]], dtype = np .int64 )),
632
+ (1 , [2 ], np .array ([[5 ], [3 ]], dtype = np .int32 ), slice (None )),
633
+ (1 , [2 ], np .array ([[5 ], [3 ]], dtype = np .int64 ), slice (None )),
634
+ (1 , [2 , 3 ], np .array ([[6 ], [0 ]], dtype = np .int32 ), slice (2 , 5 )),
635
+ (1 , [2 , 3 ], np .array ([[6 ], [0 ]], dtype = np .int64 ), slice (2 , 5 )),
636
+ (1 , [2 , 3 ], np .array ([[4 ], [7 ]], dtype = np .int32 ), slice (2 , 5 , 2 )),
637
+ (1 , [2 , 3 ], np .array ([[4 ], [7 ]], dtype = np .int64 ), slice (2 , 5 , 2 )),
638
+ (1 , [2 ], np .array ([[3 ]], dtype = np .int32 ), slice (None , None , - 1 )),
639
+ (1 , [2 ], np .array ([[3 ]], dtype = np .int64 ), slice (None , None , - 1 )),
640
+ (1 , [2 ], np .array ([[3 ]], dtype = np .int32 ), np .array ([[5 , 7 ], [2 , 4 ]], dtype = np .int64 )),
641
+ (1 , [2 ], np .array ([[4 ]], dtype = np .int32 ), np .array ([[1 , 3 ], [5 , 7 ]], dtype = 'int64' )),
642
+ [0 ],
643
+ [0 , 1 ],
644
+ [1 , 2 , 3 ],
645
+ [2 , 0 , 5 , 6 ],
646
+ ([1 , 1 ], [2 , 3 ]),
647
+ ([1 ], [4 ], [5 ]),
648
+ ([1 ], [4 ], [5 ], [6 ]),
649
+ ([[1 ]], [[2 ]]),
650
+ ([[1 ]], [[2 ]], [[3 ]], [[4 ]]),
651
+ (slice (0 , 2 ), [[1 ], [6 ]], slice (0 , 2 ), slice (0 , 5 , 2 )),
652
+ ([[[[1 ]]]], [[1 ]], slice (0 , 3 ), [1 , 5 ]),
653
+ ([[[[1 ]]]], 3 , slice (0 , 3 ), [1 , 3 ]),
654
+ ([[[[1 ]]]], 3 , slice (0 , 3 ), 0 ),
655
+ ([[[[1 ]]]], [[2 ], [12 ]], slice (0 , 3 ), slice (None )),
656
+ ([1 , 2 ], slice (3 , 5 ), [2 , 3 ], [3 , 4 ]),
657
+ ([1 , 2 ], slice (3 , 5 ), (2 , 3 ), [3 , 4 ]),
658
+ range (4 ),
659
+ range (3 , 0 , - 1 ),
660
+ (range (4 ,), [1 ]),
660
661
]
661
662
for index in index_list :
662
663
test_getitem (np_array , index )
0 commit comments