Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1612533

Browse files
zburninghaojin2
authored andcommitted
refactor gluon.utils.split_data() following np.array_split() (#17123)
* fix * fix & add test * add mx.numpy test * fix name * fix mis input * fix test
1 parent 83a23b0 commit 1612533

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

python/mxnet/gluon/utils.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,30 +70,18 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
7070
"uneven partitioning of data."%(
7171
str(data.shape), num_slice, batch_axis, num_slice))
7272

73-
step = size // num_slice
74-
75-
# If size < num_slice, make fewer slices
76-
if not even_split and size < num_slice:
77-
step = 1
78-
num_slice = size
79-
80-
if batch_axis == 0:
81-
slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
82-
for i in range(num_slice)]
83-
elif even_split:
84-
if is_np_array():
85-
slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis)
86-
else:
87-
slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis)
73+
n_each_section, extras = divmod(size, num_slice)
74+
section_sizes = [0] + (extras * [n_each_section + 1] +
75+
(num_slice - extras) * [n_each_section])
76+
div_points = np.array(section_sizes).cumsum()
77+
if is_np_array():
78+
slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis)
8879
else:
89-
if is_np_array():
90-
indices = [step * i for i in range(1, num_slice)]
91-
slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis)
92-
else:
93-
slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step)
94-
if i < num_slice - 1 else
95-
ndarray.slice_axis(data, batch_axis, i*step, size)
96-
for i in range(num_slice)]
80+
slices = []
81+
for i in range(num_slice):
82+
st = div_points[i]
83+
end = div_points[i + 1]
84+
slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
9785
return slices
9886

9987

tests/python/unittest/test_gluon.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from mxnet.gluon import nn
2424
from mxnet.base import py_str
2525
from mxnet.test_utils import assert_almost_equal
26+
from mxnet.util import is_np_array
2627
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
28+
from mxnet.test_utils import use_np
29+
import mxnet.numpy as _mx_np
2730
from common import (setup_module, with_seed, assertRaises, teardown,
2831
assert_raises_cudnn_not_satisfied)
2932
import numpy as np
@@ -952,17 +955,39 @@ def test_deferred_init():
952955
layer(x)
953956

954957

958+
955959
def check_split_data(x, num_slice, batch_axis, **kwargs):
956960
res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs)
957961
assert len(res) == num_slice
958-
mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(),
959-
x.asnumpy())
962+
if not is_np_array():
963+
mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(),
964+
x.asnumpy())
965+
else:
966+
mx.test_utils.assert_almost_equal(_mx_np.concatenate(res, axis=batch_axis).asnumpy(),
967+
x.asnumpy())
968+
np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis)
969+
res_asnp = [s.asnumpy() for s in res]
970+
for r1, r2 in zip(np_res, res_asnp):
971+
assert all(r1.reshape(-1) == r2.reshape(-1))
960972

961973

974+
@with_seed()
975+
@use_np
976+
def test_split_data_np():
977+
x = _mx_np.random.uniform(size=(128, 33, 64))
978+
check_split_data(x, 8, 0)
979+
check_split_data(x, 3, 1)
980+
check_split_data(x, 4, 1, even_split=False)
981+
check_split_data(x, 15, 1, even_split=False)
982+
try:
983+
check_split_data(x, 4, 1)
984+
except ValueError:
985+
return
986+
assert False, "Should have failed"
987+
962988
@with_seed()
963989
def test_split_data():
964990
x = mx.nd.random.uniform(shape=(128, 33, 64))
965-
966991
check_split_data(x, 8, 0)
967992
check_split_data(x, 3, 1)
968993
check_split_data(x, 4, 1, even_split=False)
@@ -973,7 +998,6 @@ def test_split_data():
973998
return
974999
assert False, "Should have failed"
9751000

976-
9771001
@with_seed()
9781002
def test_flatten():
9791003
flatten = nn.Flatten()

0 commit comments

Comments
 (0)