Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 273e19c

Browse files
committed
fix mean output type for integer inputs
1 parent 3885bbe commit 273e19c

File tree

6 files changed

+70
-47
lines changed

6 files changed

+70
-47
lines changed

src/ndarray/ndarray_function-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
379379
template<>
380380
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
381381
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
382-
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
382+
MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
383383
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
384384
});
385385
}

src/operator/numpy/np_broadcast_reduce_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
5252
.add_enum("int8", mshadow::kInt8)
5353
.add_enum("int32", mshadow::kInt32)
5454
.add_enum("int64", mshadow::kInt64)
55+
.add_enum("bool", mshadow::kBool)
5556
.set_default(dmlc::optional<int>())
5657
.describe("The type of the returned array and of the accumulator in which the elements are "
5758
"summed. The dtype of a is used by default unless a has an integer dtype of less "

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
257257
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
258258

259259
if (param.dtype.has_value()) {
260-
if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
261-
LOG(FATAL) << "Output cannot be float type when input is integer type for now";
262-
}
263260
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
264261
} else {
265-
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
266-
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
262+
if (common::is_float(in_attrs->at(0))) {
263+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
264+
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
265+
} else {
266+
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
267+
}
267268
}
268269

269270
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;

src/operator/tensor/broadcast_reduce-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
245245
#ifndef _WIN32
246246
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
247247
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
248-
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
248+
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
249249
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
250250
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
251251
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
@@ -255,7 +255,7 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
255255
#else
256256
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
257257
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
258-
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
258+
MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
259259
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
260260
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
261261
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),

src/operator/tensor/broadcast_reduce_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
617617
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
618618
Stream<xpu> *s = ctx.get_stream<xpu>();
619619
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
620-
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
620+
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
621621
const TBlob in_data = inputs[0].reshape(src_shape);
622622
const TBlob out_data = outputs[0].reshape(dst_shape);
623623
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
@@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
10451045
mxnet::TShape src_shape, dst_shape;
10461046
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
10471047
Stream<xpu> *s = ctx.get_stream<xpu>();
1048-
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
1049-
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
1048+
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
1049+
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
10501050
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
10511051
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
10521052
for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {

tests/python/unittest/test_numpy_op.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -617,49 +617,70 @@ def is_int(dtype):
617617
in_data_dim = random.choice([2, 3, 4])
618618
shape = rand_shape_nd(in_data_dim, dim=3)
619619
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
620-
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
620+
'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
621+
ft_types = ['float16', 'float32', 'float64']
622+
it_types = ['bool', 'int8', 'int32', 'int64']
621623
for hybridize in [False, True]:
622624
for keepdims in [True, False]:
623625
for axis in ([i for i in range(in_data_dim)] + [(), None]):
624-
for itype in ['float16', 'float32', 'float64']:
625-
for dtype in ['float16', 'float32', 'float64']:
626-
if is_int(dtype) and not is_int(itype):
627-
continue
628-
# test gluon
629-
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
630-
if hybridize:
631-
test_mean.hybridize()
632-
if is_int(itype):
633-
x = _np.random.randint(-128, 128, shape, dtype=itype)
634-
x = mx.nd.array(x, dtype=itype)
635-
else:
636-
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
637-
x = x.as_np_ndarray()
638-
x.attach_grad()
626+
for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types):
627+
if dtype == 'bool':
628+
continue
629+
# test gluon
630+
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
631+
if hybridize:
632+
test_mean.hybridize()
633+
x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
634+
x = x.as_np_ndarray()
635+
x.attach_grad()
639636

640-
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
641-
expected_ret = expected_ret.astype(dtype)
642-
with mx.autograd.record():
643-
y = test_mean(x)
644-
assert y.shape == expected_ret.shape
645-
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
646-
atol=1e-5 if dtype == 'float16' else 1e-5)
637+
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
638+
expected_ret = expected_ret.astype(dtype)
639+
with mx.autograd.record():
640+
y = test_mean(x)
641+
assert y.shape == expected_ret.shape
642+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
643+
atol=1e-5 if dtype == 'float16' else 1e-5)
647644

648-
y.backward()
649-
N = x.size / y.size
650-
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
645+
y.backward()
646+
N = x.size / y.size
647+
assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N)
651648

652-
# test numeric
653-
if itype == 'float32' and dtype == 'float32':
654-
x_sym = mx.sym.Variable("x").as_np_ndarray()
655-
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
656-
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
657-
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
649+
# test numeric
650+
if itype == 'float32' and dtype == 'float32':
651+
x_sym = mx.sym.Variable("x").as_np_ndarray()
652+
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
653+
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
654+
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
658655

659-
# test imperative
660-
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
661-
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
662-
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
656+
# test imperative
657+
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
658+
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
659+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
660+
661+
for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types):
662+
if dtype == 'bool':
663+
continue
664+
# test gluon
665+
test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims)
666+
if hybridize:
667+
test_mean.hybridize()
668+
669+
if itype == 'bool':
670+
x = np.random.uniform(size=shape) > 0.5
671+
else:
672+
x = np.random.uniform(-128, 127, size=shape).astype(itype)
673+
674+
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims)
675+
y = test_mean(x)
676+
assert y.shape == expected_ret.shape
677+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3,
678+
atol=1e-5 if dtype == 'float16' else 1e-5)
679+
680+
# test imperative
681+
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
682+
np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype)
683+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
663684

664685

665686
@with_seed()

0 commit comments

Comments
 (0)