Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ea440c7

Browse files
reminiscehaojin2
authored andcommitted
[numpy] Cosmetic improvement on mxnet.numpy builtin op signature in documentation (#16305)
* Init checkin * Clean up * Add test * Fix * Local import inspect * Fix lint
1 parent 512d25a commit ea440c7

File tree

5 files changed

+96
-30
lines changed

5 files changed

+96
-30
lines changed

python/mxnet/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,5 @@
9797
# fact that kvstore-server module is imported before the __version__ attr is set.
9898
from . import kvstore_server
9999

100-
from .numpy_dispatch_protocol import _register_array_function, _register_array_ufunc
101-
_register_array_function()
102-
_register_array_ufunc()
100+
from . import numpy_op_signature
101+
from . import numpy_dispatch_protocol

python/mxnet/_numpy_op_doc.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
def _np_ones_like(a):
2424
"""
25-
ones_like(a)
26-
2725
Return an array of ones with the same shape and type as a given array.
2826
2927
Parameters
@@ -42,8 +40,6 @@ def _np_ones_like(a):
4240

4341
def _np_zeros_like(a):
4442
"""
45-
zeros_like(a)
46-
4743
Return an array of zeros with the same shape and type as a given array.
4844
4945
Parameters
@@ -62,8 +58,6 @@ def _np_zeros_like(a):
6258

6359
def _np_cumsum(a, axis=None, dtype=None, out=None):
6460
"""
65-
cumsum(a, axis=None, dtype=None, out=None)
66-
6761
Return the cumulative sum of the elements along a given axis.
6862
6963
Parameters
@@ -115,8 +109,6 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):
115109

116110
def _npx_nonzero(a):
117111
"""
118-
nonzero(a)
119-
120112
Return the indices of the elements that are non-zero.
121113
122114
Returns a ndarray with ndim is 2. Each row contains the indices
@@ -164,8 +156,6 @@ def _npx_nonzero(a):
164156

165157
def _np_repeat(a, repeats, axis=None):
166158
"""
167-
repeat(a, repeats, axis=None)
168-
169159
Repeat elements of an array.
170160
171161
Parameters
@@ -213,8 +203,6 @@ def _np_repeat(a, repeats, axis=None):
213203

214204
def _np_transpose(a, axes=None):
215205
"""
216-
transpose(a, axes=None)
217-
218206
Permute the dimensions of an array.
219207
220208
Parameters
@@ -256,8 +244,7 @@ def _np_transpose(a, axes=None):
256244

257245

258246
def _np_dot(a, b, out=None):
259-
"""dot(a, b, out=None)
260-
247+
"""
261248
Dot product of two arrays. Specifically,
262249
263250
- If both `a` and `b` are 1-D arrays, it is inner product of vectors
@@ -318,10 +305,8 @@ def _np_dot(a, b, out=None):
318305
pass
319306

320307

321-
def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None):
308+
def _np_sum(a, axis=None, dtype=None, keepdims=False, initial=None, out=None):
322309
r"""
323-
sum(a, axis=None, dtype=None, keepdims=_Null, initial=_Null, out=None)
324-
325310
Sum of array elements over a given axis.
326311
327312
Parameters
@@ -414,8 +399,6 @@ def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None):
414399

415400
def _np_copy(a, out=None):
416401
"""
417-
copy(a, out=None)
418-
419402
Return an array copy of the given object.
420403
421404
Parameters
@@ -463,8 +446,6 @@ def _np_copy(a, out=None):
463446

464447
def _np_reshape(a, newshape, order='C', out=None):
465448
"""
466-
reshape(a, newshape, order='C')
467-
468449
Gives a new shape to an array without changing its data.
469450
This function always returns a copy of the input array if
470451
``out`` is not provided.
@@ -501,8 +482,6 @@ def _np_reshape(a, newshape, order='C', out=None):
501482

502483
def _np__linalg_svd(a):
503484
r"""
504-
svd(a)
505-
506485
Singular Value Decomposition.
507486
508487
When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
@@ -568,8 +547,6 @@ def _np__linalg_svd(a):
568547

569548
def _np_roll(a, shift, axis=None):
570549
"""
571-
roll(a, shift, axis=None):
572-
573550
Roll array elements along a given axis.
574551
575552
Elements that roll beyond the last position are re-introduced at
@@ -633,8 +610,7 @@ def _np_roll(a, shift, axis=None):
633610

634611

635612
def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
636-
"""trace(a, offset=0, axis1=0, axis2=1, out=None)
637-
613+
"""
638614
Return the sum along diagonals of the array.
639615
If `a` is 2-D, the sum along its diagonal with the given offset
640616
is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i.

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,7 @@ def _register_array_ufunc():
214214
_NUMPY_ARRAY_UFUNC_DICT[op_name] = mx_np_op
215215
except AttributeError:
216216
raise AttributeError('mxnet.numpy does not have operator named {}'.format(op_name))
217+
218+
219+
_register_array_function()
220+
_register_array_ufunc()

python/mxnet/numpy_op_signature.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Make builtin ops' signatures compatible with NumPy."""
19+
20+
from __future__ import absolute_import
21+
import sys
22+
import warnings
23+
from . import _numpy_op_doc
24+
from . import numpy as mx_np
25+
from . import numpy_extension as mx_npx
26+
from .base import _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_SUBMODULE_LIST, _get_op_submodule_name
27+
28+
29+
def _get_builtin_op(op_name):
30+
if op_name.startswith('_np_'):
31+
root_module = mx_np
32+
op_name_prefix = '_np_'
33+
submodule_name_list = _NP_OP_SUBMODULE_LIST
34+
elif op_name.startswith('_npx_'):
35+
root_module = mx_npx
36+
op_name_prefix = '_npx_'
37+
submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST
38+
else:
39+
return None
40+
41+
submodule_name = _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list)
42+
if len(submodule_name) > 0:
43+
op_module = getattr(root_module, submodule_name[1:-1], None)
44+
if op_module is None:
45+
raise ValueError('Cannot find submodule {} in module {}'
46+
.format(submodule_name[1:-1], root_module.__name__))
47+
else:
48+
op_module = root_module
49+
50+
op = getattr(op_module, op_name[(len(op_name_prefix)+len(submodule_name)):], None)
51+
if op is None:
52+
raise ValueError('Cannot find operator {} in module {}'
53+
.format(op_name[op_name_prefix:], root_module.__name__))
54+
return op
55+
56+
57+
def _register_op_signatures():
58+
if sys.version_info.major < 3 or sys.version_info.minor < 5:
59+
warnings.warn('Some mxnet.numpy operator signatures may not be displayed consistently with '
60+
'their counterparts in the official NumPy package due to too-low Python '
61+
'version {}. Python >= 3.5 is required to make the signatures display correctly.'
62+
.format(str(sys.version)))
63+
return
64+
65+
import inspect
66+
for op_name in dir(_numpy_op_doc):
67+
op = _get_builtin_op(op_name)
68+
if op is not None:
69+
op.__signature__ = inspect.signature(getattr(_numpy_op_doc, op_name))
70+
71+
72+
_register_op_signatures()

tests/python/unittest/test_numpy_op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
# pylint: skip-file
1919
from __future__ import absolute_import
20+
import sys
21+
import unittest
2022
import numpy as _np
2123
import mxnet as mx
2224
from mxnet import np, npx
@@ -29,6 +31,7 @@
2931
import scipy.stats as ss
3032
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
3133
from mxnet.runtime import Features
34+
from mxnet.numpy_op_signature import _get_builtin_op
3235
import platform
3336

3437

@@ -2810,6 +2813,18 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
28102813
check_output_n_grad(config[0], config[1], config[2], mode)
28112814

28122815

2816+
@unittest.skipUnless(sys.version_info.major >= 3 and sys.version_info.minor >= 5,
2817+
'inspect package requires Python >= 3.5 to work properly')
2818+
@with_seed()
2819+
def test_np_builtin_op_signature():
2820+
import inspect
2821+
from mxnet import _numpy_op_doc
2822+
for op_name in dir(_numpy_op_doc):
2823+
op = _get_builtin_op(op_name)
2824+
if op is not None:
2825+
assert str(op.__signature__) == str(inspect.signature(getattr(_numpy_op_doc, op_name)))
2826+
2827+
28132828
if __name__ == '__main__':
28142829
import nose
28152830
nose.runmodule()

0 commit comments

Comments
 (0)