Skip to content

Commit e2c924b

Browse files
hgt312yajiedesign
authored andcommitted
[NumPy][Operator] NumPy operator may_share_memory and shares_memory (apache#16533)
* init * finish & fix bug of 'take' * fix bug * add dispatch
1 parent 0b1ec85 commit e2c924b

File tree

9 files changed

+406
-7
lines changed

9 files changed

+406
-7
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
3939
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
4040
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
41-
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero']
41+
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory']
4242

4343

4444
@set_module('mxnet.ndarray.numpy')
@@ -4909,3 +4909,77 @@ def nonzero(a):
49094909
"""
49104910
out = _npi.nonzero(a).transpose()
49114911
return tuple([out[i] for i in range(len(out))])
4912+
4913+
4914+
@set_module('mxnet.ndarray.numpy')
4915+
def shares_memory(a, b, max_work=None):
4916+
"""
4917+
Determine if two arrays share memory
4918+
4919+
Parameters
4920+
----------
4921+
a, b : ndarray
4922+
Input arrays
4923+
4924+
Returns
4925+
-------
4926+
out : bool
4927+
4928+
See Also
4929+
--------
4930+
may_share_memory
4931+
4932+
Examples
4933+
--------
4934+
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
4935+
False
4936+
4937+
This function differs from the original `numpy.shares_memory
4938+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.shares_memory.html>`_ in
4939+
the following way(s):
4940+
4941+
- Does not support `max_work`, it is a dummy argument
4942+
- Actually it is same as `may_share_memory` in MXNet DeepNumPy
4943+
"""
4944+
return _npi.share_memory(a, b).item()
4945+
4946+
4947+
@set_module('mxnet.ndarray.numpy')
4948+
def may_share_memory(a, b, max_work=None):
4949+
"""
4950+
Determine if two arrays might share memory
4951+
4952+
A return of True does not necessarily mean that the two arrays
4953+
share any element. It just means that they *might*.
4954+
4955+
Only the memory bounds of a and b are checked by default.
4956+
4957+
Parameters
4958+
----------
4959+
a, b : ndarray
4960+
Input arrays
4961+
4962+
Returns
4963+
-------
4964+
out : bool
4965+
4966+
See Also
4967+
--------
4968+
shares_memory
4969+
4970+
Examples
4971+
--------
4972+
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
4973+
False
4974+
>>> x = np.zeros([3, 4])
4975+
>>> np.may_share_memory(x[:,0], x[:,1])
4976+
True
4977+
4978+
This function differs from the original `numpy.may_share_memory
4979+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.may_share_memory.html>`_ in
4980+
the following way(s):
4981+
4982+
- Does not support `max_work`, it is a dummy argument
4983+
- Actually it is same as `shares_memory` in MXNet DeepNumPy
4984+
"""
4985+
return _npi.share_memory(a, b).item()

python/mxnet/numpy/multiarray.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
5656
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
5757
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
58-
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero']
58+
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
59+
'may_share_memory']
5960

6061
# Return code for dispatching indexing function call
6162
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -1330,7 +1331,7 @@ def take(self, indices, axis=None, mode='raise'): # pylint: disable=arguments-d
13301331
The arguments are the same as for :py:func:`take`, with
13311332
this array as data.
13321333
"""
1333-
take(self, indices, axis, mode=mode)
1334+
return take(self, indices, axis, mode=mode)
13341335

13351336
def one_hot(self, *args, **kwargs):
13361337
"""Convenience fluent method for :py:func:`one_hot`.
@@ -6900,3 +6901,77 @@ def nonzero(a):
69006901
(array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64))
69016902
"""
69026903
return _mx_nd_np.nonzero(a)
6904+
6905+
6906+
@set_module('mxnet.numpy')
6907+
def shares_memory(a, b, max_work=None):
6908+
"""
6909+
Determine if two arrays share memory
6910+
6911+
Parameters
6912+
----------
6913+
a, b : ndarray
6914+
Input arrays
6915+
6916+
Returns
6917+
-------
6918+
out : bool
6919+
6920+
See Also
6921+
--------
6922+
may_share_memory
6923+
6924+
Examples
6925+
--------
6926+
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
6927+
False
6928+
6929+
This function differs from the original `numpy.shares_memory
6930+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.shares_memory.html>`_ in
6931+
the following way(s):
6932+
6933+
- Does not support `max_work`, it is a dummy argument
6934+
- Actually it is same as `may_share_memory` in MXNet DeepNumPy
6935+
"""
6936+
return _mx_nd_np.shares_memory(a, b, max_work)
6937+
6938+
6939+
@set_module('mxnet.numpy')
6940+
def may_share_memory(a, b, max_work=None):
6941+
"""
6942+
Determine if two arrays might share memory
6943+
6944+
A return of True does not necessarily mean that the two arrays
6945+
share any element. It just means that they *might*.
6946+
6947+
Only the memory bounds of a and b are checked by default.
6948+
6949+
Parameters
6950+
----------
6951+
a, b : ndarray
6952+
Input arrays
6953+
6954+
Returns
6955+
-------
6956+
out : bool
6957+
6958+
See Also
6959+
--------
6960+
shares_memory
6961+
6962+
Examples
6963+
--------
6964+
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
6965+
False
6966+
>>> x = np.zeros([3, 4])
6967+
>>> np.may_share_memory(x[:,0], x[:,1])
6968+
True
6969+
6970+
This function differs from the original `numpy.may_share_memory
6971+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.may_share_memory.html>`_ in
6972+
the following way(s):
6973+
6974+
- Does not support `max_work`, it is a dummy argument
6975+
- Actually it is same as `shares_memory` in MXNet DeepNumPy
6976+
"""
6977+
return _mx_nd_np.may_share_memory(a, b, max_work)

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
127127
'tril',
128128
'meshgrid',
129129
'outer',
130-
'einsum'
130+
'einsum',
131+
'shares_memory',
132+
'may_share_memory',
131133
]
132134

133135

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
4141
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
4242
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
43-
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide']
43+
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory']
4444

4545

4646
def _num_outputs(sym):
@@ -4590,4 +4590,43 @@ def einsum(*operands, **kwargs):
45904590
return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg))
45914591

45924592

4593+
@set_module('mxnet.symbol.numpy')
4594+
def shares_memory(a, b, max_work=None):
4595+
"""
4596+
Determine if two arrays share memory
4597+
4598+
Parameters
4599+
----------
4600+
a, b : _Symbol
4601+
Input arrays
4602+
4603+
Returns
4604+
-------
4605+
out : _Symbol
4606+
"""
4607+
return _npi.share_memory(a, b)
4608+
4609+
4610+
@set_module('mxnet.symbol.numpy')
4611+
def may_share_memory(a, b, max_work=None):
4612+
"""
4613+
Determine if two arrays might share memory
4614+
4615+
A return of True does not necessarily mean that the two arrays
4616+
share any element. It just means that they *might*.
4617+
4618+
Only the memory bounds of a and b are checked by default.
4619+
4620+
Parameters
4621+
----------
4622+
a, b : _Symbol
4623+
Input arrays
4624+
4625+
Returns
4626+
-------
4627+
out : _Symbol
4628+
"""
4629+
return _npi.share_memory(a, b)
4630+
4631+
45934632
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_memory_op.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file np_memory_op.cc
23+
*/
24+
25+
#include "./np_memory_op.h"
26+
27+
namespace mxnet {
28+
namespace op {
29+
30+
inline bool NumpyShareMemoryType(const nnvm::NodeAttrs& attrs,
31+
std::vector<int> *in_attrs,
32+
std::vector<int> *out_attrs) {
33+
CHECK_EQ(in_attrs->size(), 2U);
34+
CHECK_EQ(out_attrs->size(), 1U);
35+
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
36+
return out_attrs->at(0) != -1;
37+
}
38+
39+
inline bool NumpyShareMemoryShape(const nnvm::NodeAttrs& attrs,
40+
mxnet::ShapeVector *in_attrs,
41+
mxnet::ShapeVector *out_attrs) {
42+
CHECK_EQ(in_attrs->size(), 2U);
43+
CHECK_EQ(out_attrs->size(), 1U);
44+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, -1));
45+
return true;
46+
}
47+
48+
NNVM_REGISTER_OP(_npi_share_memory)
49+
.set_num_inputs(2)
50+
.set_num_outputs(1)
51+
.set_attr<nnvm::FListInputNames>("FListInputNames",
52+
[](const NodeAttrs& attrs) {
53+
return std::vector<std::string>{"a", "b"};
54+
})
55+
.set_attr<mxnet::FInferShape>("FInferShape", NumpyShareMemoryShape)
56+
.set_attr<nnvm::FInferType>("FInferType", NumpyShareMemoryType)
57+
.set_attr<FCompute>("FCompute<cpu>", NumpyShareMemoryCompute<cpu>)
58+
.add_argument("a", "NDArray-or-Symbol", "First input")
59+
.add_argument("b", "NDArray-or-Symbol", "Second input");
60+
61+
} // namespace op
62+
} // namespace mxnet

src/operator/numpy/np_memory_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file np_memory_op.cu
23+
*/
24+
25+
#include "./np_memory_op.h"
26+
27+
namespace mxnet {
28+
namespace op {
29+
30+
NNVM_REGISTER_OP(_npi_share_memory)
31+
.set_attr<FCompute>("FCompute<gpu>", NumpyShareMemoryCompute<gpu>);
32+
33+
} // namespace op
34+
} // namespace mxnet

0 commit comments

Comments
 (0)