Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4c82f3d

Browse files
committed
change is_op_runnable's location
1 parent afb6dab commit 4c82f3d

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from __future__ import absolute_import
2121
import functools
2222
import numpy as _np
23-
from mxnet.test_utils import is_op_runnable
2423
from . import numpy as mx_np # pylint: disable=reimported
2524
from .numpy.multiarray import _NUMPY_ARRAY_FUNCTION_DICT, _NUMPY_ARRAY_UFUNC_DICT
2625

@@ -214,6 +213,12 @@ def _register_array_function():
214213
'trunc',
215214
'floor',
216215
'logical_not',
216+
'equal',
217+
'not_equal',
218+
'less',
219+
'less_equal',
220+
'greater',
221+
'greater_equal'
217222
]
218223

219224

@@ -225,13 +230,6 @@ def _register_array_ufunc():
225230
----------
226231
https://numpy.org/neps/nep-0013-ufunc-overrides.html
227232
"""
228-
if is_op_runnable():
229-
_NUMPY_ARRAY_UFUNC_LIST.extend(['equal',
230-
'not_equal',
231-
'less',
232-
'less_equal',
233-
'greater',
234-
'greater_equal'])
235233
dup = _find_duplicate(_NUMPY_ARRAY_UFUNC_LIST)
236234
if dup is not None:
237235
raise ValueError('Duplicate operator name {} in _NUMPY_ARRAY_UFUNC_LIST'.format(dup))

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mxnet import np
2424
from mxnet.test_utils import assert_almost_equal
2525
from mxnet.test_utils import use_np
26+
from mxnet.test_utils import is_op_runnable
2627
from common import assertRaises, with_seed
2728
from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
2829
from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST
@@ -934,12 +935,13 @@ def _prepare_workloads():
934935
_add_workload_logical_not(array_pool)
935936
_add_workload_vdot()
936937
_add_workload_vstack(array_pool)
937-
_add_workload_equal(array_pool)
938-
_add_workload_not_equal(array_pool)
939-
_add_workload_greater(array_pool)
940-
_add_workload_greater_equal(array_pool)
941-
_add_workload_less(array_pool)
942-
_add_workload_less_equal(array_pool)
938+
if is_op_runnable():
939+
_add_workload_equal(array_pool)
940+
_add_workload_not_equal(array_pool)
941+
_add_workload_greater(array_pool)
942+
_add_workload_greater_equal(array_pool)
943+
_add_workload_less(array_pool)
944+
_add_workload_less_equal(array_pool)
943945

944946

945947
_prepare_workloads()

0 commit comments

Comments
 (0)