Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 74836b8

Browse files
author
Ying
committed
Numpy flip operator
* Implement flip * fix some bug and add gpu test * register param and edit test * add testcase for backward * remove print * optimize 0-dim and 0-shape * adjust format and add doc in _symbol.py * fix bug in symbol * add flip in __all__ * fix format error * import ndarray * move flip implementation to np_matrix_op and remove test in gpu
1 parent 40593c6 commit 74836b8

File tree

8 files changed

+388
-3
lines changed

8 files changed

+388
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..ndarray import NDArray
2828

2929
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
30-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
30+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'flip']
3131

3232

3333
@set_module('mxnet.ndarray.numpy')
@@ -705,3 +705,71 @@ def concatenate(seq, axis=0, out=None):
705705
The concatenated array.
706706
"""
707707
return _npi.concatenate(*seq, dim=axis, out=out)
708+
709+
710+
@set_module('mxnet.ndarray.numpy')
711+
def flip(x, axis=None, out=None, **kwargs):
712+
r"""
713+
flip(x, axis=None, out=None)
714+
715+
Reverse the order of elements in an array along the given axis.
716+
717+
The shape of the array is preserved, but the elements are reordered.
718+
719+
Parameters
720+
----------
721+
m : ndarray or scalar
722+
Input array.
723+
axis : None or int or tuple of ints, optional
724+
Axis or axes along which to flip over. The default,
725+
axis=None, will flip over all of the axes of the input array.
726+
If axis is negative it counts from the last to the first axis.
727+
728+
If axis is a tuple of ints, flipping is performed on all of the axes
729+
specified in the tuple.
730+
out : ndarray or scalar, optional
731+
Alternative output array in which to place the result. It must have
732+
the same shape and type as the expected output.
733+
734+
Returns
735+
-------
736+
out : ndarray or scalar
737+
A view of `m` with the entries of axis reversed. Since a view is
738+
returned, this operation is done in constant time.
739+
740+
Examples
741+
--------
742+
>>> A = np.arange(8).reshape((2,2,2))
743+
>>> A
744+
array([[[0, 1],
745+
[2, 3]],
746+
[[4, 5],
747+
[6, 7]]])
748+
>>> np.flip(A, 0)
749+
array([[[4, 5],
750+
[6, 7]],
751+
[[0, 1],
752+
[2, 3]]])
753+
>>> np.flip(A, 1)
754+
array([[[2, 3],
755+
[0, 1]],
756+
[[6, 7],
757+
[4, 5]]])
758+
>>> np.flip(A)
759+
array([[[7, 6],
760+
[5, 4]],
761+
[[3, 2],
762+
[1, 0]]])
763+
>>> np.flip(A, (0, 2))
764+
array([[[5, 4],
765+
[7, 6]],
766+
[[1, 0],
767+
[3, 2]]])
768+
"""
769+
from ...numpy import ndarray
770+
if isinstance(x, numeric_types):
771+
return _np.flip(x, axis, **kwargs)
772+
elif isinstance(x, ndarray):
773+
return _npi.flip(x, axis, out=out, **kwargs)
774+
else:
775+
raise TypeError('type {} not supported'.format(str(type(x))))

python/mxnet/numpy/multiarray.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
4747
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split',
48-
'concatenate']
48+
'concatenate', 'flip']
4949

5050

5151
# This function is copied from ndarray.py since pylint
@@ -1877,3 +1877,65 @@ def concatenate(seq, axis=0, out=None):
18771877
The concatenated array.
18781878
"""
18791879
return _mx_nd_np.concatenate(seq, axis=axis, out=out)
1880+
1881+
1882+
@set_module('mxnet.numpy')
1883+
def flip(x, axis=None, out=None, **kwargs):
1884+
r"""
1885+
flip(x, axis=None, out=None)
1886+
1887+
Reverse the order of elements in an array along the given axis.
1888+
1889+
The shape of the array is preserved, but the elements are reordered.
1890+
1891+
Parameters
1892+
----------
1893+
m : ndarray or scalar
1894+
Input array.
1895+
axis : None or int or tuple of ints, optional
1896+
Axis or axes along which to flip over. The default,
1897+
axis=None, will flip over all of the axes of the input array.
1898+
If axis is negative it counts from the last to the first axis.
1899+
1900+
If axis is a tuple of ints, flipping is performed on all of the axes
1901+
specified in the tuple.
1902+
out : ndarray or scalar, optional
1903+
Alternative output array in which to place the result. It must have
1904+
the same shape and type as the expected output.
1905+
1906+
Returns
1907+
-------
1908+
out : ndarray or scalar
1909+
A view of `m` with the entries of axis reversed. Since a view is
1910+
returned, this operation is done in constant time.
1911+
1912+
Examples
1913+
--------
1914+
>>> A = np.arange(8).reshape((2,2,2))
1915+
>>> A
1916+
array([[[0, 1],
1917+
[2, 3]],
1918+
[[4, 5],
1919+
[6, 7]]])
1920+
>>> np.flip(A, 0)
1921+
array([[[4, 5],
1922+
[6, 7]],
1923+
[[0, 1],
1924+
[2, 3]]])
1925+
>>> np.flip(A, 1)
1926+
array([[[2, 3],
1927+
[0, 1]],
1928+
[[6, 7],
1929+
[4, 5]]])
1930+
>>> np.flip(A)
1931+
array([[[7, 6],
1932+
[5, 4]],
1933+
[[3, 2],
1934+
[1, 0]]])
1935+
>>> np.flip(A, (0, 2))
1936+
array([[[5, 4],
1937+
[7, 6]],
1938+
[[1, 0],
1939+
[3, 2]]])
1940+
"""
1941+
return _mx_nd_np.flip(x, axis, out=out, **kwargs)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from . import _internal as _npi
3131

3232
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
33-
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
33+
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'flip']
3434

3535

3636
def _num_outputs(sym):
@@ -1335,4 +1335,42 @@ def concatenate(seq, axis=0, out=None):
13351335
return _npi.concatenate(*seq, dim=axis, out=out)
13361336

13371337

1338+
@set_module('mxnet.symbol.numpy')
1339+
def flip(x, axis=None, out=None, **kwargs):
1340+
r"""
1341+
flip(x, axis=None, out=None)
1342+
1343+
Reverse the order of elements in an array along the given axis.
1344+
1345+
The shape of the array is preserved, but the elements are reordered.
1346+
1347+
Parameters
1348+
----------
1349+
m : _Symbol or scalar
1350+
Input array.
1351+
axis : None or int or tuple of ints, optional
1352+
Axis or axes along which to flip over. The default,
1353+
axis=None, will flip over all of the axes of the input array.
1354+
If axis is negative it counts from the last to the first axis.
1355+
1356+
If axis is a tuple of ints, flipping is performed on all of the axes
1357+
specified in the tuple.
1358+
out : _Symbol or scalar, optional
1359+
Alternative output array in which to place the result. It must have
1360+
the same shape and type as the expected output.
1361+
1362+
Returns
1363+
-------
1364+
out : _Symbol or scalar
1365+
A view of `m` with the entries of axis reversed. Since a view is
1366+
returned, this operation is done in constant time.
1367+
"""
1368+
if isinstance(x, numeric_types):
1369+
return _np.flip(x, axis, **kwargs)
1370+
elif isinstance(x, _Symbol):
1371+
return _npi.flip(x, axis, out=out, **kwargs)
1372+
else:
1373+
raise TypeError('type {} not supported'.format(str(type(x))))
1374+
1375+
13381376
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,82 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
6060
}
6161
}
6262

63+
struct FlipParam : public dmlc::Parameter<FlipParam> {
64+
mxnet::Tuple<int> axis;
65+
DMLC_DECLARE_PARAMETER(FlipParam) {
66+
DMLC_DECLARE_FIELD(axis)
67+
.describe("The axis which to flip elements.");
68+
}
69+
};
70+
71+
struct flip0dim_shared_kernel {
72+
template<typename DType>
73+
MSHADOW_XINLINE static void Map(int i,
74+
DType* out_data,
75+
const DType* in_data) {
76+
out_data[i] = in_data[i];
77+
}
78+
};
79+
80+
#define FLIP_MAX_DIM 10
81+
#define FLIP_MIN_DIM -1
82+
83+
template<typename xpu>
84+
void NumpyFlipForwardImpl(const OpContext& ctx,
85+
const std::vector<TBlob>& inputs,
86+
const std::vector<TBlob>& outputs,
87+
const std::vector<index_t>& stride_,
88+
const std::vector<index_t>& trailing_,
89+
const index_t& flip_index);
90+
91+
92+
template<typename xpu>
93+
void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
94+
const OpContext& ctx,
95+
const std::vector<TBlob>& inputs,
96+
const std::vector<OpReqType>& req,
97+
const std::vector<TBlob>& outputs) {
98+
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
99+
mxnet::Tuple<int> axistemp;
100+
CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_);
101+
CHECK_LT(param.axis.ndim(), FLIP_MAX_DIM);
102+
CHECK_GE(param.axis.ndim(), FLIP_MIN_DIM);
103+
if (param.axis.ndim() == FLIP_MIN_DIM) {
104+
if (inputs[0].shape_.ndim() == 0) {
105+
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
106+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
107+
mxnet_op::Kernel<flip0dim_shared_kernel, xpu>::Launch(s, inputs[0].Size(),
108+
outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
109+
});
110+
return;
111+
}
112+
std::vector<int> temp;
113+
for (int i = 0; i < inputs[0].shape_.ndim(); i++) {
114+
temp.push_back(i);
115+
}
116+
axistemp.assign(temp.begin(), temp.end());
117+
} else {
118+
axistemp = param.axis;
119+
}
120+
121+
const mxnet::TShape& ishape = inputs[0].shape_;
122+
if (ishape.ProdShape(0, ishape.ndim()) == 0) {
123+
return; // zero shape
124+
}
125+
std::vector<index_t> stride_(axistemp.ndim());
126+
std::vector<index_t> trailing_(axistemp.ndim());
127+
index_t flip_index = 0;
128+
for (int axis : axistemp) {
129+
CHECK_LT(axis, ishape.ndim());
130+
stride_[flip_index] = ishape[axis];
131+
trailing_[flip_index] = 1;
132+
for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) {
133+
trailing_[flip_index] *= ishape[i2];
134+
}
135+
flip_index++;
136+
}
137+
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
138+
}
63139
} // namespace op
64140
} // namespace mxnet
65141

src/operator/numpy/np_matrix_op.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,5 +304,51 @@ NNVM_REGISTER_OP(_backward_np_concat)
304304
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
305305
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
306306

307+
template<>
308+
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
309+
const std::vector<TBlob>& inputs,
310+
const std::vector<TBlob>& outputs,
311+
const std::vector<index_t>& stride_,
312+
const std::vector<index_t>& trailing_,
313+
const index_t& flip_index) {
314+
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
315+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
316+
mxnet_op::Kernel<reverse, cpu>::Launch(s, inputs[0].Size(), flip_index,
317+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
318+
stride_.data(), trailing_.data());
319+
});
320+
}
321+
322+
DMLC_REGISTER_PARAMETER(FlipParam);
323+
324+
NNVM_REGISTER_OP(_npi_flip)
325+
.set_num_outputs(1)
326+
.set_num_inputs(1)
327+
.set_attr_parser(ParamParser<FlipParam>)
328+
.set_attr<nnvm::FListInputNames>("FListInputNames",
329+
[](const NodeAttrs& attrs) {
330+
return std::vector<std::string> {"data"};
331+
})
332+
.set_attr<FResourceRequest>("FResourceRequest",
333+
[](const NodeAttrs& attrs) {
334+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
335+
})
336+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
337+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
338+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>)
339+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_flip"})
340+
.add_argument("data", "NDArray-or-Symbol", "Input data array")
341+
.add_arguments(FlipParam::__FIELDS__());
342+
343+
NNVM_REGISTER_OP(_backward_npi_flip)
344+
.set_num_inputs(1)
345+
.set_num_outputs(1)
346+
.set_attr_parser(ParamParser<FlipParam>)
347+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
348+
.set_attr<FResourceRequest>("FResourceRequest",
349+
[](const NodeAttrs& attrs) {
350+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
351+
})
352+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>);
307353
} // namespace op
308354
} // namespace mxnet

src/operator/numpy/np_matrix_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,39 @@ NNVM_REGISTER_OP(_npi_concatenate)
4242
NNVM_REGISTER_OP(_backward_np_concat)
4343
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
4444

45+
template<>
46+
void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
47+
const std::vector<TBlob>& inputs,
48+
const std::vector<TBlob>& outputs,
49+
const std::vector<index_t>& stride_,
50+
const std::vector<index_t>& trailing_,
51+
const index_t& flip_index) {
52+
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
53+
mshadow::Tensor<gpu, 1, uint8_t> workspace =
54+
ctx.requested[0].get_space_typed<gpu, 1, uint8_t>(
55+
mshadow::Shape1(flip_index * sizeof(index_t) * 2), s);
56+
57+
auto stride_workspace = workspace.dptr_;
58+
auto trailing_workspace = workspace.dptr_ + flip_index * sizeof(index_t);
59+
60+
cudaMemcpyAsync(stride_workspace, thrust::raw_pointer_cast(stride_.data()),
61+
stride_.size() * sizeof(index_t),
62+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
63+
cudaMemcpyAsync(trailing_workspace, thrust::raw_pointer_cast(trailing_.data()),
64+
trailing_.size() * sizeof(index_t),
65+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
66+
67+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
68+
mxnet_op::Kernel<reverse, gpu>::Launch(s, inputs[0].Size(), flip_index,
69+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
70+
reinterpret_cast<index_t*>(stride_workspace), reinterpret_cast<index_t*>(trailing_workspace));
71+
});
72+
}
73+
74+
NNVM_REGISTER_OP(_npi_flip)
75+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
76+
77+
NNVM_REGISTER_OP(_backward_npi_flip)
78+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
4579
} // namespace op
4680
} // namespace mxnet

tests/python/gpu/test_operator_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,7 @@ def test_math():
23312331
for op in ops:
23322332
run_math(op, shape, dtype, check_value=check_value)
23332333

2334+
23342335
if __name__ == '__main__':
23352336
import nose
23362337
nose.runmodule()

0 commit comments

Comments
 (0)