diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index e1b1a95838c5..3dc80cd5a0b4 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -97,6 +97,5 @@ # fact that kvstore-server module is imported before the __version__ attr is set. from . import kvstore_server -from .numpy_dispatch_protocol import _register_array_function, _register_array_ufunc -_register_array_function() -_register_array_ufunc() +from . import numpy_op_signature +from . import numpy_dispatch_protocol diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 6d2776e98bc9..6e2f5fa15919 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -22,8 +22,6 @@ def _np_ones_like(a): """ - ones_like(a) - Return an array of ones with the same shape and type as a given array. Parameters @@ -42,8 +40,6 @@ def _np_ones_like(a): def _np_zeros_like(a): """ - zeros_like(a) - Return an array of zeros with the same shape and type as a given array. Parameters @@ -62,8 +58,6 @@ def _np_zeros_like(a): def _np_cumsum(a, axis=None, dtype=None, out=None): """ - cumsum(a, axis=None, dtype=None, out=None) - Return the cumulative sum of the elements along a given axis. Parameters @@ -115,8 +109,6 @@ def _np_cumsum(a, axis=None, dtype=None, out=None): def _npx_nonzero(a): """ - nonzero(a) - Return the indices of the elements that are non-zero. Returns a ndarray with ndim is 2. Each row contains the indices @@ -164,8 +156,6 @@ def _npx_nonzero(a): def _np_repeat(a, repeats, axis=None): """ - repeat(a, repeats, axis=None) - Repeat elements of an array. Parameters @@ -213,8 +203,6 @@ def _np_repeat(a, repeats, axis=None): def _np_transpose(a, axes=None): """ - transpose(a, axes=None) - Permute the dimensions of an array. Parameters @@ -256,8 +244,7 @@ def _np_transpose(a, axes=None): def _np_dot(a, b, out=None): - """dot(a, b, out=None) - + """ Dot product of two arrays. Specifically, - If both `a` and `b` are 1-D arrays, it is inner product of vectors @@ -318,10 +305,8 @@ def _np_dot(a, b, out=None): pass -def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None): +def _np_sum(a, axis=None, dtype=None, keepdims=False, initial=None, out=None): r""" - sum(a, axis=None, dtype=None, keepdims=_Null, initial=_Null, out=None) - Sum of array elements over a given axis. Parameters @@ -414,8 +399,6 @@ def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None): def _np_copy(a, out=None): """ - copy(a, out=None) - Return an array copy of the given object. Parameters @@ -463,8 +446,6 @@ def _np_copy(a, out=None): def _np_reshape(a, newshape, order='C', out=None): """ - reshape(a, newshape, order='C') - Gives a new shape to an array without changing its data. This function always returns a copy of the input array if ``out`` is not provided. @@ -501,8 +482,6 @@ def _np_reshape(a, newshape, order='C', out=None): def _np__linalg_svd(a): r""" - svd(a) - Singular Value Decomposition. When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``, @@ -568,8 +547,6 @@ def _np__linalg_svd(a): def _np_roll(a, shift, axis=None): """ - roll(a, shift, axis=None): - Roll array elements along a given axis. Elements that roll beyond the last position are re-introduced at @@ -633,8 +610,7 @@ def _np_roll(a, shift, axis=None): def _np_trace(a, offset=0, axis1=0, axis2=1, out=None): - """trace(a, offset=0, axis1=0, axis2=1, out=None) - + """ Return the sum along diagonals of the array. If `a` is 2-D, the sum along its diagonal with the given offset is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i. diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 93b81b1a868d..f483e299ffed 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -214,3 +214,7 @@ def _register_array_ufunc(): _NUMPY_ARRAY_UFUNC_DICT[op_name] = mx_np_op except AttributeError: raise AttributeError('mxnet.numpy does not have operator named {}'.format(op_name)) + + +_register_array_function() +_register_array_ufunc() diff --git a/python/mxnet/numpy_op_signature.py b/python/mxnet/numpy_op_signature.py new file mode 100644 index 000000000000..e42ba264b37e --- /dev/null +++ b/python/mxnet/numpy_op_signature.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Make builtin ops' signatures compatible with NumPy.""" + +from __future__ import absolute_import +import sys +import warnings +from . import _numpy_op_doc +from . import numpy as mx_np +from . import numpy_extension as mx_npx +from .base import _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_SUBMODULE_LIST, _get_op_submodule_name + + +def _get_builtin_op(op_name): + if op_name.startswith('_np_'): + root_module = mx_np + op_name_prefix = '_np_' + submodule_name_list = _NP_OP_SUBMODULE_LIST + elif op_name.startswith('_npx_'): + root_module = mx_npx + op_name_prefix = '_npx_' + submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST + else: + return None + + submodule_name = _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list) + if len(submodule_name) > 0: + op_module = getattr(root_module, submodule_name[1:-1], None) + if op_module is None: + raise ValueError('Cannot find submodule {} in module {}' + .format(submodule_name[1:-1], root_module.__name__)) + else: + op_module = root_module + + op = getattr(op_module, op_name[(len(op_name_prefix)+len(submodule_name)):], None) + if op is None: + raise ValueError('Cannot find operator {} in module {}' + .format(op_name[op_name_prefix:], root_module.__name__)) + return op + + +def _register_op_signatures(): + if sys.version_info.major < 3 or sys.version_info.minor < 5: + warnings.warn('Some mxnet.numpy operator signatures may not be displayed consistently with ' + 'their counterparts in the official NumPy package due to too-low Python ' + 'version {}. Python >= 3.5 is required to make the signatures display correctly.' + .format(str(sys.version))) + return + + import inspect + for op_name in dir(_numpy_op_doc): + op = _get_builtin_op(op_name) + if op is not None: + op.__signature__ = inspect.signature(getattr(_numpy_op_doc, op_name)) + + +_register_op_signatures() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7e3d9655f771..eaf3032d526d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -17,6 +17,8 @@ # pylint: skip-file from __future__ import absolute_import +import sys +import unittest import numpy as _np import mxnet as mx from mxnet import np, npx @@ -29,6 +31,7 @@ import scipy.stats as ss from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry from mxnet.runtime import Features +from mxnet.numpy_op_signature import _get_builtin_op import platform @@ -2810,6 +2813,18 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): check_output_n_grad(config[0], config[1], config[2], mode) +@unittest.skipUnless(sys.version_info.major >= 3 and sys.version_info.minor >= 5, + 'inspect package requires Python >= 3.5 to work properly') +@with_seed() +def test_np_builtin_op_signature(): + import inspect + from mxnet import _numpy_op_doc + for op_name in dir(_numpy_op_doc): + op = _get_builtin_op(op_name) + if op is not None: + assert str(op.__signature__) == str(inspect.signature(getattr(_numpy_op_doc, op_name))) + + if __name__ == '__main__': import nose nose.runmodule()