Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a24cd05

Browse files
committed
Assert equal for boolean ndarrays
1 parent 1bb3517 commit a24cd05

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

python/mxnet/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=
565565
b = b.asnumpy()
566566

567567
if use_np_allclose:
568+
if a.dtype == np.bool_ and b.dtype == np.bool_:
569+
np.testing.assert_equal(a, b)
570+
return
568571
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
569572
return
570573
else:

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -975,15 +975,17 @@ def _add_workload_equal(array_pool):
975975
# TODO(junwu): fp16 does not work yet with TVM generated ops
976976
# OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
977977
OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
978-
OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
978+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
979+
# OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
979980
OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2'])
980981

981982

982983
def _add_workload_not_equal(array_pool):
983984
# TODO(junwu): fp16 does not work yet with TVM generated ops
984985
# OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
985986
OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
986-
OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
987+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
988+
# OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
987989
OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2'])
988990

989991

@@ -992,31 +994,35 @@ def _add_workload_greater(array_pool):
992994
# OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
993995
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
994996
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
995-
OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
997+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
998+
# OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
996999

9971000

9981001
def _add_workload_greater_equal(array_pool):
9991002
# TODO(junwu): fp16 does not work yet with TVM generated ops
10001003
# OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
10011004
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
10021005
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
1003-
OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
1006+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
1007+
# OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
10041008

10051009

10061010
def _add_workload_less(array_pool):
10071011
# TODO(junwu): fp16 does not work yet with TVM generated ops
10081012
# OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
10091013
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
10101014
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
1011-
OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
1015+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
1016+
# OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
10121017

10131018

10141019
def _add_workload_less_equal(array_pool):
10151020
# TODO(junwu): fp16 does not work yet with TVM generated ops
10161021
# OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
10171022
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
10181023
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
1019-
OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))
1024+
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
1025+
# OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))
10201026

10211027

10221028
@use_np

0 commit comments

Comments
 (0)