Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 554517b

Browse files
author
Ying
committed
Numpy flip operator
* Implement flip * fix some bug and add gpu test * register param and edit test * add testcase for backward * remove print * optimize 0-dim and 0-shape * adjust format and add doc in _symbol.py * fix bug in symbol * add flip in __all__ * fix format error * import ndarray * move flip implementation to np_matrix_op and remove test in gpu * delate redundant blank line
1 parent 9675a2d commit 554517b

File tree

7 files changed

+386
-3
lines changed

7 files changed

+386
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
36-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
36+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']
3737

3838

3939
@set_module('mxnet.ndarray.numpy')
@@ -2363,3 +2363,71 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
23632363
0.2025
23642364
"""
23652365
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
2366+
2367+
2368+
@set_module('mxnet.ndarray.numpy')
2369+
def flip(x, axis=None, out=None, **kwargs):
2370+
r"""
2371+
flip(x, axis=None, out=None)
2372+
2373+
Reverse the order of elements in an array along the given axis.
2374+
2375+
The shape of the array is preserved, but the elements are reordered.
2376+
2377+
Parameters
2378+
----------
2379+
m : ndarray or scalar
2380+
Input array.
2381+
axis : None or int or tuple of ints, optional
2382+
Axis or axes along which to flip over. The default,
2383+
axis=None, will flip over all of the axes of the input array.
2384+
If axis is negative it counts from the last to the first axis.
2385+
2386+
If axis is a tuple of ints, flipping is performed on all of the axes
2387+
specified in the tuple.
2388+
out : ndarray or scalar, optional
2389+
Alternative output array in which to place the result. It must have
2390+
the same shape and type as the expected output.
2391+
2392+
Returns
2393+
-------
2394+
out : ndarray or scalar
2395+
A view of `m` with the entries of axis reversed. Since a view is
2396+
returned, this operation is done in constant time.
2397+
2398+
Examples
2399+
--------
2400+
>>> A = np.arange(8).reshape((2,2,2))
2401+
>>> A
2402+
array([[[0, 1],
2403+
[2, 3]],
2404+
[[4, 5],
2405+
[6, 7]]])
2406+
>>> np.flip(A, 0)
2407+
array([[[4, 5],
2408+
[6, 7]],
2409+
[[0, 1],
2410+
[2, 3]]])
2411+
>>> np.flip(A, 1)
2412+
array([[[2, 3],
2413+
[0, 1]],
2414+
[[6, 7],
2415+
[4, 5]]])
2416+
>>> np.flip(A)
2417+
array([[[7, 6],
2418+
[5, 4]],
2419+
[[3, 2],
2420+
[1, 0]]])
2421+
>>> np.flip(A, (0, 2))
2422+
array([[[5, 4],
2423+
[7, 6]],
2424+
[[1, 0],
2425+
[3, 2]]])
2426+
"""
2427+
from ...numpy import ndarray
2428+
if isinstance(x, numeric_types):
2429+
return _np.flip(x, axis, **kwargs)
2430+
elif isinstance(x, ndarray):
2431+
return _npi.flip(x, axis, out=out, **kwargs)
2432+
else:
2433+
raise TypeError('type {} not supported'.format(str(type(x))))

python/mxnet/numpy/multiarray.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
55-
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
55+
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']
5656

5757
# Return code for dispatching indexing function call
5858
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3808,3 +3808,65 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
38083808
0.2025
38093809
"""
38103810
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
3811+
3812+
3813+
@set_module('mxnet.numpy')
3814+
def flip(x, axis=None, out=None, **kwargs):
3815+
r"""
3816+
flip(x, axis=None, out=None)
3817+
3818+
Reverse the order of elements in an array along the given axis.
3819+
3820+
The shape of the array is preserved, but the elements are reordered.
3821+
3822+
Parameters
3823+
----------
3824+
m : ndarray or scalar
3825+
Input array.
3826+
axis : None or int or tuple of ints, optional
3827+
Axis or axes along which to flip over. The default,
3828+
axis=None, will flip over all of the axes of the input array.
3829+
If axis is negative it counts from the last to the first axis.
3830+
3831+
If axis is a tuple of ints, flipping is performed on all of the axes
3832+
specified in the tuple.
3833+
out : ndarray or scalar, optional
3834+
Alternative output array in which to place the result. It must have
3835+
the same shape and type as the expected output.
3836+
3837+
Returns
3838+
-------
3839+
out : ndarray or scalar
3840+
A view of `m` with the entries of axis reversed. Since a view is
3841+
returned, this operation is done in constant time.
3842+
3843+
Examples
3844+
--------
3845+
>>> A = np.arange(8).reshape((2,2,2))
3846+
>>> A
3847+
array([[[0, 1],
3848+
[2, 3]],
3849+
[[4, 5],
3850+
[6, 7]]])
3851+
>>> np.flip(A, 0)
3852+
array([[[4, 5],
3853+
[6, 7]],
3854+
[[0, 1],
3855+
[2, 3]]])
3856+
>>> np.flip(A, 1)
3857+
array([[[2, 3],
3858+
[0, 1]],
3859+
[[6, 7],
3860+
[4, 5]]])
3861+
>>> np.flip(A)
3862+
array([[[7, 6],
3863+
[5, 4]],
3864+
[[3, 2],
3865+
[1, 0]]])
3866+
>>> np.flip(A, (0, 2))
3867+
array([[[5, 4],
3868+
[7, 6]],
3869+
[[1, 0],
3870+
[3, 2]]])
3871+
"""
3872+
return _mx_nd_np.flip(x, axis, out=out, **kwargs)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
38-
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
38+
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']
3939

4040

4141
def _num_outputs(sym):
@@ -2678,4 +2678,42 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
26782678
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)
26792679

26802680

2681+
@set_module('mxnet.symbol.numpy')
2682+
def flip(x, axis=None, out=None, **kwargs):
2683+
r"""
2684+
flip(x, axis=None, out=None)
2685+
2686+
Reverse the order of elements in an array along the given axis.
2687+
2688+
The shape of the array is preserved, but the elements are reordered.
2689+
2690+
Parameters
2691+
----------
2692+
m : _Symbol or scalar
2693+
Input array.
2694+
axis : None or int or tuple of ints, optional
2695+
Axis or axes along which to flip over. The default,
2696+
axis=None, will flip over all of the axes of the input array.
2697+
If axis is negative it counts from the last to the first axis.
2698+
2699+
If axis is a tuple of ints, flipping is performed on all of the axes
2700+
specified in the tuple.
2701+
out : _Symbol or scalar, optional
2702+
Alternative output array in which to place the result. It must have
2703+
the same shape and type as the expected output.
2704+
2705+
Returns
2706+
-------
2707+
out : _Symbol or scalar
2708+
A view of `m` with the entries of axis reversed. Since a view is
2709+
returned, this operation is done in constant time.
2710+
"""
2711+
if isinstance(x, numeric_types):
2712+
return _np.flip(x, axis, **kwargs)
2713+
elif isinstance(x, _Symbol):
2714+
return _npi.flip(x, axis, out=out, **kwargs)
2715+
else:
2716+
raise TypeError('type {} not supported'.format(str(type(x))))
2717+
2718+
26812719
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,81 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
6060
}
6161
}
6262

63+
struct FlipParam : public dmlc::Parameter<FlipParam> {
64+
mxnet::Tuple<int> axis;
65+
DMLC_DECLARE_PARAMETER(FlipParam) {
66+
DMLC_DECLARE_FIELD(axis)
67+
.describe("The axis which to flip elements.");
68+
}
69+
};
70+
71+
struct flip0dim_shared_kernel {
72+
template<typename DType>
73+
MSHADOW_XINLINE static void Map(int i,
74+
DType* out_data,
75+
const DType* in_data) {
76+
out_data[i] = in_data[i];
77+
}
78+
};
79+
80+
#define FLIP_MAX_DIM 10
81+
#define FLIP_MIN_DIM -1
82+
83+
template<typename xpu>
84+
void NumpyFlipForwardImpl(const OpContext& ctx,
85+
const std::vector<TBlob>& inputs,
86+
const std::vector<TBlob>& outputs,
87+
const std::vector<index_t>& stride_,
88+
const std::vector<index_t>& trailing_,
89+
const index_t& flip_index);
90+
91+
template<typename xpu>
92+
void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
93+
const OpContext& ctx,
94+
const std::vector<TBlob>& inputs,
95+
const std::vector<OpReqType>& req,
96+
const std::vector<TBlob>& outputs) {
97+
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
98+
mxnet::Tuple<int> axistemp;
99+
CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_);
100+
CHECK_LT(param.axis.ndim(), FLIP_MAX_DIM);
101+
CHECK_GE(param.axis.ndim(), FLIP_MIN_DIM);
102+
if (param.axis.ndim() == FLIP_MIN_DIM) {
103+
if (inputs[0].shape_.ndim() == 0) {
104+
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
105+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
106+
mxnet_op::Kernel<flip0dim_shared_kernel, xpu>::Launch(s, inputs[0].Size(),
107+
outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
108+
});
109+
return;
110+
}
111+
std::vector<int> temp;
112+
for (int i = 0; i < inputs[0].shape_.ndim(); i++) {
113+
temp.push_back(i);
114+
}
115+
axistemp.assign(temp.begin(), temp.end());
116+
} else {
117+
axistemp = param.axis;
118+
}
119+
120+
const mxnet::TShape& ishape = inputs[0].shape_;
121+
if (ishape.ProdShape(0, ishape.ndim()) == 0) {
122+
return; // zero shape
123+
}
124+
std::vector<index_t> stride_(axistemp.ndim());
125+
std::vector<index_t> trailing_(axistemp.ndim());
126+
index_t flip_index = 0;
127+
for (int axis : axistemp) {
128+
CHECK_LT(axis, ishape.ndim());
129+
stride_[flip_index] = ishape[axis];
130+
trailing_[flip_index] = 1;
131+
for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) {
132+
trailing_[flip_index] *= ishape[i2];
133+
}
134+
flip_index++;
135+
}
136+
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
137+
}
63138
} // namespace op
64139
} // namespace mxnet
65140

src/operator/numpy/np_matrix_op.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,5 +345,51 @@ Examples::
345345
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
346346
.add_arguments(StackParam::__FIELDS__());
347347

348+
template<>
349+
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
350+
const std::vector<TBlob>& inputs,
351+
const std::vector<TBlob>& outputs,
352+
const std::vector<index_t>& stride_,
353+
const std::vector<index_t>& trailing_,
354+
const index_t& flip_index) {
355+
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
356+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
357+
mxnet_op::Kernel<reverse, cpu>::Launch(s, inputs[0].Size(), flip_index,
358+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
359+
stride_.data(), trailing_.data());
360+
});
361+
}
362+
363+
DMLC_REGISTER_PARAMETER(FlipParam);
364+
365+
NNVM_REGISTER_OP(_npi_flip)
366+
.set_num_outputs(1)
367+
.set_num_inputs(1)
368+
.set_attr_parser(ParamParser<FlipParam>)
369+
.set_attr<nnvm::FListInputNames>("FListInputNames",
370+
[](const NodeAttrs& attrs) {
371+
return std::vector<std::string> {"data"};
372+
})
373+
.set_attr<FResourceRequest>("FResourceRequest",
374+
[](const NodeAttrs& attrs) {
375+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
376+
})
377+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
378+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
379+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>)
380+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_flip"})
381+
.add_argument("data", "NDArray-or-Symbol", "Input data array")
382+
.add_arguments(FlipParam::__FIELDS__());
383+
384+
NNVM_REGISTER_OP(_backward_npi_flip)
385+
.set_num_inputs(1)
386+
.set_num_outputs(1)
387+
.set_attr_parser(ParamParser<FlipParam>)
388+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
389+
.set_attr<FResourceRequest>("FResourceRequest",
390+
[](const NodeAttrs& attrs) {
391+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
392+
})
393+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>);
348394
} // namespace op
349395
} // namespace mxnet

src/operator/numpy/np_matrix_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,39 @@ NNVM_REGISTER_OP(_backward_np_concat)
4646
NNVM_REGISTER_OP(_npi_stack)
4747
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
4848

49+
template<>
50+
void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
51+
const std::vector<TBlob>& inputs,
52+
const std::vector<TBlob>& outputs,
53+
const std::vector<index_t>& stride_,
54+
const std::vector<index_t>& trailing_,
55+
const index_t& flip_index) {
56+
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
57+
mshadow::Tensor<gpu, 1, uint8_t> workspace =
58+
ctx.requested[0].get_space_typed<gpu, 1, uint8_t>(
59+
mshadow::Shape1(flip_index * sizeof(index_t) * 2), s);
60+
61+
auto stride_workspace = workspace.dptr_;
62+
auto trailing_workspace = workspace.dptr_ + flip_index * sizeof(index_t);
63+
64+
cudaMemcpyAsync(stride_workspace, thrust::raw_pointer_cast(stride_.data()),
65+
stride_.size() * sizeof(index_t),
66+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
67+
cudaMemcpyAsync(trailing_workspace, thrust::raw_pointer_cast(trailing_.data()),
68+
trailing_.size() * sizeof(index_t),
69+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
70+
71+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
72+
mxnet_op::Kernel<reverse, gpu>::Launch(s, inputs[0].Size(), flip_index,
73+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
74+
reinterpret_cast<index_t*>(stride_workspace), reinterpret_cast<index_t*>(trailing_workspace));
75+
});
76+
}
77+
78+
NNVM_REGISTER_OP(_npi_flip)
79+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
80+
81+
NNVM_REGISTER_OP(_backward_npi_flip)
82+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
4983
} // namespace op
5084
} // namespace mxnet

0 commit comments

Comments
 (0)