Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a2ce871

Browse files
author
Fan
committed
np compatible vstack
1 parent 57927a9 commit a2ce871

File tree

7 files changed

+525
-3
lines changed

7 files changed

+525
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...context import current_context
2626
from . import _internal as _npi
2727

28-
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
28+
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'vstack']
2929

3030

3131
@set_module('mxnet.ndarray.numpy')
@@ -293,3 +293,56 @@ def power(x1, x2, out=None):
293293
This is a scalar if both x1 and x2 are scalars.
294294
"""
295295
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)
296+
297+
298+
@set_module('mxnet.ndarray.numpy')
299+
def vstack(arrays):
300+
r"""vstack(tup)
301+
302+
Stack arrays in sequence vertically (row wise).
303+
304+
This is equivalent to concatenation along the first axis after 1-D arrays
305+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
306+
`vsplit`.
307+
308+
This function makes most sense for arrays with up to 3 dimensions. For
309+
instance, for pixel-data with a height (first axis), width (second axis),
310+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
311+
provide more general stacking and concatenation operations.
312+
313+
Parameters
314+
----------
315+
tup : sequence of ndarrays
316+
The arrays must have the same shape along all but the first axis.
317+
1-D arrays must have the same length.
318+
319+
Returns
320+
-------
321+
stacked : ndarray
322+
The array formed by stacking the given arrays, will be at least 2-D.
323+
324+
Examples
325+
--------
326+
>>> a = np.array([1, 2, 3])
327+
>>> b = np.array([2, 3, 4])
328+
>>> np.vstack((a, b))
329+
array([[1., 2., 3.],
330+
[2., 3., 4.]])
331+
332+
>>> a = np.array([[1], [2], [3]])
333+
>>> b = np.array([[2], [3], [4]])
334+
>>> np.vstack((a, b))
335+
array([[1.],
336+
[2.],
337+
[3.],
338+
[2.],
339+
[3.],
340+
[4.]])
341+
"""
342+
def get_list(arrays):
343+
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
344+
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
345+
return [arr for arr in arrays]
346+
347+
arrays = get_list(arrays)
348+
return _npi.vstack(*arrays)

python/mxnet/numpy/multiarray.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ..ndarray.numpy import _internal as _npi
4545

4646
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
47-
'mod', 'power']
47+
'mod', 'power', 'vstack']
4848

4949

5050
# This function is copied from ndarray.py since pylint
@@ -1549,3 +1549,50 @@ def power(x1, x2, out=None):
15491549
This is a scalar if both x1 and x2 are scalars.
15501550
"""
15511551
return _mx_nd_np.power(x1, x2, out=out)
1552+
1553+
1554+
@set_module('mxnet.numpy')
1555+
def vstack(arrays):
1556+
r"""vstack(tup)
1557+
1558+
Stack arrays in sequence vertically (row wise).
1559+
1560+
This is equivalent to concatenation along the first axis after 1-D arrays
1561+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
1562+
`vsplit`.
1563+
1564+
This function makes most sense for arrays with up to 3 dimensions. For
1565+
instance, for pixel-data with a height (first axis), width (second axis),
1566+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
1567+
provide more general stacking and concatenation operations.
1568+
1569+
Parameters
1570+
----------
1571+
tup : sequence of ndarrays
1572+
The arrays must have the same shape along all but the first axis.
1573+
1-D arrays must have the same length.
1574+
1575+
Returns
1576+
-------
1577+
stacked : ndarray
1578+
The array formed by stacking the given arrays, will be at least 2-D.
1579+
1580+
Examples
1581+
--------
1582+
>>> a = np.array([1, 2, 3])
1583+
>>> b = np.array([2, 3, 4])
1584+
>>> np.vstack((a, b))
1585+
array([[1., 2., 3.],
1586+
[2., 3., 4.]])
1587+
1588+
>>> a = np.array([[1], [2], [3]])
1589+
>>> b = np.array([[2], [3], [4]])
1590+
>>> np.vstack((a, b))
1591+
array([[1.],
1592+
[2.],
1593+
[3.],
1594+
[2.],
1595+
[3.],
1596+
[4.]])
1597+
"""
1598+
return _mx_nd_np.vstack(arrays)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .._internal import _set_np_symbol_class
2929
from . import _internal as _npi
3030

31-
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
31+
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'vstack']
3232

3333

3434
def _num_outputs(sym):
@@ -1010,4 +1010,39 @@ def power(x1, x2, out=None):
10101010
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)
10111011

10121012

1013+
@set_module('mxnet.symbol.numpy')
1014+
def vstack(arrays):
1015+
r"""vstack(tup)
1016+
1017+
Stack arrays in sequence vertically (row wise).
1018+
1019+
This is equivalent to concatenation along the first axis after 1-D arrays
1020+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
1021+
`vsplit`.
1022+
1023+
This function makes most sense for arrays with up to 3 dimensions. For
1024+
instance, for pixel-data with a height (first axis), width (second axis),
1025+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
1026+
provide more general stacking and concatenation operations.
1027+
1028+
Parameters
1029+
----------
1030+
tup : sequence of _Symbol
1031+
The arrays must have the same shape along all but the first axis.
1032+
1-D arrays must have the same length.
1033+
1034+
Returns
1035+
-------
1036+
stacked : _Symbol
1037+
The array formed by stacking the given arrays, will be at least 2-D.
1038+
"""
1039+
def get_list(arrays):
1040+
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
1041+
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
1042+
return [arr for arr in arrays]
1043+
1044+
arrays = get_list(arrays)
1045+
return _npi.vstack(*arrays)
1046+
1047+
10131048
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file np_matrix_op-inl.h
23+
* \brief Function definition of matrix related operators
24+
*/
25+
#ifndef MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_
26+
#define MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_
27+
28+
#include <vector>
29+
#include "../tensor/matrix_op-inl.h"
30+
#include "../nn/concat-inl.h"
31+
32+
namespace mxnet {
33+
namespace op {
34+
35+
struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
36+
int num_args;
37+
DMLC_DECLARE_PARAMETER(NumpyVstackParam) {
38+
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
39+
.describe("Number of inputs to be vstacked.");
40+
}
41+
};
42+
43+
template<typename xpu>
44+
void NumpyVstackForward(const nnvm::NodeAttrs& attrs,
45+
const OpContext& ctx,
46+
const std::vector<TBlob>& inputs,
47+
const std::vector<OpReqType>& req,
48+
const std::vector<TBlob>& outputs) {
49+
using namespace mshadow;
50+
using namespace mshadow_op;
51+
52+
const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
53+
CHECK_EQ(inputs.size(), param.num_args);
54+
CHECK_EQ(outputs.size(), 1);
55+
CHECK_EQ(req.size(), 1);
56+
57+
// reshape if necessary
58+
std::vector<TBlob> data(param.num_args);
59+
for (int i = 0; i < param.num_args; i++) {
60+
if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) {
61+
TShape shape = Shape2(1, inputs[i].shape_.Size());
62+
data[i] = inputs[i].reshape(shape);
63+
} else {
64+
data[i] = inputs[i];
65+
}
66+
}
67+
68+
// initialize ConcatOp
69+
ConcatParam cparam;
70+
cparam.num_args = param.num_args;
71+
cparam.dim = 0;
72+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
73+
ConcatOp<xpu, DType> op;
74+
op.Init(cparam);
75+
op.Forward(ctx, data, req, outputs);
76+
});
77+
}
78+
79+
template<typename xpu>
80+
void NumpyVstackBackward(const nnvm::NodeAttrs& attrs,
81+
const OpContext& ctx,
82+
const std::vector<TBlob>& inputs,
83+
const std::vector<OpReqType>& req,
84+
const std::vector<TBlob>& outputs) {
85+
using namespace mshadow;
86+
using namespace mshadow_op;
87+
88+
const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
89+
CHECK_EQ(inputs.size(), 1);
90+
CHECK_EQ(outputs.size(), param.num_args);
91+
CHECK_EQ(req.size(), param.num_args);
92+
93+
// reshape if necessary
94+
std::vector<TBlob> data(param.num_args);
95+
for (int i = 0; i < param.num_args; i++) {
96+
if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) {
97+
TShape shape = Shape2(1, outputs[i].shape_.Size());
98+
data[i] = outputs[i].reshape(shape);
99+
} else {
100+
data[i] = outputs[i];
101+
}
102+
}
103+
104+
// initialize ConcatOp
105+
ConcatParam cparam;
106+
cparam.num_args = param.num_args;
107+
cparam.dim = 0;
108+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
109+
ConcatOp<xpu, DType> op;
110+
op.Init(cparam);
111+
op.Backward(ctx, inputs[0], req, data);
112+
});
113+
}
114+
115+
} // namespace op
116+
} // namespace mxnet
117+
118+
#endif // MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_

0 commit comments

Comments
 (0)