@@ -617,49 +617,70 @@ def is_int(dtype):
617
617
in_data_dim = random .choice ([2 , 3 , 4 ])
618
618
shape = rand_shape_nd (in_data_dim , dim = 3 )
619
619
acc_type = {'float16' : 'float32' , 'float32' : 'float64' , 'float64' : 'float64' ,
620
- 'int8' : 'int32' , 'int32' : 'int64' , 'int64' : 'int64' }
620
+ 'bool' : 'int64' , 'int8' : 'int32' , 'int32' : 'int64' , 'int64' : 'int64' }
621
+ ft_types = ['float16' , 'float32' , 'float64' ]
622
+ it_types = ['bool' , 'int8' , 'int32' , 'int64' ]
621
623
for hybridize in [False , True ]:
622
624
for keepdims in [True , False ]:
623
625
for axis in ([i for i in range (in_data_dim )] + [(), None ]):
624
- for itype in ['float16' , 'float32' , 'float64' ]:
625
- for dtype in ['float16' , 'float32' , 'float64' ]:
626
- if is_int (dtype ) and not is_int (itype ):
627
- continue
628
- # test gluon
629
- test_mean = TestMean (axis = axis , dtype = dtype , keepdims = keepdims )
630
- if hybridize :
631
- test_mean .hybridize ()
632
- if is_int (itype ):
633
- x = _np .random .randint (- 128 , 128 , shape , dtype = itype )
634
- x = mx .nd .array (x , dtype = itype )
635
- else :
636
- x = mx .nd .random .uniform (- 1.0 , 1.0 , shape = shape , dtype = itype )
637
- x = x .as_np_ndarray ()
638
- x .attach_grad ()
626
+ for itype , dtype in itertools .product (ft_types , [None ] + ft_types + it_types ):
627
+ if dtype == 'bool' :
628
+ continue
629
+ # test gluon
630
+ test_mean = TestMean (axis = axis , dtype = dtype , keepdims = keepdims )
631
+ if hybridize :
632
+ test_mean .hybridize ()
633
+ x = np .random .uniform (- 1.0 , 1.0 , size = shape ).astype (itype )
634
+ x = x .as_np_ndarray ()
635
+ x .attach_grad ()
639
636
640
- expected_ret = _np .mean (x .asnumpy (), axis = axis , dtype = acc_type [itype ], keepdims = keepdims )
641
- expected_ret = expected_ret .astype (dtype )
642
- with mx .autograd .record ():
643
- y = test_mean (x )
644
- assert y .shape == expected_ret .shape
645
- assert_almost_equal (y .asnumpy (), expected_ret , rtol = 1e-3 if dtype == 'float16' else 1e-3 ,
646
- atol = 1e-5 if dtype == 'float16' else 1e-5 )
637
+ expected_ret = _np .mean (x .asnumpy (), axis = axis , dtype = acc_type [itype ], keepdims = keepdims )
638
+ expected_ret = expected_ret .astype (dtype )
639
+ with mx .autograd .record ():
640
+ y = test_mean (x )
641
+ assert y .shape == expected_ret .shape
642
+ assert_almost_equal (y .asnumpy (), expected_ret , rtol = 1e-3 if dtype == 'float16' else 1e-3 ,
643
+ atol = 1e-5 if dtype == 'float16' else 1e-5 )
647
644
648
- y .backward ()
649
- N = x .size / y .size
650
- assert same (x .grad .asnumpy (), _np .ones (shape = x .shape , dtype = x .dtype ) / N )
645
+ y .backward ()
646
+ N = x .size / y .size
647
+ assert same (x .grad .asnumpy (), _np .ones (shape = x .shape , dtype = x .dtype ) / N )
651
648
652
- # test numeric
653
- if itype == 'float32' and dtype == 'float32' :
654
- x_sym = mx .sym .Variable ("x" ).as_np_ndarray ()
655
- mx_sym = mx .sym .np .mean (x_sym , axis = axis , dtype = dtype , keepdims = keepdims ).as_nd_ndarray ()
656
- check_numeric_gradient (mx_sym , [x .as_nd_ndarray ()],
657
- numeric_eps = 1e-3 , rtol = 1e-3 , atol = 1e-4 , dtype = _np .float32 )
649
+ # test numeric
650
+ if itype == 'float32' and dtype == 'float32' :
651
+ x_sym = mx .sym .Variable ("x" ).as_np_ndarray ()
652
+ mx_sym = mx .sym .np .mean (x_sym , axis = axis , dtype = dtype , keepdims = keepdims ).as_nd_ndarray ()
653
+ check_numeric_gradient (mx_sym , [x .as_nd_ndarray ()],
654
+ numeric_eps = 1e-3 , rtol = 1e-3 , atol = 1e-4 , dtype = _np .float32 )
658
655
659
- # test imperative
660
- mx_out = np .mean (x , axis = axis , dtype = dtype , keepdims = keepdims )
661
- np_out = _np .mean (x .asnumpy (), axis = axis , dtype = acc_type [itype ], keepdims = keepdims ).astype (dtype )
662
- assert_almost_equal (mx_out .asnumpy (), np_out , rtol = 1e-3 , atol = 1e-5 )
656
+ # test imperative
657
+ mx_out = np .mean (x , axis = axis , dtype = dtype , keepdims = keepdims )
658
+ np_out = _np .mean (x .asnumpy (), axis = axis , dtype = acc_type [itype ], keepdims = keepdims ).astype (dtype )
659
+ assert_almost_equal (mx_out .asnumpy (), np_out , rtol = 1e-3 , atol = 1e-5 )
660
+
661
+ for itype , dtype in itertools .product (it_types , [None ] + ft_types + it_types ):
662
+ if dtype == 'bool' :
663
+ continue
664
+ # test gluon
665
+ test_mean = TestMean (axis = axis , dtype = dtype , keepdims = keepdims )
666
+ if hybridize :
667
+ test_mean .hybridize ()
668
+
669
+ if itype == 'bool' :
670
+ x = np .random .uniform (size = shape ) > 0.5
671
+ else :
672
+ x = np .random .uniform (- 128 , 127 , size = shape ).astype (itype )
673
+
674
+ expected_ret = _np .mean (x .asnumpy (), axis = axis , dtype = dtype , keepdims = keepdims )
675
+ y = test_mean (x )
676
+ assert y .shape == expected_ret .shape
677
+ assert_almost_equal (y .asnumpy (), expected_ret , rtol = 1e-3 if dtype == 'float16' else 1e-3 ,
678
+ atol = 1e-5 if dtype == 'float16' else 1e-5 )
679
+
680
+ # test imperative
681
+ mx_out = np .mean (x , axis = axis , dtype = dtype , keepdims = keepdims )
682
+ np_out = _np .mean (x .asnumpy (), axis = axis , dtype = dtype , keepdims = keepdims ).astype (dtype )
683
+ assert_almost_equal (mx_out .asnumpy (), np_out , rtol = 1e-3 , atol = 1e-5 )
663
684
664
685
665
686
@with_seed ()
0 commit comments