Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 3c8ffc4

Browse files
author
Ying
committed
numpy operator ravel, derive from reshape
* it is the same as reshape(x, -1) * register reshape with prefix _npi_ * fix format error * edit examples in doc * fix error in review * add out in wrapper * remove out * test data type and add order
1 parent 287e3b5 commit 3c8ffc4

File tree

5 files changed

+212
-3
lines changed

5 files changed

+212
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
36-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']
3737

3838

3939
@set_module('mxnet.ndarray.numpy')
@@ -2432,3 +2432,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
24322432
else:
24332433
raise ValueError("The dimensions must be sequence of ints")
24342434
# pylint: enable=redefined-outer-name
2435+
2436+
2437+
@set_module('mxnet.ndarray.numpy')
2438+
def ravel(x, order='C'):
2439+
r"""
2440+
ravel(x)
2441+
2442+
Return a contiguous flattened array.
2443+
A 1-D array, containing the elements of the input, is returned. A copy is
2444+
made only if needed.
2445+
2446+
Parameters
2447+
----------
2448+
x : ndarray
2449+
Input array. The elements in `x` are read in row-major, C-style order and
2450+
packed as a 1-D array.
2451+
2452+
Returns
2453+
-------
2454+
y : ndarray
2455+
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
2456+
Note that matrices are special cased for backward compatibility, if `x`
2457+
is a matrix, then y is a 1-D ndarray.
2458+
2459+
Notes
2460+
-----
2461+
This function differs from the original numpy.arange in the following aspects:
2462+
- Only support row-major, C-style order.
2463+
2464+
Examples
2465+
--------
2466+
It is equivalent to ``reshape(x, -1)``.
2467+
2468+
>>> x = np.array([[1, 2, 3], [4, 5, 6]])
2469+
>>> print(np.ravel(x))
2470+
[1. 2. 3. 4. 5. 6.]
2471+
2472+
>>> print(x.reshape(-1))
2473+
[1. 2. 3. 4. 5. 6.]
2474+
2475+
>>> print(np.ravel(x.T))
2476+
[1. 4. 2. 5. 3. 6.]
2477+
"""
2478+
if order is not 'C':
2479+
raise NotImplementedError('order {} is not supported'.format(order))
2480+
if isinstance(x, numeric_types):
2481+
return _np.reshape(x, -1)
2482+
elif isinstance(x, NDArray):
2483+
return _npi.reshape(x, -1)
2484+
else:
2485+
raise TypeError('type {} not supported'.format(str(type(x))))

python/mxnet/numpy/multiarray.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
55+
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']
5656

5757
# Return code for dispatching indexing function call
5858
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3873,3 +3873,47 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
38733873
"""
38743874
return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
38753875
# pylint: enable=redefined-outer-name
3876+
3877+
3878+
@set_module('mxnet.numpy')
3879+
def ravel(x, order='C'):
3880+
r"""
3881+
ravel(x)
3882+
3883+
Return a contiguous flattened array.
3884+
A 1-D array, containing the elements of the input, is returned. A copy is
3885+
made only if needed.
3886+
3887+
Parameters
3888+
----------
3889+
x : ndarray
3890+
Input array. The elements in `x` are read in row-major, C-style order and
3891+
packed as a 1-D array.
3892+
3893+
Returns
3894+
-------
3895+
y : ndarray
3896+
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
3897+
Note that matrices are special cased for backward compatibility, if `x`
3898+
is a matrix, then y is a 1-D ndarray.
3899+
3900+
Notes
3901+
-----
3902+
This function differs from the original numpy.arange in the following aspects:
3903+
- Only support row-major, C-style order.
3904+
3905+
Examples
3906+
--------
3907+
It is equivalent to ``reshape(x, -1)``.
3908+
3909+
>>> x = np.array([[1, 2, 3], [4, 5, 6]])
3910+
>>> print(np.ravel(x))
3911+
[1. 2. 3. 4. 5. 6.]
3912+
3913+
>>> print(x.reshape(-1))
3914+
[1. 2. 3. 4. 5. 6.]
3915+
3916+
>>> print(np.ravel(x.T))
3917+
[1. 4. 2. 5. 3. 6.]
3918+
"""
3919+
return _mx_nd_np.ravel(x, order)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
38-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
38+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']
3939

4040

4141
def _num_outputs(sym):
@@ -2748,4 +2748,44 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
27482748
# pylint: enable=redefined-outer-name
27492749

27502750

2751+
@set_module('mxnet.symbol.numpy')
2752+
def ravel(x, order='C'):
2753+
r"""
2754+
ravel(x)
2755+
2756+
Return a contiguous flattened array.
2757+
A 1-D array, containing the elements of the input, is returned. A copy is
2758+
made only if needed.
2759+
2760+
Parameters
2761+
----------
2762+
x : ndarray
2763+
Input array. The elements in `x` are read in row-major, C-style order and
2764+
packed as a 1-D array.
2765+
out : ndarray or None, optional
2766+
A location into which the result is stored. If not provided or `None`,
2767+
a freshly-allocated array is returned.
2768+
2769+
Returns
2770+
-------
2771+
y : ndarray
2772+
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
2773+
Note that matrices are special cased for backward compatibility, if `x`
2774+
is a matrix, then y is a 1-D ndarray.
2775+
2776+
Notes
2777+
-----
2778+
This function differs from the original numpy.arange in the following aspects:
2779+
- Only support row-major, C-style order.
2780+
"""
2781+
if order is not 'C':
2782+
raise NotImplementedError('order {} is not supported'.format(order))
2783+
if isinstance(x, numeric_types):
2784+
return _np.reshape(x, -1)
2785+
elif isinstance(x, _Symbol):
2786+
return _npi.reshape(x, -1)
2787+
else:
2788+
raise TypeError('type {} not supported'.format(str(type(x))))
2789+
2790+
27512791
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_matrix_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
163163

164164
NNVM_REGISTER_OP(_np_reshape)
165165
.describe(R"code()code" ADD_FILELINE)
166+
.add_alias("_npi_reshape")
166167
.set_num_inputs(1)
167168
.set_num_outputs(1)
168169
.set_attr_parser(ParamParser<NumpyReshapeParam>)

tests/python/unittest/test_numpy_op.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,39 @@ def hybrid_forward(self, F, a, *args):
13241324
assert same(mx_out.asnumpy(), np_out)
13251325

13261326

1327+
@with_seed()
1328+
@use_np
1329+
def test_np_ravel():
1330+
class TestRavel(HybridBlock):
1331+
def __init__(self):
1332+
super(TestRavel, self).__init__()
1333+
1334+
def hybrid_forward(self, F, a):
1335+
return F.np.ravel(a)
1336+
1337+
types = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
1338+
for oneType in types:
1339+
for hybridize in [True, False]:
1340+
for shape in [(), (2,), (2, 2), (1, 2, 3), (3, 0), (1, 0, 2)]:
1341+
test_ravel = TestRavel()
1342+
if hybridize:
1343+
test_ravel.hybridize()
1344+
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
1345+
x.attach_grad()
1346+
np_out = _np.ravel(x.asnumpy())
1347+
with mx.autograd.record():
1348+
mx_out = test_ravel(x)
1349+
assert mx_out.shape == np_out.shape
1350+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1351+
mx_out.backward()
1352+
np_backward = _np.ones(shape)
1353+
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
1354+
1355+
mx_out = np.ravel(x)
1356+
np_out = _np.ravel(x.asnumpy())
1357+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1358+
1359+
13271360
@with_seed()
13281361
@use_np
13291362
def test_np_randint():
@@ -1790,6 +1823,46 @@ def hybrid_forward(self, F, x):
17901823
assert mx_out.shape == np_out.shape
17911824

17921825

1826+
@with_seed()
1827+
@use_np
1828+
def test_np_ravel():
1829+
class TestRavel(HybridBlock):
1830+
def __init__(self):
1831+
super(TestRavel, self).__init__()
1832+
1833+
def hybrid_forward(self, F, a):
1834+
return F.np.ravel(a)
1835+
1836+
types = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
1837+
for oneType in types:
1838+
for hybridize in [True, False]:
1839+
for shape in [(),
1840+
(2,),
1841+
(2, 2),
1842+
(1, 2, 3),
1843+
(3, 0),
1844+
(1, 0, 2)
1845+
]:
1846+
test_ravel = TestRavel()
1847+
if hybridize:
1848+
test_ravel.hybridize()
1849+
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
1850+
x.attach_grad()
1851+
np_out = _np.ravel(x.asnumpy())
1852+
with mx.autograd.record():
1853+
mx_out = test_ravel(x)
1854+
assert mx_out.shape == np_out.shape
1855+
assert mx_out.dtype == np_out.dtype
1856+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1857+
mx_out.backward()
1858+
np_backward = _np.ones(shape)
1859+
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
1860+
1861+
mx_out = np.ravel(x)
1862+
np_out = _np.ravel(x.asnumpy())
1863+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
1864+
1865+
17931866
if __name__ == '__main__':
17941867
import nose
17951868
nose.runmodule()

0 commit comments

Comments
 (0)