Skip to content

Commit 2af01ba

Browse files
haojin2ptrendx
authored andcommitted
Fix numpy-compatible mean output type for integer inputs (apache#16792)
* fix mean output type for integer inputs * enable for windows
1 parent d07b2f4 commit 2af01ba

File tree

7 files changed

+88
-74
lines changed

7 files changed

+88
-74
lines changed

src/ndarray/ndarray_function-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
379379
template<>
380380
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
381381
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
382-
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
382+
MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
383383
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
384384
});
385385
}

src/operator/numpy/np_broadcast_reduce_op.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
5252
.add_enum("int8", mshadow::kInt8)
5353
.add_enum("int32", mshadow::kInt32)
5454
.add_enum("int64", mshadow::kInt64)
55+
.add_enum("bool", mshadow::kBool)
5556
.set_default(dmlc::optional<int>())
5657
.describe("The type of the returned array and of the accumulator in which the elements are "
5758
"summed. The dtype of a is used by default unless a has an integer dtype of less "
@@ -221,15 +222,15 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
221222
const std::vector<TBlob>& inputs,
222223
const std::vector<OpReqType>& req,
223224
const std::vector<TBlob>& outputs) {
225+
using namespace mshadow;
224226
if (req[0] == kNullOp) return;
225227
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
226228
if (param.initial.has_value()) {
227229
LOG(FATAL) << "initial is not supported yet";
228230
}
231+
Stream<xpu>* s = ctx.get_stream<xpu>();
229232
if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
230233
using namespace mxnet_op;
231-
using namespace mshadow;
232-
Stream<xpu>* s = ctx.get_stream<xpu>();
233234
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
234235
Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr<DType>());
235236
});
@@ -246,6 +247,13 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
246247
LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays";
247248
}
248249
TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
250+
if (normalize) {
251+
using namespace mshadow::expr;
252+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
253+
auto out = outputs[0].FlatTo2D<xpu, OType>(s);
254+
out /= scalar<OType>(inputs[0].Size()/outputs[0].Size());
255+
});
256+
}
249257
return;
250258
}
251259
#endif

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
257257
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
258258

259259
if (param.dtype.has_value()) {
260-
if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
261-
LOG(FATAL) << "Output cannot be float type when input is integer type for now";
262-
}
263260
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
264261
} else {
265-
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
266-
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
262+
if (common::is_float(in_attrs->at(0))) {
263+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
264+
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
265+
} else {
266+
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
267+
}
267268
}
268269

269270
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;

src/operator/tensor/broadcast_reduce-inl.cuh

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
619619
ReduceImplConfig<ndim> config =
620620
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
621621
if (safe_acc) {
622-
// TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues.
623-
#ifndef _WIN32
624622
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
625623
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
626624
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
@@ -630,17 +628,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
630628
stream, small, req, big, workspace, config);
631629
});
632630
});
633-
#else
634-
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
635-
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
636-
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
637-
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
638-
config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, NULL, NULL);
639-
ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
640-
stream, small, req, big, workspace, config);
641-
});
642-
});
643-
#endif
644631
} else {
645632
ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
646633
}

src/operator/tensor/broadcast_reduce-inl.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,28 +241,15 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
241241
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
242242
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
243243
} else {
244-
// TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues.
245-
#ifndef _WIN32
246244
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
247245
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
248-
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
246+
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
249247
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
250248
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
251249
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
252250
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
253251
});
254252
});
255-
#else
256-
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
257-
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
258-
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
259-
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
260-
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
261-
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
262-
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
263-
});
264-
});
265-
#endif
266253
}
267254
}
268255

src/operator/tensor/broadcast_reduce_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
617617
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
618618
Stream<xpu> *s = ctx.get_stream<xpu>();
619619
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
620-
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
620+
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
621621
const TBlob in_data = inputs[0].reshape(src_shape);
622622
const TBlob out_data = outputs[0].reshape(dst_shape);
623623
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
@@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
10451045
mxnet::TShape src_shape, dst_shape;
10461046
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
10471047
Stream<xpu> *s = ctx.get_stream<xpu>();
1048-
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
1049-
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
1048+
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
1049+
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
10501050
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
10511051
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
10521052
for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {

tests/python/unittest/test_numpy_op.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -614,52 +614,83 @@ def hybrid_forward(self, F, a, *args, **kwargs):
614614
def is_int(dtype):
615615
return 'int' in dtype
616616

617+
is_windows = sys.platform.startswith('win')
617618
in_data_dim = random.choice([2, 3, 4])
618619
shape = rand_shape_nd(in_data_dim, dim=3)
619620
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
620-
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
621+
'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
622+
ft_types = ['float16', 'float32', 'float64']
623+
it_types = ['bool', 'int8', 'int32', 'int64']
621624
for hybridize in [False, True]:
622625
for keepdims in [True, False]:
623626
for axis in ([i for i in range(in_data_dim)] + [(), None]):
624-
for itype in ['float16', 'float32', 'float64']:
625-
for dtype in ['float16', 'float32', 'float64']:
626-
if is_int(dtype) and not is_int(itype):
627-
continue
628-
# test gluon
629-
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
630-
if hybridize:
631-
test_mean.hybridize()
632-
if is_int(itype):
633-
x = _np.random.randint(-128, 128, shape, dtype=itype)
634-
x = mx.nd.array(x, dtype=itype)
635-
else:
636-
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
637-
x = x.as_np_ndarray()
638-
x.attach_grad()
627+
for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types):
628+
if dtype == 'bool':
629+
continue
630+
# test gluon
631+
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
632+
if hybridize:
633+
test_mean.hybridize()
634+
x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
635+
x = x.as_np_ndarray()
636+
x.attach_grad()
639637

640-
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
641-
expected_ret = expected_ret.astype(dtype)
642-
with mx.autograd.record():
643-
y = test_mean(x)
644-
assert y.shape == expected_ret.shape
645-
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
646-
atol=1e-5 if dtype == 'float16' else 1e-5)
638+
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
639+
expected_ret = expected_ret.astype(dtype)
640+
with mx.autograd.record():
641+
y = test_mean(x)
642+
assert y.shape == expected_ret.shape
643+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
644+
atol=1e-5 if dtype == 'float16' else 1e-5)
647645

648-
y.backward()
649-
N = x.size / y.size
650-
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
646+
y.backward()
647+
N = x.size / y.size
648+
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
651649

652-
# test numeric
653-
if itype == 'float32' and dtype == 'float32':
654-
x_sym = mx.sym.Variable("x").as_np_ndarray()
655-
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
656-
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
657-
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
650+
# test numeric
651+
if itype == 'float32' and dtype == 'float32':
652+
x_sym = mx.sym.Variable("x").as_np_ndarray()
653+
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
654+
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
655+
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
658656

659-
# test imperative
660-
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
661-
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
662-
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
657+
# test imperative
658+
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
659+
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
660+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
661+
662+
for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types):
663+
if dtype == 'bool':
664+
continue
665+
# test gluon
666+
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
667+
if hybridize:
668+
test_mean.hybridize()
669+
670+
if itype == 'bool':
671+
x = np.array(_np.random.uniform(size=shape) > 0.5)
672+
else:
673+
x = np.random.uniform(-128, 127, size=shape).astype(itype)
674+
675+
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims)
676+
677+
if itype == 'bool':
678+
if is_op_runnable() and (not is_windows) and dtype not in ['float16', 'int8']: # special handling of boolean ndarray
679+
y = test_mean(x)
680+
assert y.shape == expected_ret.shape
681+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
682+
atol=1e-5 if dtype == 'float16' else 1e-5)
683+
continue
684+
685+
y = test_mean(x)
686+
assert y.shape == expected_ret.shape
687+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
688+
atol=1e-5 if dtype == 'float16' else 1e-5)
689+
690+
# test imperative
691+
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
692+
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype)
693+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
663694

664695

665696
@with_seed()

0 commit comments

Comments
 (0)