Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 708276d

Browse files
Yingtingying2020
authored andcommitted
Numpy flip operator
* Implement flip * fix some bug and add gpu test * register param and edit test * add testcase for backward * remove print * optimize 0-dim and 0-shape * adjust format and add doc in _symbol.py * fix bug in symbol * add flip in __all__ * fix format error * import ndarray * move flip implementation to np_matrix_op and remove test in gpu * delate redundant blank line * fix error in review * remove **kwargs and change copy * fix error in review
1 parent a37a76c commit 708276d

File tree

7 files changed

+357
-3
lines changed

7 files changed

+357
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
3636
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
37-
'ravel']
37+
'ravel', 'flip']
3838

3939

4040
@set_module('mxnet.ndarray.numpy')
@@ -2537,3 +2537,71 @@ def ravel(x, order='C'):
25372537
return _npi.reshape(x, -1)
25382538
else:
25392539
raise TypeError('type {} not supported'.format(str(type(x))))
2540+
2541+
2542+
@set_module('mxnet.ndarray.numpy')
2543+
def flip(x, axis=None, out=None):
2544+
r"""
2545+
flip(x, axis=None, out=None)
2546+
2547+
Reverse the order of elements in an array along the given axis.
2548+
2549+
The shape of the array is preserved, but the elements are reordered.
2550+
2551+
Parameters
2552+
----------
2553+
m : ndarray or scalar
2554+
Input array.
2555+
axis : None or int or tuple of ints, optional
2556+
Axis or axes along which to flip over. The default,
2557+
axis=None, will flip over all of the axes of the input array.
2558+
If axis is negative it counts from the last to the first axis.
2559+
2560+
If axis is a tuple of ints, flipping is performed on all of the axes
2561+
specified in the tuple.
2562+
out : ndarray or scalar, optional
2563+
Alternative output array in which to place the result. It must have
2564+
the same shape and type as the expected output.
2565+
2566+
Returns
2567+
-------
2568+
out : ndarray or scalar
2569+
A view of `m` with the entries of axis reversed. Since a view is
2570+
returned, this operation is done in constant time.
2571+
2572+
Examples
2573+
--------
2574+
>>> A = np.arange(8).reshape((2,2,2))
2575+
>>> A
2576+
array([[[0, 1],
2577+
[2, 3]],
2578+
[[4, 5],
2579+
[6, 7]]])
2580+
>>> np.flip(A, 0)
2581+
array([[[4, 5],
2582+
[6, 7]],
2583+
[[0, 1],
2584+
[2, 3]]])
2585+
>>> np.flip(A, 1)
2586+
array([[[2, 3],
2587+
[0, 1]],
2588+
[[6, 7],
2589+
[4, 5]]])
2590+
>>> np.flip(A)
2591+
array([[[7, 6],
2592+
[5, 4]],
2593+
[[3, 2],
2594+
[1, 0]]])
2595+
>>> np.flip(A, (0, 2))
2596+
array([[[5, 4],
2597+
[7, 6]],
2598+
[[1, 0],
2599+
[3, 2]]])
2600+
"""
2601+
from ...numpy import ndarray
2602+
if isinstance(x, numeric_types):
2603+
return _np.flip(x, axis)
2604+
elif isinstance(x, ndarray):
2605+
return _npi.flip(x, axis, out=out)
2606+
else:
2607+
raise TypeError('type {} not supported'.format(str(type(x))))

python/mxnet/numpy/multiarray.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
5555
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
56-
'ravel']
56+
'ravel', 'flip']
5757

5858
# Return code for dispatching indexing function call
5959
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -4088,3 +4088,65 @@ def ravel(x, order='C'):
40884088
[1. 4. 2. 5. 3. 6.]
40894089
"""
40904090
return _mx_nd_np.ravel(x, order)
4091+
4092+
4093+
@set_module('mxnet.numpy')
4094+
def flip(x, axis=None, out=None):
4095+
r"""
4096+
flip(x, axis=None, out=None)
4097+
4098+
Reverse the order of elements in an array along the given axis.
4099+
4100+
The shape of the array is preserved, but the elements are reordered.
4101+
4102+
Parameters
4103+
----------
4104+
m : ndarray or scalar
4105+
Input array.
4106+
axis : None or int or tuple of ints, optional
4107+
Axis or axes along which to flip over. The default,
4108+
axis=None, will flip over all of the axes of the input array.
4109+
If axis is negative it counts from the last to the first axis.
4110+
4111+
If axis is a tuple of ints, flipping is performed on all of the axes
4112+
specified in the tuple.
4113+
out : ndarray or scalar, optional
4114+
Alternative output array in which to place the result. It must have
4115+
the same shape and type as the expected output.
4116+
4117+
Returns
4118+
-------
4119+
out : ndarray or scalar
4120+
A view of `m` with the entries of axis reversed. Since a view is
4121+
returned, this operation is done in constant time.
4122+
4123+
Examples
4124+
--------
4125+
>>> A = np.arange(8).reshape((2,2,2))
4126+
>>> A
4127+
array([[[0, 1],
4128+
[2, 3]],
4129+
[[4, 5],
4130+
[6, 7]]])
4131+
>>> np.flip(A, 0)
4132+
array([[[4, 5],
4133+
[6, 7]],
4134+
[[0, 1],
4135+
[2, 3]]])
4136+
>>> np.flip(A, 1)
4137+
array([[[2, 3],
4138+
[0, 1]],
4139+
[[6, 7],
4140+
[4, 5]]])
4141+
>>> np.flip(A)
4142+
array([[[7, 6],
4143+
[5, 4]],
4144+
[[3, 2],
4145+
[1, 0]]])
4146+
>>> np.flip(A, (0, 2))
4147+
array([[[5, 4],
4148+
[7, 6]],
4149+
[[1, 0],
4150+
[3, 2]]])
4151+
"""
4152+
return _mx_nd_np.flip(x, axis, out=out)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3737
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
3838
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
39-
'ravel']
39+
'ravel', 'flip']
4040

4141

4242
def _num_outputs(sym):
@@ -2818,4 +2818,42 @@ def ravel(x, order='C'):
28182818
raise TypeError('type {} not supported'.format(str(type(x))))
28192819

28202820

2821+
@set_module('mxnet.symbol.numpy')
2822+
def flip(x, axis=None, out=None):
2823+
r"""
2824+
flip(x, axis=None, out=None)
2825+
2826+
Reverse the order of elements in an array along the given axis.
2827+
2828+
The shape of the array is preserved, but the elements are reordered.
2829+
2830+
Parameters
2831+
----------
2832+
m : _Symbol or scalar
2833+
Input array.
2834+
axis : None or int or tuple of ints, optional
2835+
Axis or axes along which to flip over. The default,
2836+
axis=None, will flip over all of the axes of the input array.
2837+
If axis is negative it counts from the last to the first axis.
2838+
2839+
If axis is a tuple of ints, flipping is performed on all of the axes
2840+
specified in the tuple.
2841+
out : _Symbol or scalar, optional
2842+
Alternative output array in which to place the result. It must have
2843+
the same shape and type as the expected output.
2844+
2845+
Returns
2846+
-------
2847+
out : _Symbol or scalar
2848+
A view of `m` with the entries of axis reversed. Since a view is
2849+
returned, this operation is done in constant time.
2850+
"""
2851+
if isinstance(x, numeric_types):
2852+
return _np.flip(x, axis)
2853+
elif isinstance(x, _Symbol):
2854+
return _npi.flip(x, axis, out=out)
2855+
else:
2856+
raise TypeError('type {} not supported'.format(str(type(x))))
2857+
2858+
28212859
_set_np_symbol_class(_Symbol)

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,71 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
6060
}
6161
}
6262

63+
struct FlipParam : public dmlc::Parameter<FlipParam> {
64+
mxnet::Tuple<int> axis;
65+
DMLC_DECLARE_PARAMETER(FlipParam) {
66+
DMLC_DECLARE_FIELD(axis)
67+
.describe("The axis which to flip elements.");
68+
}
69+
};
70+
71+
#define FLIP_MAX_DIM 10
72+
#define FLIP_MIN_DIM -1
73+
74+
template<typename xpu>
75+
void NumpyFlipForwardImpl(const OpContext& ctx,
76+
const std::vector<TBlob>& inputs,
77+
const std::vector<TBlob>& outputs,
78+
const std::vector<index_t>& stride_,
79+
const std::vector<index_t>& trailing_,
80+
const index_t& flip_index);
81+
82+
template<typename xpu>
83+
void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
84+
const OpContext& ctx,
85+
const std::vector<TBlob>& inputs,
86+
const std::vector<OpReqType>& req,
87+
const std::vector<TBlob>& outputs) {
88+
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
89+
mxnet::Tuple<int> axistemp;
90+
CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_);
91+
CHECK_LT(param.axis.ndim(), FLIP_MAX_DIM);
92+
CHECK_GE(param.axis.ndim(), FLIP_MIN_DIM);
93+
if (param.axis.ndim() == FLIP_MIN_DIM) {
94+
if (inputs[0].shape_.ndim() == 0) {
95+
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
96+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
97+
mshadow::Copy(outputs[0].FlatTo1D<xpu, DType>(s), inputs[0].FlatTo1D<xpu, DType>(s), s);
98+
});
99+
return;
100+
}
101+
std::vector<int> temp;
102+
for (int i = 0; i < inputs[0].shape_.ndim(); i++) {
103+
temp.push_back(i);
104+
}
105+
axistemp.assign(temp.begin(), temp.end());
106+
} else {
107+
axistemp = param.axis;
108+
}
109+
110+
const mxnet::TShape& ishape = inputs[0].shape_;
111+
if (ishape.ProdShape(0, ishape.ndim()) == 0) {
112+
return; // zero shape
113+
}
114+
std::vector<index_t> stride_(axistemp.ndim());
115+
std::vector<index_t> trailing_(axistemp.ndim());
116+
index_t flip_index = 0;
117+
for (int axis : axistemp) {
118+
CHECK_LT(axis, ishape.ndim());
119+
stride_[flip_index] = ishape[axis];
120+
trailing_[flip_index] = 1;
121+
for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) {
122+
trailing_[flip_index] *= ishape[i2];
123+
}
124+
flip_index++;
125+
}
126+
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
127+
}
63128
} // namespace op
64129
} // namespace mxnet
65130

src/operator/numpy/np_matrix_op.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,5 +346,51 @@ Examples::
346346
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
347347
.add_arguments(StackParam::__FIELDS__());
348348

349+
template<>
350+
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
351+
const std::vector<TBlob>& inputs,
352+
const std::vector<TBlob>& outputs,
353+
const std::vector<index_t>& stride_,
354+
const std::vector<index_t>& trailing_,
355+
const index_t& flip_index) {
356+
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
357+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
358+
mxnet_op::Kernel<reverse, cpu>::Launch(s, inputs[0].Size(), flip_index,
359+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
360+
stride_.data(), trailing_.data());
361+
});
362+
}
363+
364+
DMLC_REGISTER_PARAMETER(FlipParam);
365+
366+
NNVM_REGISTER_OP(_npi_flip)
367+
.set_num_outputs(1)
368+
.set_num_inputs(1)
369+
.set_attr_parser(ParamParser<FlipParam>)
370+
.set_attr<nnvm::FListInputNames>("FListInputNames",
371+
[](const NodeAttrs& attrs) {
372+
return std::vector<std::string> {"data"};
373+
})
374+
.set_attr<FResourceRequest>("FResourceRequest",
375+
[](const NodeAttrs& attrs) {
376+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
377+
})
378+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
379+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
380+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>)
381+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_flip"})
382+
.add_argument("data", "NDArray-or-Symbol", "Input data array")
383+
.add_arguments(FlipParam::__FIELDS__());
384+
385+
NNVM_REGISTER_OP(_backward_npi_flip)
386+
.set_num_inputs(1)
387+
.set_num_outputs(1)
388+
.set_attr_parser(ParamParser<FlipParam>)
389+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
390+
.set_attr<FResourceRequest>("FResourceRequest",
391+
[](const NodeAttrs& attrs) {
392+
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
393+
})
394+
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>);
349395
} // namespace op
350396
} // namespace mxnet

src/operator/numpy/np_matrix_op.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,39 @@ NNVM_REGISTER_OP(_backward_np_concat)
4646
NNVM_REGISTER_OP(_npi_stack)
4747
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
4848

49+
template<>
50+
void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
51+
const std::vector<TBlob>& inputs,
52+
const std::vector<TBlob>& outputs,
53+
const std::vector<index_t>& stride_,
54+
const std::vector<index_t>& trailing_,
55+
const index_t& flip_index) {
56+
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
57+
mshadow::Tensor<gpu, 1, uint8_t> workspace =
58+
ctx.requested[0].get_space_typed<gpu, 1, uint8_t>(
59+
mshadow::Shape1(flip_index * sizeof(index_t) * 2), s);
60+
61+
auto stride_workspace = workspace.dptr_;
62+
auto trailing_workspace = workspace.dptr_ + flip_index * sizeof(index_t);
63+
64+
cudaMemcpyAsync(stride_workspace, thrust::raw_pointer_cast(stride_.data()),
65+
stride_.size() * sizeof(index_t),
66+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
67+
cudaMemcpyAsync(trailing_workspace, thrust::raw_pointer_cast(trailing_.data()),
68+
trailing_.size() * sizeof(index_t),
69+
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
70+
71+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
72+
mxnet_op::Kernel<reverse, gpu>::Launch(s, inputs[0].Size(), flip_index,
73+
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
74+
reinterpret_cast<index_t*>(stride_workspace), reinterpret_cast<index_t*>(trailing_workspace));
75+
});
76+
}
77+
78+
NNVM_REGISTER_OP(_npi_flip)
79+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
80+
81+
NNVM_REGISTER_OP(_backward_npi_flip)
82+
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
4983
} // namespace op
5084
} // namespace mxnet

0 commit comments

Comments
 (0)