Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2d86c70

Browse files
reminiscehaojin2
authored andcommitted
Port ops from np branch (#16018)
* Add maximum, minimum, swapaxes, argmax, clip in python * Add backend * Fix pylint * Add unit test decorators back * Fix gpu compile * Add np.random.normal and npx.seed * Expose seed through npx.random * Add rtol atol in seed testing
1 parent 3f7b6ee commit 2d86c70

16 files changed

+1050
-104
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
35-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']
35+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack',
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax']
3637

3738

3839
@set_module('mxnet.ndarray.numpy')
@@ -1960,3 +1961,187 @@ def get_list(arrays):
19601961

19611962
arrays = get_list(arrays)
19621963
return _npi.stack(*arrays, axis=axis, out=out)
1964+
1965+
1966+
@set_module('mxnet.ndarray.numpy')
1967+
def maximum(x1, x2, out=None):
1968+
"""Returns element-wise maximum of the input arrays with broadcasting.
1969+
1970+
Parameters
1971+
----------
1972+
x1, x2 : scalar or mxnet.numpy.ndarray
1973+
The arrays holding the elements to be compared. They must have the same shape,
1974+
or shapes that can be broadcast to a single shape.
1975+
1976+
Returns
1977+
-------
1978+
out : mxnet.numpy.ndarray or scalar
1979+
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
1980+
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
1981+
1982+
1983+
@set_module('mxnet.ndarray.numpy')
1984+
def minimum(x1, x2, out=None):
1985+
"""Returns element-wise minimum of the input arrays with broadcasting.
1986+
1987+
Parameters
1988+
----------
1989+
x1, x2 : scalar or mxnet.numpy.ndarray
1990+
The arrays holding the elements to be compared. They must have the same shape,
1991+
or shapes that can be broadcast to a single shape.
1992+
1993+
Returns
1994+
-------
1995+
out : mxnet.numpy.ndarray or scalar
1996+
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
1997+
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
1998+
1999+
2000+
@set_module('mxnet.ndarray.numpy')
2001+
def swapaxes(a, axis1, axis2):
2002+
"""Interchange two axes of an array.
2003+
2004+
Parameters
2005+
----------
2006+
a : ndarray
2007+
Input array.
2008+
axis1 : int
2009+
First axis.
2010+
axis2 : int
2011+
Second axis.
2012+
2013+
Returns
2014+
-------
2015+
a_swapped : ndarray
2016+
Swapped array. This is always a copy of the input array.
2017+
"""
2018+
return _npi.swapaxes(a, dim1=axis1, dim2=axis2)
2019+
2020+
2021+
@set_module('mxnet.ndarray.numpy')
2022+
def clip(a, a_min, a_max, out=None):
2023+
"""clip(a, a_min, a_max, out=None)
2024+
2025+
Clip (limit) the values in an array.
2026+
Given an interval, values outside the interval are clipped to
2027+
the interval edges. For example, if an interval of ``[0, 1]``
2028+
is specified, values smaller than 0 become 0, and values larger
2029+
than 1 become 1.
2030+
2031+
Parameters
2032+
----------
2033+
a : ndarray
2034+
Array containing elements to clip.
2035+
a_min : scalar or `None`
2036+
Minimum value. If `None`, clipping is not performed on lower
2037+
interval edge. Not more than one of `a_min` and `a_max` may be
2038+
`None`.
2039+
a_max : scalar or `None`
2040+
Maximum value. If `None`, clipping is not performed on upper
2041+
interval edge. Not more than one of `a_min` and `a_max` may be
2042+
`None`.
2043+
out : ndarray, optional
2044+
The results will be placed in this array. It may be the input
2045+
array for in-place clipping. `out` must be of the right shape
2046+
to hold the output. Its type is preserved.
2047+
2048+
Returns
2049+
-------
2050+
clipped_array : ndarray
2051+
An array with the elements of `a`, but where values
2052+
< `a_min` are replaced with `a_min`, and those > `a_max`
2053+
with `a_max`.
2054+
2055+
Notes
2056+
-----
2057+
array_like `a_min` and `a_max` are not supported.
2058+
2059+
Examples
2060+
--------
2061+
>>> a = np.arange(10)
2062+
>>> np.clip(a, 1, 8)
2063+
array([1., 1., 2., 3., 4., 5., 6., 7., 8., 8.], dtype=float32)
2064+
>>> a
2065+
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
2066+
>>> np.clip(a, 3, 6, out=a)
2067+
array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32)
2068+
"""
2069+
if a_min is None and a_max is None:
2070+
raise ValueError('array_clip: must set either max or min')
2071+
if a_min is None:
2072+
a_min = float('-inf')
2073+
if a_max is None:
2074+
a_max = float('inf')
2075+
return _npi.clip(a, a_min, a_max, out=out)
2076+
2077+
2078+
@set_module('mxnet.ndarray.numpy')
2079+
def argmax(a, axis=None, out=None):
2080+
r"""
2081+
argmax(a, axis=None, out=None)
2082+
2083+
Returns the indices of the maximum values along an axis.
2084+
2085+
Parameters
2086+
----------
2087+
a : ndarray
2088+
Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
2089+
axis : int, optional
2090+
By default, the index is into the flattened array, otherwise
2091+
along the specified axis.
2092+
out : ndarray or None, optional
2093+
A location into which the result is stored.
2094+
If provided, it must have the same shape and dtype as input ndarray.
2095+
If not provided or `None`, a freshly-allocated array is returned.
2096+
2097+
Returns
2098+
-------
2099+
index_array : ndarray of indices whose dtype is same as the input ndarray.
2100+
Array of indices into the array. It has the same shape as `a.shape`
2101+
with the dimension along `axis` removed.
2102+
2103+
Notes
2104+
-----
2105+
In case of multiple occurrences of the maximum values, the indices
2106+
corresponding to the first occurrence are returned.
2107+
2108+
This function differs from the original `numpy.argmax
2109+
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html>`_ in
2110+
the following aspects:
2111+
2112+
- Input type does not support Python native iterables(list, tuple, ...).
2113+
- Output has dtype that is same as the input ndarray.
2114+
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
2115+
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
2116+
- ``out`` param does not support scalar input case.
2117+
2118+
Examples
2119+
--------
2120+
>>> a = np.arange(6).reshape(2,3) + 10
2121+
>>> a
2122+
array([[10., 11., 12.],
2123+
[13., 14., 15.]])
2124+
>>> np.argmax(a)
2125+
array(5.)
2126+
>>> np.argmax(a, axis=0)
2127+
array([1., 1., 1.])
2128+
>>> np.argmax(a, axis=1)
2129+
array([2., 2.])
2130+
2131+
>>> b = np.arange(6)
2132+
>>> b[1] = 5
2133+
>>> b
2134+
array([0., 5., 2., 3., 4., 5.])
2135+
>>> np.argmax(b) # Only the first occurrence is returned.
2136+
array(1.)
2137+
2138+
Specify ``out`` ndarray:
2139+
2140+
>>> a = np.arange(6).reshape(2,3) + 10
2141+
>>> b = np.zeros((2,))
2142+
>>> np.argmax(a, axis=1, out=b)
2143+
array([2., 2.])
2144+
>>> b
2145+
array([2., 2.])
2146+
"""
2147+
return _npi.argmax(a, axis=axis, keepdims=False, out=out)

python/mxnet/ndarray/numpy/random.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from __future__ import absolute_import
2020
from ...context import current_context
2121
from . import _internal as _npi
22+
from ...base import numeric_types
2223

2324

24-
__all__ = ['randint', 'uniform']
25+
__all__ = ['randint', 'uniform', 'normal']
2526

2627

2728
def randint(low, high=None, size=None, dtype=None, **kwargs):
@@ -141,5 +142,50 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
141142
return _npi.uniform(low=low, high=high, size=size,
142143
ctx=ctx, dtype=dtype, out=out)
143144

144-
raise ValueError(
145-
"Distribution parameters must be either mxnet.numpy.ndarray or numbers")
145+
146+
def normal(loc=0.0, scale=1.0, size=None, **kwargs):
147+
"""Draw random samples from a normal (Gaussian) distribution.
148+
149+
Samples are distributed according to a normal distribution parametrized
150+
by *loc* (mean) and *scale* (standard deviation).
151+
152+
153+
Parameters
154+
----------
155+
loc : float, optional
156+
Mean (centre) of the distribution.
157+
scale : float, optional
158+
Standard deviation (spread or "width") of the distribution.
159+
size : int or tuple of ints, optional
160+
Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
161+
samples are drawn. If size is `None` (default), a scalar tensor containing
162+
a single value is returned if loc and scale are both scalars.
163+
dtype : {'float16', 'float32', 'float64'}, optional
164+
Data type of output samples. Default is 'float32'
165+
ctx : Context, optional
166+
Device context of output. Default is current context.
167+
out : ``ndarray``, optional
168+
Store output to an existing ``ndarray``.
169+
170+
Returns
171+
-------
172+
out : ndarray
173+
Drawn samples from the parameterized normal distribution.
174+
175+
Notes
176+
-----
177+
This function currently does not support ``loc`` and ``scale`` as ndarrays.
178+
"""
179+
dtype = kwargs.pop('dtype', None)
180+
if dtype is None:
181+
dtype = 'float32'
182+
ctx = kwargs.pop('ctx', None)
183+
if ctx is None:
184+
ctx = current_context()
185+
out = kwargs.pop('out', None)
186+
if size is None and out is None:
187+
size = ()
188+
if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)):
189+
raise NotImplementedError('np.random.normal only supports loc and scale of '
190+
'numeric types for now')
191+
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)

0 commit comments

Comments
 (0)