Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit db81cc1

Browse files
committed
add dispatch
1 parent 69ddf4d commit db81cc1

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
127127
'tril',
128128
'meshgrid',
129129
'outer',
130-
'einsum'
130+
'einsum',
131+
'shares_memory',
132+
'may_share_memory',
131133
]
132134

133135

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,8 @@ def check_interoperability(op_list):
12341234
for name in op_list:
12351235
if name in _TVM_OPS and not is_op_runnable():
12361236
continue
1237+
if name in ['shares_memory', 'may_share_memory']: # skip list
1238+
continue
12371239
print('Dispatch test:', name)
12381240
workloads = OpArgMngr.get_workloads(name)
12391241
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
@@ -1243,6 +1245,19 @@ def check_interoperability(op_list):
12431245
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
12441246

12451247

1248+
@with_seed()
1249+
@use_np
1250+
@with_array_function_protocol
1251+
def test_np_memory_array_function():
1252+
ops = [_np.shares_memory, _np.may_share_memory]
1253+
for op in ops:
1254+
data_mx = np.zeros([13, 21, 23, 22], dtype=np.float32)
1255+
data_np = _np.zeros([13, 21, 23, 22], dtype=np.float32)
1256+
assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:])
1257+
assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7])
1258+
assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0)))
1259+
1260+
12461261
@with_seed()
12471262
@use_np
12481263
@with_array_function_protocol

0 commit comments

Comments
 (0)