|
32 | 32 | _INT_DTYPES = [np.int8, np.int32, np.int64, np.uint8]
|
33 | 33 | _FLOAT_DTYPES = [np.float16, np.float32, np.float64]
|
34 | 34 | _DTYPES = _INT_DTYPES + _FLOAT_DTYPES
|
| 35 | +_TVM_OPS = [ |
| 36 | + 'equal', |
| 37 | + 'not_equal', |
| 38 | + 'less', |
| 39 | + 'less_equal', |
| 40 | + 'greater', |
| 41 | + 'greater_equal' |
| 42 | +] |
35 | 43 |
|
36 | 44 |
|
37 | 45 | class OpArgMngr(object):
|
@@ -535,6 +543,13 @@ def _add_workload_roll():
|
535 | 543 |
|
536 | 544 | def _add_workload_stack(array_pool):
|
537 | 545 | OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2)
|
| 546 | + OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, 1) |
| 547 | + OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -1) |
| 548 | + OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -2) |
| 549 | + OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), 2) |
| 550 | + OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), -3) |
| 551 | + OpArgMngr.add_workload('stack', np.array([[], [], []]), 1) |
| 552 | + OpArgMngr.add_workload('stack', np.array([[], [], []])) |
538 | 553 |
|
539 | 554 |
|
540 | 555 | def _add_workload_sum():
|
@@ -590,10 +605,22 @@ def _add_workload_unique():
|
590 | 605 |
|
591 | 606 | def _add_workload_var(array_pool):
|
592 | 607 | OpArgMngr.add_workload('var', array_pool['4x1'])
|
| 608 | + OpArgMngr.add_workload('var', np.array([np.float16(1.)])) |
| 609 | + OpArgMngr.add_workload('var', np.array([1])) |
| 610 | + OpArgMngr.add_workload('var', np.array([1.])) |
| 611 | + OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]])) |
| 612 | + OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 0) |
| 613 | + OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 1) |
| 614 | + OpArgMngr.add_workload('var', np.array([np.nan])) |
| 615 | + OpArgMngr.add_workload('var', np.array([1, -1, 1, -1])) |
| 616 | + OpArgMngr.add_workload('var', np.array([1,2,3,4], dtype='f8')) |
593 | 617 |
|
594 | 618 |
|
595 | 619 | def _add_workload_zeros_like(array_pool):
|
596 | 620 | OpArgMngr.add_workload('zeros_like', array_pool['4x1'])
|
| 621 | + OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float64)) |
| 622 | + OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float32)) |
| 623 | + OpArgMngr.add_workload('zeros_like', np.random.randint(2, size = (3, 3))) |
597 | 624 |
|
598 | 625 |
|
599 | 626 | def _add_workload_outer():
|
@@ -933,6 +960,53 @@ def _add_workload_logical_not(array_pool):
|
933 | 960 | OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool))
|
934 | 961 |
|
935 | 962 |
|
| 963 | +def _add_workload_vdot(): |
| 964 | + OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)), np.random.normal(size=(4, 2))) |
| 965 | + OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)).astype(np.float64), np.random.normal(size=(2, 4)).astype(np.float64)) |
| 966 | + |
| 967 | + |
| 968 | +def _add_workload_vstack(array_pool): |
| 969 | + OpArgMngr.add_workload('vstack', (array_pool['4x1'], np.random.uniform(size=(5, 1)))) |
| 970 | + OpArgMngr.add_workload('vstack', array_pool['4x1']) |
| 971 | + OpArgMngr.add_workload('vstack', array_pool['1x1x0']) |
| 972 | + |
| 973 | + |
| 974 | +def _add_workload_equal(array_pool): |
| 975 | + OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 976 | + OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan])) |
| 977 | + OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2']) |
| 978 | + |
| 979 | + |
| 980 | +def _add_workload_not_equal(array_pool): |
| 981 | + OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 982 | + OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan])) |
| 983 | + OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2']) |
| 984 | + |
| 985 | + |
| 986 | +def _add_workload_greater(array_pool): |
| 987 | + OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 988 | + OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2']) |
| 989 | + OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan])) |
| 990 | + |
| 991 | + |
| 992 | +def _add_workload_greater_equal(array_pool): |
| 993 | + OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 994 | + OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2']) |
| 995 | + OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan])) |
| 996 | + |
| 997 | + |
| 998 | +def _add_workload_less(array_pool): |
| 999 | + OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 1000 | + OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2']) |
| 1001 | + OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan])) |
| 1002 | + |
| 1003 | + |
| 1004 | +def _add_workload_less_equal(array_pool): |
| 1005 | + OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) |
| 1006 | + OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2']) |
| 1007 | + OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan])) |
| 1008 | + |
| 1009 | + |
936 | 1010 | @use_np
|
937 | 1011 | def _prepare_workloads():
|
938 | 1012 | array_pool = {
|
@@ -1028,6 +1102,14 @@ def _prepare_workloads():
|
1028 | 1102 | _add_workload_turnc(array_pool)
|
1029 | 1103 | _add_workload_floor(array_pool)
|
1030 | 1104 | _add_workload_logical_not(array_pool)
|
| 1105 | + _add_workload_vdot() |
| 1106 | + _add_workload_vstack(array_pool) |
| 1107 | + _add_workload_equal(array_pool) |
| 1108 | + _add_workload_not_equal(array_pool) |
| 1109 | + _add_workload_greater(array_pool) |
| 1110 | + _add_workload_greater_equal(array_pool) |
| 1111 | + _add_workload_less(array_pool) |
| 1112 | + _add_workload_less_equal(array_pool) |
1031 | 1113 |
|
1032 | 1114 |
|
1033 | 1115 | _prepare_workloads()
|
@@ -1070,6 +1152,8 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
|
1070 | 1152 |
|
1071 | 1153 | def check_interoperability(op_list):
|
1072 | 1154 | for name in op_list:
|
| 1155 | + if name in _TVM_OPS and not is_op_runnable(): |
| 1156 | + continue |
1073 | 1157 | print('Dispatch test:', name)
|
1074 | 1158 | workloads = OpArgMngr.get_workloads(name)
|
1075 | 1159 | assert workloads is not None, 'Workloads for operator `{}` has not been ' \
|
|
0 commit comments