Skip to content

Commit e1e0291

Browse files
reminisceyajiedesign
authored andcommitted
numpy doc enhancement (apache#16637)
* Change NDArray to ndarray for npx ops Add nonzero boolean mask supports boolean ndarray Add argmin op and interoperability test for nonzero Fix vdot, inner, outter docs Add nonzero to mx.nd.np Add docs Fix * Fix lint * Fix * Fix * Fix get_constant
1 parent c6dc4b3 commit e1e0291

24 files changed

+1062
-102
lines changed

python/mxnet/_numpy_op_doc.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ def _np_ones_like(a):
3434
-------
3535
out : ndarray
3636
Array of ones with the same shape and type as `a`.
37+
38+
Examples
39+
--------
40+
>>> x = np.arange(6)
41+
>>> x = x.reshape((2, 3))
42+
>>> x
43+
array([[0., 1., 2.],
44+
[3., 4., 5.]])
45+
>>> np.ones_like(x)
46+
array([[1., 1., 1.],
47+
[1., 1., 1.]])
48+
49+
>>> y = np.arange(3, dtype=float)
50+
>>> y
51+
array([0., 1., 2.], dtype=float64)
52+
>>>
53+
>>> np.ones_like(y)
54+
array([1., 1., 1.], dtype=float64)
3755
"""
3856
pass
3957

@@ -52,6 +70,23 @@ def _np_zeros_like(a):
5270
-------
5371
out : ndarray
5472
Array of zeros with the same shape and type as `a`.
73+
74+
Examples
75+
--------
76+
>>> x = np.arange(6)
77+
>>> x = x.reshape((2, 3))
78+
>>> x
79+
array([[0., 1., 2.],
80+
[3., 4., 5.]])
81+
>>> np.zeros_like(x)
82+
array([[0., 0., 0.],
83+
[0., 0., 0.]])
84+
>>> y = np.arange(3, dtype=float)
85+
>>> y
86+
array([0., 1., 2.], dtype=float64)
87+
>>>
88+
>>> np.zeros_like(y)
89+
array([0., 0., 0.], dtype=float64)
5590
"""
5691
pass
5792

@@ -477,6 +512,31 @@ def _np_reshape(a, newshape, order='C', out=None):
477512
See Also
478513
--------
479514
ndarray.reshape : Equivalent method.
515+
516+
Examples
517+
--------
518+
>>> a = np.arange(6).reshape((3, 2))
519+
>>> a
520+
array([[0., 1.],
521+
[2., 3.],
522+
[4., 5.]])
523+
524+
>>> np.reshape(a, (2, 3)) # C-like index ordering
525+
array([[0., 1., 2.],
526+
[3., 4., 5.]])
527+
528+
>>> np.reshape(np.ravel(a), (2, 3)) # equivalent to C ravel then C reshape
529+
array([[0., 1., 2.],
530+
[3., 4., 5.]])
531+
532+
>>> a = np.array([[1,2,3], [4,5,6]])
533+
>>> np.reshape(a, 6)
534+
array([1., 2., 3., 4., 5., 6.])
535+
536+
>>> np.reshape(a, (3,-1)) # the unspecified value is inferred to be 2
537+
array([[1., 2.],
538+
[3., 4.],
539+
[5., 6.]])
480540
"""
481541

482542

python/mxnet/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""ctypes library of mxnet and helper functions."""
2121
from __future__ import absolute_import
2222

23+
import re
2324
import atexit
2425
import ctypes
2526
import os
@@ -853,3 +854,5 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
853854

854855
if hasattr(_np_op_doc, name):
855856
function.__doc__ = getattr(_np_op_doc, name).__doc__
857+
else:
858+
function.__doc__ = re.sub('NDArray', 'ndarray', function.__doc__)

python/mxnet/gluon/parameter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,8 @@ def __init__(self, **kwargs):
674674
"""
675675
def __init__(self, name, value):
676676
if not isinstance(value, ndarray.NDArray):
677-
value = ndarray.array(value)
677+
array_fn = _mx_np.array if is_np_array() else ndarray.array
678+
value = array_fn(value)
678679
self.value = value
679680

680681
class Init(initializer.Initializer):

python/mxnet/ndarray/numpy/_op.py

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3535
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
3636
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate',
37-
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
37+
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin',
3838
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
3939
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
4040
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
41-
'hsplit', 'rot90', 'einsum', 'true_divide']
41+
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero']
4242

4343

4444
@set_module('mxnet.ndarray.numpy')
@@ -3165,8 +3165,6 @@ def clip(a, a_min, a_max, out=None):
31653165
@set_module('mxnet.ndarray.numpy')
31663166
def argmax(a, axis=None, out=None):
31673167
r"""
3168-
argmax(a, axis=None, out=None)
3169-
31703168
Returns the indices of the maximum values along an axis.
31713169
31723170
Parameters
@@ -3234,6 +3232,75 @@ def argmax(a, axis=None, out=None):
32343232
return _npi.argmax(a, axis=axis, keepdims=False, out=out)
32353233

32363234

3235+
@set_module('mxnet.ndarray.numpy')
3236+
def argmin(a, axis=None, out=None):
3237+
r"""
3238+
Returns the indices of the maximum values along an axis.
3239+
3240+
Parameters
3241+
----------
3242+
a : ndarray
3243+
Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
3244+
axis : int, optional
3245+
By default, the index is into the flattened array, otherwise
3246+
along the specified axis.
3247+
out : ndarray or None, optional
3248+
If provided, the result will be inserted into this array. It should
3249+
be of the appropriate shape and dtype.
3250+
3251+
Returns
3252+
-------
3253+
index_array : ndarray of indices whose dtype is same as the input ndarray.
3254+
Array of indices into the array. It has the same shape as `a.shape`
3255+
with the dimension along `axis` removed.
3256+
3257+
Notes
3258+
-----
3259+
In case of multiple occurrences of the maximum values, the indices
3260+
corresponding to the first occurrence are returned.
3261+
3262+
This function differs from the original `numpy.argmax
3263+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html>`_ in
3264+
the following aspects:
3265+
3266+
- Input type does not support Python native iterables(list, tuple, ...).
3267+
- Output has dtype that is same as the input ndarray.
3268+
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
3269+
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
3270+
- ``out`` param does not support scalar input case.
3271+
3272+
Examples
3273+
--------
3274+
>>> a = np.arange(6).reshape(2,3) + 10
3275+
>>> a
3276+
array([[10., 11., 12.],
3277+
[13., 14., 15.]])
3278+
>>> np.argmin(a)
3279+
array(0.)
3280+
>>> np.argmin(a, axis=0)
3281+
array([0., 0., 0.])
3282+
>>> np.argmin(a, axis=1)
3283+
array([0., 0.])
3284+
3285+
>>> b = np.arange(6)
3286+
>>> b[2] = 0
3287+
>>> b
3288+
array([0., 1., 0., 3., 4., 5.])
3289+
>>> np.argmax(b) # Only the first occurrence is returned.
3290+
array(0.)
3291+
3292+
Specify ``out`` ndarray:
3293+
3294+
>>> a = np.arange(6).reshape(2,3) + 10
3295+
>>> b = np.zeros((2,))
3296+
>>> np.argmin(a, axis=1, out=b)
3297+
array([0., 0.])
3298+
>>> b
3299+
array([0., 0.])
3300+
"""
3301+
return _npi.argmin(a, axis=axis, keepdims=False, out=out)
3302+
3303+
32373304
@set_module('mxnet.ndarray.numpy')
32383305
def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
32393306
"""
@@ -4761,3 +4828,84 @@ def einsum(*operands, **kwargs):
47614828
subscripts = operands[0]
47624829
operands = operands[1:]
47634830
return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg))
4831+
4832+
4833+
@set_module('mxnet.ndarray.numpy')
4834+
def nonzero(a):
4835+
"""
4836+
Return the indices of the elements that are non-zero.
4837+
4838+
Returns a tuple of arrays, one for each dimension of `a`,
4839+
containing the indices of the non-zero elements in that
4840+
dimension. The values in `a` are always returned in
4841+
row-major, C-style order.
4842+
4843+
To group the indices by element, rather than dimension, use `argwhere`,
4844+
which returns a row for each non-zero element.
4845+
4846+
Parameters
4847+
----------
4848+
a : ndarray
4849+
Input array.
4850+
4851+
Returns
4852+
-------
4853+
tuple_of_arrays : tuple
4854+
Indices of elements that are non-zero.
4855+
4856+
See Also
4857+
--------
4858+
ndarray.nonzero :
4859+
Equivalent ndarray method.
4860+
4861+
Notes
4862+
-----
4863+
While the nonzero values can be obtained with ``a[nonzero(a)]``, it is
4864+
recommended to use ``x[x.astype(bool)]`` or ``x[x != 0]`` instead, which
4865+
will correctly handle 0-d arrays.
4866+
4867+
Examples
4868+
--------
4869+
>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
4870+
>>> x
4871+
array([[3, 0, 0],
4872+
[0, 4, 0],
4873+
[5, 6, 0]], dtype=int32)
4874+
>>> np.nonzero(x)
4875+
(array([0, 1, 2, 2], dtype=int64), array([0, 1, 0, 1], dtype=int64))
4876+
4877+
>>> x[np.nonzero(x)]
4878+
array([3, 4, 5, 6])
4879+
>>> np.transpose(np.stack(np.nonzero(x)))
4880+
array([[0, 0],
4881+
[1, 1],
4882+
[2, 0],
4883+
[2, 1]], dtype=int64)
4884+
4885+
A common use for ``nonzero`` is to find the indices of an array, where
4886+
a condition is True. Given an array `a`, the condition `a` > 3 is a
4887+
boolean array and since False is interpreted as 0, np.nonzero(a > 3)
4888+
yields the indices of the `a` where the condition is true.
4889+
4890+
>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
4891+
>>> a > 3
4892+
array([[False, False, False],
4893+
[ True, True, True],
4894+
[ True, True, True]])
4895+
>>> np.nonzero(a > 3)
4896+
(array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64))
4897+
4898+
Using this result to index `a` is equivalent to using the mask directly:
4899+
4900+
>>> a[np.nonzero(a > 3)]
4901+
array([4, 5, 6, 7, 8, 9], dtype=int32)
4902+
>>> a[a > 3]
4903+
array([4, 5, 6, 7, 8, 9], dtype=int32)
4904+
4905+
``nonzero`` can also be called as a method of the array.
4906+
4907+
>>> (a > 3).nonzero()
4908+
(array([1, 1, 1, 2, 2, 2], dtype=int64), array([0, 1, 2, 0, 1, 2], dtype=int64))
4909+
"""
4910+
out = _npi.nonzero(a).transpose()
4911+
return tuple([out[i] for i in range(len(out))])

python/mxnet/ndarray/numpy/random.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from ..ndarray import NDArray
2424

2525

26-
__all__ = ['randint', 'uniform', 'normal', "choice", "rand"]
26+
__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial"]
2727

2828

2929
def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
30-
"""Return random integers from `low` (inclusive) to `high` (exclusive).
30+
r"""Return random integers from `low` (inclusive) to `high` (exclusive).
3131
3232
Return random integers from the "discrete uniform" distribution of
3333
the specified dtype in the "half-open" interval [`low`, `high`). If
@@ -88,7 +88,7 @@ def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
8888

8989

9090
def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
91-
"""Draw samples from a uniform distribution.
91+
r"""Draw samples from a uniform distribution.
9292
9393
Samples are uniformly distributed over the half-open interval
9494
``[low, high)`` (includes low, but excludes high). In other words,
@@ -143,7 +143,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
143143

144144

145145
def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
146-
"""Draw random samples from a normal (Gaussian) distribution.
146+
r"""Draw random samples from a normal (Gaussian) distribution.
147147
148148
Samples are distributed according to a normal distribution parametrized
149149
by *loc* (mean) and *scale* (standard deviation).
@@ -194,7 +194,7 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
194194

195195

196196
def multinomial(n, pvals, size=None):
197-
"""multinomial(n, pvals, size=None)
197+
r"""multinomial(n, pvals, size=None)
198198
199199
Draw samples from a multinomial distribution.
200200
@@ -246,7 +246,7 @@ def multinomial(n, pvals, size=None):
246246

247247

248248
def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
249-
"""Generates a random sample from a given 1-D array
249+
r"""Generates a random sample from a given 1-D array
250250
251251
Parameters
252252
-----------

python/mxnet/numpy/linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,35 @@ def norm(x, ord=None, axis=None, keepdims=False):
5454
n : float or ndarray
5555
Norm of the matrix or vector(s).
5656
57+
Notes
58+
-----
59+
This operator differs from NumPy in the aspect that it always returns a
60+
zero-dim tensor for the cases where Python float values are expected
61+
in NumPy.
62+
5763
References
5864
----------
5965
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
6066
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
67+
68+
Examples
69+
--------
70+
>>> from numpy import linalg as LA
71+
>>> a = np.arange(9) - 4
72+
>>> a
73+
array([-4., -3., -2., -1., 0., 1., 2., 3., 4.])
74+
>>> b = a.reshape((3, 3))
75+
>>> b
76+
array([[-4., -3., -2.],
77+
[-1., 0., 1.],
78+
[ 2., 3., 4.]])
79+
>>> LA.norm(a)
80+
array(7.745967)
81+
>>>
82+
>>> LA.norm(b)
83+
array(7.745967)
84+
>>> LA.norm(b, 'fro')
85+
array(7.745967)
6186
"""
6287
return _mx_nd_np.linalg.norm(x, ord, axis, keepdims)
6388

0 commit comments

Comments
 (0)