Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 144eea4

Browse files
committed
numpy compatible max min
1 parent 47f8ceb commit 144eea4

File tree

4 files changed

+284
-0
lines changed

4 files changed

+284
-0
lines changed

src/operator/numpy/np_broadcast_reduce_op.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
6464
}
6565
};
6666

67+
struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter<NumpyReduceAxesNoDTypeParam> {
68+
dmlc::optional<mxnet::Tuple<int>> axis;
69+
bool keepdims;
70+
dmlc::optional<double> initial;
71+
DMLC_DECLARE_PARAMETER(NumpyReduceAxesNoDTypeParam) {
72+
DMLC_DECLARE_FIELD(axis)
73+
.set_default(dmlc::optional<mxnet::Tuple<int>>())
74+
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
75+
"all of the elements of the input array. If axis is negative it counts from the "
76+
"last to the first axis.");
77+
DMLC_DECLARE_FIELD(keepdims).set_default(false)
78+
.describe("If this is set to `True`, the reduced axes are left "
79+
"in the result as dimension with size one.");
80+
DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
81+
.describe("Starting value for the sum.");
82+
}
83+
};
84+
6785
inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
6886
const dmlc::optional<mxnet::Tuple<int>>& axis,
6987
bool keepdims) {
@@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
152170
return shape_is_known(out_attrs->at(0));
153171
}
154172

173+
inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs,
174+
std::vector<TShape> *in_attrs,
175+
std::vector<TShape> *out_attrs) {
176+
CHECK_EQ(in_attrs->size(), 1U);
177+
CHECK_EQ(out_attrs->size(), 1U);
178+
if (!shape_is_known(in_attrs->at(0))) {
179+
return false;
180+
}
181+
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
182+
// check the case where the reduction axis should not be zero
183+
bool is_all_reducded_axes_not_zero = true;
184+
const TShape& ishape = (*in_attrs)[0];
185+
if (param.axis.has_value()) {
186+
const mxnet::Tuple<int>& axes = param.axis.value();
187+
for (int i = 0; i < axes.ndim(); ++i) {
188+
if (ishape[axes[i]] == 0) {
189+
is_all_reducded_axes_not_zero = false;
190+
break;
191+
}
192+
}
193+
} else {
194+
if (ishape.Size() == 0) {
195+
// global reduction should excuted only when input have size more than 0
196+
is_all_reducded_axes_not_zero = false;
197+
}
198+
}
199+
CHECK(is_all_reducded_axes_not_zero)
200+
<< "zero-size array to reduction operation maximum which has no identity";
201+
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
202+
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
203+
return shape_is_known(out_attrs->at(0));
204+
}
205+
155206
template<bool safe_acc_hint = false>
156207
inline bool NeedSafeAcc(int itype, int otype) {
157208
bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
@@ -186,6 +237,30 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
186237
}
187238
}
188239

240+
template<typename xpu, typename reducer, typename OP = op::mshadow_op::identity>
241+
void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs,
242+
const OpContext& ctx,
243+
const std::vector<TBlob>& inputs,
244+
const std::vector<OpReqType>& req,
245+
const std::vector<TBlob>& outputs) {
246+
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
247+
if (param.initial.has_value()) {
248+
LOG(FATAL) << "initial is not supported yet";
249+
}
250+
if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; // zero-size tensor
251+
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
252+
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
253+
}
254+
TShape small;
255+
if (param.keepdims) {
256+
small = outputs[0].shape_;
257+
} else {
258+
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
259+
}
260+
ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
261+
}
262+
263+
189264
template<typename xpu, bool normalize = false>
190265
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
191266
const OpContext& ctx,
@@ -273,6 +348,24 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
273348
}
274349
}
275350

351+
template<typename xpu, typename OP>
352+
void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs,
353+
const OpContext& ctx,
354+
const std::vector<TBlob>& inputs,
355+
const std::vector<OpReqType>& req,
356+
const std::vector<TBlob>& outputs) {
357+
using namespace mshadow;
358+
using namespace mshadow::expr;
359+
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
360+
TShape small;
361+
if (param.keepdims) {
362+
small = inputs[0].shape_;
363+
} else {
364+
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
365+
}
366+
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
367+
}
368+
276369
} // namespace op
277370
} // namespace mxnet
278371
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace mxnet {
2929
namespace op {
3030

3131
DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
32+
DMLC_REGISTER_PARAMETER(NumpyReduceAxesNoDTypeParam);
3233

3334
inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
3435
std::vector<int> *in_attrs,
@@ -74,6 +75,71 @@ NNVM_REGISTER_OP(_backward_np_sum)
7475
.set_num_inputs(1)
7576
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu>);
7677

78+
inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs,
79+
std::vector<int> *in_attrs,
80+
std::vector<int> *out_attrs) {
81+
CHECK_EQ(in_attrs->size(), 1U);
82+
CHECK_EQ(out_attrs->size(), 1U);
83+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
84+
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
85+
86+
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
87+
}
88+
89+
NNVM_REGISTER_OP(_np_max)
90+
.describe(R"code()code" ADD_FILELINE)
91+
.set_num_inputs(1)
92+
.set_num_outputs(1)
93+
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
94+
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
95+
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
96+
.set_attr<nnvm::FListInputNames>("FListInputNames",
97+
[](const NodeAttrs& attrs) {
98+
return std::vector<std::string>{"a"};
99+
})
100+
.add_argument("a", "NDArray-or-Symbol", "The input")
101+
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
102+
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::maximum>)
103+
.set_attr<FResourceRequest>("FResourceRequest",
104+
[](const NodeAttrs& attrs) {
105+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
106+
})
107+
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_max"});
108+
109+
NNVM_REGISTER_OP(_backward_np_max)
110+
.set_num_outputs(1)
111+
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
112+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
113+
.set_num_inputs(3)
114+
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);
115+
116+
NNVM_REGISTER_OP(_np_min)
117+
.describe(R"code()code" ADD_FILELINE)
118+
.set_num_inputs(1)
119+
.set_num_outputs(1)
120+
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
121+
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
122+
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
123+
.set_attr<nnvm::FListInputNames>("FListInputNames",
124+
[](const NodeAttrs& attrs) {
125+
return std::vector<std::string>{"a"};
126+
})
127+
.add_argument("a", "NDArray-or-Symbol", "The input")
128+
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
129+
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::minimum>)
130+
.set_attr<FResourceRequest>("FResourceRequest",
131+
[](const NodeAttrs& attrs) {
132+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
133+
})
134+
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_min"});
135+
136+
NNVM_REGISTER_OP(_backward_np_min)
137+
.set_num_outputs(1)
138+
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
139+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
140+
.set_num_inputs(3)
141+
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);
142+
77143
NNVM_REGISTER_OP(_np_prod)
78144
.set_num_inputs(1)
79145
.set_num_outputs(1)

src/operator/numpy/np_broadcast_reduce_op_value.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ NNVM_REGISTER_OP(_np_sum)
3232
NNVM_REGISTER_OP(_backward_np_sum)
3333
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
3434

35+
NNVM_REGISTER_OP(_np_max)
36+
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::maximum>);
37+
38+
NNVM_REGISTER_OP(_backward_np_max)
39+
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
40+
41+
NNVM_REGISTER_OP(_np_min)
42+
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::minimum>);
43+
44+
NNVM_REGISTER_OP(_backward_np_min)
45+
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
46+
3547
NNVM_REGISTER_OP(_np_prod)
3648
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::product, true>);
3749

tests/python/unittest/test_numpy_op.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,119 @@ def is_int(dtype):
277277
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
278278

279279

280+
@with_seed()
281+
@use_np
282+
def test_np_max_min():
283+
class TestMax(HybridBlock):
284+
def __init__(self, axis=None, keepdims=False):
285+
super(TestMax, self).__init__()
286+
self._axis = axis
287+
self._keepdims = keepdims
288+
289+
def hybrid_forward(self, F, a, *args, **kwargs):
290+
return F.np.max(a, axis=self._axis, keepdims=self._keepdims)
291+
292+
class TestMin(HybridBlock):
293+
def __init__(self, axis=None, keepdims=False):
294+
super(TestMin, self).__init__()
295+
self._axis = axis
296+
self._keepdims = keepdims
297+
298+
def hybrid_forward(self, F, a, *args, **kwargs):
299+
return F.np.min(a, axis=self._axis, keepdims=self._keepdims)
300+
301+
def is_int(dtype):
302+
return 'int' == dtype
303+
304+
def get_grad(axis, func_name):
305+
index = -1 if func_name == 'max' else 0
306+
if axis == ():
307+
return _np.ones((2,3,4,5))
308+
else:
309+
temp = _np.zeros((2,3,4,5))
310+
if axis == 0:
311+
temp[index,:,:,:] = 1
312+
return temp
313+
elif axis == 1:
314+
temp[:,index,:,:] = 1
315+
return temp
316+
elif axis == 2:
317+
temp[:,:,index,:] = 1
318+
return temp
319+
elif axis == 3:
320+
temp[:,:,:,index] = 1
321+
return temp
322+
elif not axis:
323+
temp[index,index,index,index] = 1
324+
return temp
325+
raise ValueError('axis should be int or None or ()')
326+
327+
def _test_np_exception(func, shape, dim):
328+
x = _np.random.uniform(-1.0, 1.0, shape)
329+
x = mx.nd.array(x).as_np_ndarray()
330+
if func == 'max':
331+
out = mx.np.max(x)
332+
else:
333+
out = mx.np.min(x)
334+
assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim)
335+
336+
in_data_dim = random.choice([2, 3, 4])
337+
shape = rand_shape_nd(in_data_dim, dim=3)
338+
for func in ['max', 'min']:
339+
for hybridize in [False, True]:
340+
for keepdims in [True, False]:
341+
for axis in ([i for i in range(in_data_dim)] + [(), None]):
342+
for itype in ['float16', 'float32', 'float64', 'int']:
343+
# test gluon
344+
if func == 'max':
345+
test_gluon = TestMax(axis=axis, keepdims=keepdims)
346+
else:
347+
test_gluon = TestMin(axis=axis, keepdims=keepdims)
348+
if hybridize:
349+
test_gluon.hybridize()
350+
if is_int(itype):
351+
x = mx.nd.arange(120).reshape((2, 3, 4, 5))
352+
x = mx.nd.array(x)
353+
else:
354+
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
355+
x = x.as_np_ndarray()
356+
x.attach_grad()
357+
if func == 'max':
358+
expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
359+
else:
360+
expected_ret = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
361+
with mx.autograd.record():
362+
y = test_gluon(x)
363+
assert y.shape == expected_ret.shape
364+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
365+
atol=1e-5 if itype == 'float16' else 1e-5)
366+
y.backward()
367+
# only check the gradient with hardcoded input
368+
if is_int(itype):
369+
assert same(x.grad.asnumpy(), get_grad(axis, func)), \
370+
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis))
371+
372+
# test imperative
373+
if func == 'max':
374+
mx_out = np.max(x, axis=axis, keepdims=keepdims)
375+
np_out = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
376+
else:
377+
mx_out = np.min(x, axis=axis, keepdims=keepdims)
378+
np_out = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
379+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
380+
381+
# test zero and zero dim
382+
shapes = [(), (0), (2, 0), (0, 2, 1)]
383+
exceptions = [False, True, True, True]
384+
dims = [0] * len(shapes)
385+
for func in ['max', 'min']:
386+
for shape, exception, dim in zip(shapes, exceptions, dims):
387+
if exception:
388+
assertRaises(MXNetError, _test_np_exception, func, shape, dim)
389+
else:
390+
_test_np_exception(func, shape, dim)
391+
392+
280393
@with_seed()
281394
@use_np
282395
def test_np_linspace():

0 commit comments

Comments
 (0)