Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f0e7911

Browse files
JiangZhaohreminisce
authored andcommitted
all changes
fix sanity problem change is_op_runnable's location Fix
1 parent 746cbc5 commit f0e7911

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
117117
'transpose',
118118
'unique',
119119
'var',
120+
'vdot',
121+
'vstack',
120122
'zeros_like',
121123
'linalg.norm',
122124
'trace',
@@ -214,6 +216,12 @@ def _register_array_function():
214216
'trunc',
215217
'floor',
216218
'logical_not',
219+
'equal',
220+
'not_equal',
221+
'less',
222+
'less_equal',
223+
'greater',
224+
'greater_equal'
217225
]
218226

219227

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232
_INT_DTYPES = [np.int8, np.int32, np.int64, np.uint8]
3333
_FLOAT_DTYPES = [np.float16, np.float32, np.float64]
3434
_DTYPES = _INT_DTYPES + _FLOAT_DTYPES
35+
_TVM_OPS = [
36+
'equal',
37+
'not_equal',
38+
'less',
39+
'less_equal',
40+
'greater',
41+
'greater_equal'
42+
]
3543

3644

3745
class OpArgMngr(object):
@@ -535,6 +543,13 @@ def _add_workload_roll():
535543

536544
def _add_workload_stack(array_pool):
537545
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2)
546+
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, 1)
547+
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -1)
548+
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -2)
549+
OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), 2)
550+
OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), -3)
551+
OpArgMngr.add_workload('stack', np.array([[], [], []]), 1)
552+
OpArgMngr.add_workload('stack', np.array([[], [], []]))
538553

539554

540555
def _add_workload_sum():
@@ -590,10 +605,22 @@ def _add_workload_unique():
590605

591606
def _add_workload_var(array_pool):
592607
OpArgMngr.add_workload('var', array_pool['4x1'])
608+
OpArgMngr.add_workload('var', np.array([np.float16(1.)]))
609+
OpArgMngr.add_workload('var', np.array([1]))
610+
OpArgMngr.add_workload('var', np.array([1.]))
611+
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]))
612+
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 0)
613+
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 1)
614+
OpArgMngr.add_workload('var', np.array([np.nan]))
615+
OpArgMngr.add_workload('var', np.array([1, -1, 1, -1]))
616+
OpArgMngr.add_workload('var', np.array([1,2,3,4], dtype='f8'))
593617

594618

595619
def _add_workload_zeros_like(array_pool):
596620
OpArgMngr.add_workload('zeros_like', array_pool['4x1'])
621+
OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float64))
622+
OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float32))
623+
OpArgMngr.add_workload('zeros_like', np.random.randint(2, size = (3, 3)))
597624

598625

599626
def _add_workload_outer():
@@ -933,6 +960,53 @@ def _add_workload_logical_not(array_pool):
933960
OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool))
934961

935962

963+
def _add_workload_vdot():
964+
OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)), np.random.normal(size=(4, 2)))
965+
OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)).astype(np.float64), np.random.normal(size=(2, 4)).astype(np.float64))
966+
967+
968+
def _add_workload_vstack(array_pool):
969+
OpArgMngr.add_workload('vstack', (array_pool['4x1'], np.random.uniform(size=(5, 1))))
970+
OpArgMngr.add_workload('vstack', array_pool['4x1'])
971+
OpArgMngr.add_workload('vstack', array_pool['1x1x0'])
972+
973+
974+
def _add_workload_equal(array_pool):
975+
OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
976+
OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
977+
OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2'])
978+
979+
980+
def _add_workload_not_equal(array_pool):
981+
OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
982+
OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
983+
OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2'])
984+
985+
986+
def _add_workload_greater(array_pool):
987+
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
988+
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
989+
OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
990+
991+
992+
def _add_workload_greater_equal(array_pool):
993+
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
994+
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
995+
OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
996+
997+
998+
def _add_workload_less(array_pool):
999+
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
1000+
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
1001+
OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
1002+
1003+
1004+
def _add_workload_less_equal(array_pool):
1005+
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
1006+
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
1007+
OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))
1008+
1009+
9361010
@use_np
9371011
def _prepare_workloads():
9381012
array_pool = {
@@ -1028,6 +1102,14 @@ def _prepare_workloads():
10281102
_add_workload_turnc(array_pool)
10291103
_add_workload_floor(array_pool)
10301104
_add_workload_logical_not(array_pool)
1105+
_add_workload_vdot()
1106+
_add_workload_vstack(array_pool)
1107+
_add_workload_equal(array_pool)
1108+
_add_workload_not_equal(array_pool)
1109+
_add_workload_greater(array_pool)
1110+
_add_workload_greater_equal(array_pool)
1111+
_add_workload_less(array_pool)
1112+
_add_workload_less_equal(array_pool)
10311113

10321114

10331115
_prepare_workloads()
@@ -1070,6 +1152,8 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
10701152

10711153
def check_interoperability(op_list):
10721154
for name in op_list:
1155+
if name in _TVM_OPS and not is_op_runnable():
1156+
continue
10731157
print('Dispatch test:', name)
10741158
workloads = OpArgMngr.get_workloads(name)
10751159
assert workloads is not None, 'Workloads for operator `{}` has not been ' \

0 commit comments

Comments
 (0)