Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2e30b6d

Browse files
author
Fan
committed
np compatible vstack
1 parent 3dacabe commit 2e30b6d

File tree

7 files changed

+419
-4
lines changed

7 files changed

+419
-4
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
35-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
35+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
3636
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
3737
'ravel']
3838

@@ -1990,6 +1990,57 @@ def get_list(arrays):
19901990
return _npi.stack(*arrays, axis=axis, out=out)
19911991

19921992

1993+
@set_module('mxnet.ndarray.numpy')
1994+
def vstack(arrays, out=None):
1995+
r"""Stack arrays in sequence vertically (row wise).
1996+
1997+
This is equivalent to concatenation along the first axis after 1-D arrays
1998+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
1999+
`vsplit`.
2000+
2001+
This function makes most sense for arrays with up to 3 dimensions. For
2002+
instance, for pixel-data with a height (first axis), width (second axis),
2003+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
2004+
provide more general stacking and concatenation operations.
2005+
2006+
Parameters
2007+
----------
2008+
tup : sequence of ndarrays
2009+
The arrays must have the same shape along all but the first axis.
2010+
1-D arrays must have the same length.
2011+
2012+
Returns
2013+
-------
2014+
stacked : ndarray
2015+
The array formed by stacking the given arrays, will be at least 2-D.
2016+
2017+
Examples
2018+
--------
2019+
>>> a = np.array([1, 2, 3])
2020+
>>> b = np.array([2, 3, 4])
2021+
>>> np.vstack((a, b))
2022+
array([[1., 2., 3.],
2023+
[2., 3., 4.]])
2024+
2025+
>>> a = np.array([[1], [2], [3]])
2026+
>>> b = np.array([[2], [3], [4]])
2027+
>>> np.vstack((a, b))
2028+
array([[1.],
2029+
[2.],
2030+
[3.],
2031+
[2.],
2032+
[3.],
2033+
[4.]])
2034+
"""
2035+
def get_list(arrays):
2036+
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
2037+
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
2038+
return [arr for arr in arrays]
2039+
2040+
arrays = get_list(arrays)
2041+
return _npi.vstack(*arrays)
2042+
2043+
19932044
@set_module('mxnet.ndarray.numpy')
19942045
def maximum(x1, x2, out=None):
19952046
"""Returns element-wise maximum of the input arrays with broadcasting.

python/mxnet/numpy/multiarray.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
56-
'ravel']
55+
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
56+
'copysign', 'ravel']
5757

5858
# Return code for dispatching indexing function call
5959
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3505,6 +3505,51 @@ def stack(arrays, axis=0, out=None):
35053505
return _mx_nd_np.stack(arrays, axis=axis, out=out)
35063506

35073507

3508+
@set_module('mxnet.numpy')
3509+
def vstack(arrays, out=None):
3510+
r"""Stack arrays in sequence vertically (row wise).
3511+
3512+
This is equivalent to concatenation along the first axis after 1-D arrays
3513+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
3514+
`vsplit`.
3515+
3516+
This function makes most sense for arrays with up to 3 dimensions. For
3517+
instance, for pixel-data with a height (first axis), width (second axis),
3518+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
3519+
provide more general stacking and concatenation operations.
3520+
3521+
Parameters
3522+
----------
3523+
tup : sequence of ndarrays
3524+
The arrays must have the same shape along all but the first axis.
3525+
1-D arrays must have the same length.
3526+
3527+
Returns
3528+
-------
3529+
stacked : ndarray
3530+
The array formed by stacking the given arrays, will be at least 2-D.
3531+
3532+
Examples
3533+
--------
3534+
>>> a = np.array([1, 2, 3])
3535+
>>> b = np.array([2, 3, 4])
3536+
>>> np.vstack((a, b))
3537+
array([[1., 2., 3.],
3538+
[2., 3., 4.]])
3539+
3540+
>>> a = np.array([[1], [2], [3]])
3541+
>>> b = np.array([[2], [3], [4]])
3542+
>>> np.vstack((a, b))
3543+
array([[1.],
3544+
[2.],
3545+
[3.],
3546+
[2.],
3547+
[3.],
3548+
[4.]])
3549+
"""
3550+
return _mx_nd_np.vstack(arrays)
3551+
3552+
35083553
@set_module('mxnet.numpy')
35093554
def maximum(x1, x2, out=None):
35103555
"""Returns element-wise maximum of the input arrays with broadcasting.

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
37-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
37+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
3838
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
3939
'ravel']
4040

@@ -2396,6 +2396,39 @@ def get_list(arrays):
23962396
return _npi.stack(*arrays, axis=axis, out=out)
23972397

23982398

2399+
@set_module('mxnet.symbol.numpy')
2400+
def vstack(arrays, out=None):
2401+
r"""Stack arrays in sequence vertically (row wise).
2402+
2403+
This is equivalent to concatenation along the first axis after 1-D arrays
2404+
of shape `(N,)` have been reshaped to `(1,N)`. Rebuilds arrays divided by
2405+
`vsplit`.
2406+
2407+
This function makes most sense for arrays with up to 3 dimensions. For
2408+
instance, for pixel-data with a height (first axis), width (second axis),
2409+
and r/g/b channels (third axis). The functions `concatenate` and `stack`
2410+
provide more general stacking and concatenation operations.
2411+
2412+
Parameters
2413+
----------
2414+
tup : sequence of _Symbol
2415+
The arrays must have the same shape along all but the first axis.
2416+
1-D arrays must have the same length.
2417+
2418+
Returns
2419+
-------
2420+
stacked : _Symbol
2421+
The array formed by stacking the given arrays, will be at least 2-D.
2422+
"""
2423+
def get_list(arrays):
2424+
if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
2425+
raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
2426+
return [arr for arr in arrays]
2427+
2428+
arrays = get_list(arrays)
2429+
return _npi.vstack(*arrays)
2430+
2431+
23992432
@set_module('mxnet.symbol.numpy')
24002433
def maximum(x1, x2, out=None):
24012434
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
4141
}
4242
};
4343

44+
struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
45+
int num_args;
46+
DMLC_DECLARE_PARAMETER(NumpyVstackParam) {
47+
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
48+
.describe("Number of inputs to be vstacked.");
49+
}
50+
};
51+
4452
template<typename xpu>
4553
void NumpyTranspose(const nnvm::NodeAttrs& attrs,
4654
const OpContext& ctx,
@@ -60,6 +68,78 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
6068
}
6169
}
6270

71+
template<typename xpu>
72+
void NumpyVstackForward(const nnvm::NodeAttrs& attrs,
73+
const OpContext& ctx,
74+
const std::vector<TBlob>& inputs,
75+
const std::vector<OpReqType>& req,
76+
const std::vector<TBlob>& outputs) {
77+
using namespace mshadow;
78+
using namespace mshadow_op;
79+
80+
const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
81+
CHECK_EQ(inputs.size(), param.num_args);
82+
CHECK_EQ(outputs.size(), 1);
83+
CHECK_EQ(req.size(), 1);
84+
85+
// reshape if necessary
86+
std::vector<TBlob> data(param.num_args);
87+
for (int i = 0; i < param.num_args; i++) {
88+
if (inputs[i].shape_.ndim() == 0 || inputs[i].shape_.ndim() == 1) {
89+
TShape shape = Shape2(1, inputs[i].shape_.Size());
90+
data[i] = inputs[i].reshape(shape);
91+
} else {
92+
data[i] = inputs[i];
93+
}
94+
}
95+
96+
// initialize ConcatOp
97+
ConcatParam cparam;
98+
cparam.num_args = param.num_args;
99+
cparam.dim = 0;
100+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
101+
ConcatOp<xpu, DType> op;
102+
op.Init(cparam);
103+
op.Forward(ctx, data, req, outputs);
104+
});
105+
}
106+
107+
template<typename xpu>
108+
void NumpyVstackBackward(const nnvm::NodeAttrs& attrs,
109+
const OpContext& ctx,
110+
const std::vector<TBlob>& inputs,
111+
const std::vector<OpReqType>& req,
112+
const std::vector<TBlob>& outputs) {
113+
using namespace mshadow;
114+
using namespace mshadow_op;
115+
116+
const NumpyVstackParam& param = nnvm::get<NumpyVstackParam>(attrs.parsed);
117+
CHECK_EQ(inputs.size(), 1);
118+
CHECK_EQ(outputs.size(), param.num_args);
119+
CHECK_EQ(req.size(), param.num_args);
120+
121+
// reshape if necessary
122+
std::vector<TBlob> data(param.num_args);
123+
for (int i = 0; i < param.num_args; i++) {
124+
if (outputs[i].shape_.ndim() == 0 || outputs[i].shape_.ndim() == 1) {
125+
TShape shape = Shape2(1, outputs[i].shape_.Size());
126+
data[i] = outputs[i].reshape(shape);
127+
} else {
128+
data[i] = outputs[i];
129+
}
130+
}
131+
132+
// initialize ConcatOp
133+
ConcatParam cparam;
134+
cparam.num_args = param.num_args;
135+
cparam.dim = 0;
136+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
137+
ConcatOp<xpu, DType> op;
138+
op.Init(cparam);
139+
op.Backward(ctx, inputs[0], req, data);
140+
});
141+
}
142+
63143
} // namespace op
64144
} // namespace mxnet
65145

0 commit comments

Comments
 (0)