Skip to content

Commit 15181fe

Browse files
tingying2020drivanov
authored andcommitted
numpy operator arctan2 (apache#15890)
* change the test code * add @use_np in test code * only support float16, float32 and float64. * fix format error * remove redundant backslash * change wrapper in symbol * delete gpu test * edit test * change infer type * remove redundant **kwargs * change atol and rtol in test * edit test shape
1 parent f3cdf49 commit 15181fe

File tree

9 files changed

+420
-8
lines changed

9 files changed

+420
-8
lines changed

python/mxnet/ndarray/numpy/_op.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from ..ndarray import NDArray
2929

3030
__all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
31-
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute',
32-
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
33-
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
31+
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
32+
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
33+
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3434
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
3535
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
3636
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
@@ -2953,3 +2953,89 @@ def around(x, decimals=0, out=None, **kwargs):
29532953
return _npi.around(x, decimals, out=out, **kwargs)
29542954
else:
29552955
raise TypeError('type {} not supported'.format(str(type(x))))
2956+
2957+
2958+
@set_module('mxnet.ndarray.numpy')
2959+
def arctan2(x1, x2, out=None):
2960+
r"""
2961+
arctan2(x1, x2, out=None)
2962+
2963+
Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
2964+
2965+
The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
2966+
the signed angle in radians between the ray ending at the origin and
2967+
passing through the point (1,0), and the ray ending at the origin and
2968+
passing through the point (`x2`, `x1`). (Note the role reversal: the
2969+
"`y`-coordinate" is the first function parameter, the "`x`-coordinate"
2970+
is the second.) By IEEE convention, this function is defined for
2971+
`x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
2972+
Notes for specific values).
2973+
2974+
This function is not defined for complex-valued arguments; for the
2975+
so-called argument of complex values, use `angle`.
2976+
2977+
Parameters
2978+
----------
2979+
x1 : ndarray or scalar
2980+
`y`-coordinates.
2981+
x2 : ndarray or scalar
2982+
`x`-coordinates. `x2` must be broadcastable to match the shape of
2983+
`x1` or vice versa.
2984+
out : ndarray or None, optional
2985+
A location into which the result is stored. If provided, it must have
2986+
a shape that the inputs broadcast to. If not provided or `None`,
2987+
a freshly-allocated array is returned.
2988+
2989+
Returns
2990+
-------
2991+
out : ndarray or scalar
2992+
Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if
2993+
`x1` and `x2` are scalars.
2994+
2995+
Notes
2996+
-----
2997+
*arctan2* is identical to the `atan2` function of the underlying
2998+
C library. The following special values are defined in the C
2999+
standard: [1]_
3000+
3001+
====== ====== ================
3002+
`x1` `x2` `arctan2(x1,x2)`
3003+
====== ====== ================
3004+
+/- 0 +0 +/- 0
3005+
+/- 0 -0 +/- pi
3006+
> 0 +/-inf +0 / +pi
3007+
< 0 +/-inf -0 / -pi
3008+
+/-inf +inf +/- (pi/4)
3009+
+/-inf -inf +/- (3*pi/4)
3010+
====== ====== ================
3011+
3012+
Note that +0 and -0 are distinct floating point numbers, as are +inf
3013+
and -inf.
3014+
3015+
This function differs from the original numpy.arange in the following aspects:
3016+
- Only support float16, float32 and float64.
3017+
3018+
References
3019+
----------
3020+
.. [1] ISO/IEC standard 9899:1999, "Programming language C."
3021+
3022+
Examples
3023+
--------
3024+
Consider four points in different quadrants:
3025+
3026+
>>> x = np.array([-1, +1, +1, -1])
3027+
>>> y = np.array([-1, -1, +1, +1])
3028+
>>> np.arctan2(y, x) * 180 / np.pi
3029+
array([-135., -45., 45., 135.])
3030+
3031+
Note the order of the parameters. `arctan2` is defined also when `x2` = 0
3032+
and at several other special points, obtaining values in
3033+
the range ``[-pi, pi]``:
3034+
3035+
>>> x = np.array([1, -1])
3036+
>>> y = np.array([0, 0])
3037+
>>> np.arctan2(x, y)
3038+
array([ 1.5707964, -1.5707964])
3039+
"""
3040+
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
3041+
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)

python/mxnet/numpy/multiarray.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@
4747
from ..ndarray.numpy import _internal as _npi
4848

4949
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide',
50-
'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt',
51-
'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
50+
'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10',
51+
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
5252
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
5353
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
5454
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
5555
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
56-
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']
56+
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2']
5757

5858
# Return code for dispatching indexing function call
5959
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -4481,3 +4481,88 @@ def around(x, decimals=0, out=None, **kwargs):
44814481
array([ 0, 0, 0, 10])
44824482
"""
44834483
return _mx_nd_np.around(x, decimals, out=out, **kwargs)
4484+
4485+
4486+
@set_module('mxnet.numpy')
4487+
def arctan2(x1, x2, out=None):
4488+
r"""
4489+
arctan2(x1, x2, out=None)
4490+
4491+
Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
4492+
4493+
The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
4494+
the signed angle in radians between the ray ending at the origin and
4495+
passing through the point (1,0), and the ray ending at the origin and
4496+
passing through the point (`x2`, `x1`). (Note the role reversal: the
4497+
"`y`-coordinate" is the first function parameter, the "`x`-coordinate"
4498+
is the second.) By IEEE convention, this function is defined for
4499+
`x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
4500+
Notes for specific values).
4501+
4502+
This function is not defined for complex-valued arguments; for the
4503+
so-called argument of complex values, use `angle`.
4504+
4505+
Parameters
4506+
----------
4507+
x1 : ndarray or scalar
4508+
`y`-coordinates.
4509+
x2 : ndarray or scalar
4510+
`x`-coordinates. `x2` must be broadcastable to match the shape of
4511+
`x1` or vice versa.
4512+
out : ndarray or None, optional
4513+
A location into which the result is stored. If provided, it must have
4514+
a shape that the inputs broadcast to. If not provided or `None`,
4515+
a freshly-allocated array is returned.
4516+
4517+
Returns
4518+
-------
4519+
out : ndarray or scalar
4520+
Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if
4521+
`x1` and `x2` are scalars.
4522+
4523+
Notes
4524+
-----
4525+
*arctan2* is identical to the `atan2` function of the underlying
4526+
C library. The following special values are defined in the C
4527+
standard: [1]_
4528+
4529+
====== ====== ================
4530+
`x1` `x2` `arctan2(x1,x2)`
4531+
====== ====== ================
4532+
+/- 0 +0 +/- 0
4533+
+/- 0 -0 +/- pi
4534+
> 0 +/-inf +0 / +pi
4535+
< 0 +/-inf -0 / -pi
4536+
+/-inf +inf +/- (pi/4)
4537+
+/-inf -inf +/- (3*pi/4)
4538+
====== ====== ================
4539+
4540+
Note that +0 and -0 are distinct floating point numbers, as are +inf
4541+
and -inf.
4542+
4543+
This function differs from the original numpy.arange in the following aspects:
4544+
- Only support float16, float32 and float64.
4545+
4546+
References
4547+
----------
4548+
.. [1] ISO/IEC standard 9899:1999, "Programming language C."
4549+
4550+
Examples
4551+
--------
4552+
Consider four points in different quadrants:
4553+
4554+
>>> x = np.array([-1, +1, +1, -1])
4555+
>>> y = np.array([-1, -1, +1, +1])
4556+
>>> np.arctan2(y, x) * 180 / np.pi
4557+
array([-135., -45., 45., 135.])
4558+
4559+
Note the order of the parameters. `arctan2` is defined also when `x2` = 0
4560+
and at several other special points, obtaining values in
4561+
the range ``[-pi, pi]``:
4562+
4563+
>>> x = np.array([1, -1])
4564+
>>> y = np.array([0, 0])
4565+
>>> np.arctan2(x, y)
4566+
array([ 1.5707964, -1.5707964])
4567+
"""
4568+
return _mx_nd_np.arctan2(x1, x2, out=out)

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from .._internal import _set_np_symbol_class
3030
from . import _internal as _npi
3131

32-
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'sin',
33-
'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
32+
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
33+
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
3434
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
3535
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
3636
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
@@ -3172,4 +3172,72 @@ def around(x, decimals=0, out=None, **kwargs):
31723172
raise TypeError('type {} not supported'.format(str(type(x))))
31733173

31743174

3175+
@set_module('mxnet.symbol.numpy')
3176+
def arctan2(x1, x2, out=None):
3177+
r"""
3178+
arctan2(x1, x2, out=None)
3179+
3180+
Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly.
3181+
3182+
The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is
3183+
the signed angle in radians between the ray ending at the origin and
3184+
passing through the point (1,0), and the ray ending at the origin and
3185+
passing through the point (`x2`, `x1`). (Note the role reversal: the
3186+
"`y`-coordinate" is the first function parameter, the "`x`-coordinate"
3187+
is the second.) By IEEE convention, this function is defined for
3188+
`x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see
3189+
Notes for specific values).
3190+
3191+
This function is not defined for complex-valued arguments; for the
3192+
so-called argument of complex values, use `angle`.
3193+
3194+
Parameters
3195+
----------
3196+
x1 : _Symbol or scalar
3197+
`y`-coordinates.
3198+
x2 : _Symbol or scalar
3199+
`x`-coordinates. `x2` must be broadcastable to match the shape of
3200+
`x1` or vice versa.
3201+
out : _Symbol or None, optional
3202+
A location into which the result is stored. If provided, it must have
3203+
a shape that the inputs broadcast to. If not provided or `None`,
3204+
a freshly-allocated array is returned.
3205+
3206+
Returns
3207+
-------
3208+
out : _Symbol or scalar
3209+
Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if
3210+
`x1` and `x2` are scalars.
3211+
3212+
Notes
3213+
-----
3214+
*arctan2* is identical to the `atan2` function of the underlying
3215+
C library. The following special values are defined in the C
3216+
standard: [1]_
3217+
3218+
====== ====== ================
3219+
`x1` `x2` `arctan2(x1,x2)`
3220+
====== ====== ================
3221+
+/- 0 +0 +/- 0
3222+
+/- 0 -0 +/- pi
3223+
> 0 +/-inf +0 / +pi
3224+
< 0 +/-inf -0 / -pi
3225+
+/-inf +inf +/- (pi/4)
3226+
+/-inf -inf +/- (3*pi/4)
3227+
====== ====== ================
3228+
3229+
Note that +0 and -0 are distinct floating point numbers, as are +inf
3230+
and -inf.
3231+
3232+
This function differs from the original numpy.arange in the following aspects:
3233+
- Only support float16, float32 and float64.
3234+
3235+
References
3236+
----------
3237+
.. [1] ISO/IEC standard 9899:1999, "Programming language C."
3238+
"""
3239+
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
3240+
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)
3241+
3242+
31753243
_set_np_symbol_class(_Symbol)

src/operator/math_functions-inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ MXNET_BINARY_MATH_FUNC(hypot)
125125

126126
MXNET_BINARY_MATH_FUNC(pow)
127127

128+
MXNET_BINARY_MATH_FUNC(atan2)
129+
128130
template<typename DType> MSHADOW_XINLINE
129131
float id(DType a) {
130132
return static_cast<float>(a);

src/operator/mshadow_op.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,16 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
322322

323323
MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));
324324

325+
MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));
326+
327+
MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));
328+
329+
MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
330+
331+
MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a));
332+
333+
MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b)));
334+
325335
MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));
326336

327337
MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,75 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
144144
.set_attr<FCompute>("FCompute<cpu>",
145145
BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
146146

147+
inline bool IsFloatType(const int dtype) {
148+
return (dtype == mshadow::kFloat16 ||
149+
dtype == mshadow::kFloat32 ||
150+
dtype == mshadow::kFloat64);
151+
}
152+
153+
inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs,
154+
std::vector<int>* in_attrs,
155+
std::vector<int>* out_attrs) {
156+
CHECK_EQ(in_attrs->size(), 2U);
157+
CHECK_EQ(out_attrs->size(), 1U);
158+
159+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
160+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
161+
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
162+
TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
163+
// check if it is float16, float32 or float64. If not, raise error.
164+
CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
165+
return out_attrs->at(0) != -1;
166+
}
167+
168+
NNVM_REGISTER_OP(_npi_arctan2)
169+
.set_num_inputs(2)
170+
.set_num_outputs(1)
171+
.set_attr<nnvm::FListInputNames>("FListInputNames",
172+
[](const NodeAttrs& attrs) {
173+
return std::vector<std::string>{"x1", "x2"};
174+
})
175+
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
176+
.set_attr<nnvm::FInferType>("FInferType", Arctan2OpType)
177+
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::arctan2>)
178+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"})
179+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
180+
[](const NodeAttrs& attrs) {
181+
return std::vector<std::pair<int, int> >{{0, 0}};
182+
})
183+
.add_argument("x1", "NDArray-or-Symbol", "The input array")
184+
.add_argument("x2", "NDArray-or-Symbol", "The input array");
185+
186+
NNVM_REGISTER_OP(_backward_npi_arctan2)
187+
.set_num_inputs(3)
188+
.set_num_outputs(2)
189+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
190+
.set_attr<FResourceRequest>("FResourceRequest",
191+
[](const NodeAttrs& attrs) {
192+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
193+
})
194+
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::arctan2_grad,
195+
mshadow_op::arctan2_rgrad>);
196+
197+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar)
198+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::arctan2>)
199+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"});
200+
201+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
202+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rarctan2>)
203+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});
204+
205+
MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
206+
.add_argument("scalar", "float", "scalar value")
207+
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
208+
.set_attr<FCompute>("FCompute<cpu>",
209+
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);
210+
211+
MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
212+
.add_argument("scalar", "float", "scalar value")
213+
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
214+
.set_attr<FCompute>("FCompute<cpu>",
215+
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);
216+
147217
} // namespace op
148218
} // namespace mxnet

0 commit comments

Comments
 (0)