Skip to content

Commit 38116fd

Browse files
reminiscezheng-da
authored andcommitted
[MXNET-283] Error handling for non-positive reps of tile op (apache#10417)
* Error handling for non-positive reps of tile op * Address cr * Fix unit test
1 parent 13b44af commit 38116fd

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,8 @@ struct TileParam : public dmlc::Parameter<TileParam> {
14421442
TShape reps;
14431443
DMLC_DECLARE_PARAMETER(TileParam) {
14441444
DMLC_DECLARE_FIELD(reps)
1445-
.describe("The number of times for repeating the tensor a."
1445+
.describe("The number of times for repeating the tensor a. Each dim size of reps"
1446+
" must be a positive integer."
14461447
" If reps has length d, the result will have dimension of max(d, a.ndim);"
14471448
" If a.ndim < d, a is promoted to be d-dimensional by prepending new axes."
14481449
" If a.ndim > d, reps is promoted to a.ndim by pre-pending 1's to it.");
@@ -1462,6 +1463,9 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs,
14621463
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape);
14631464
return true;
14641465
}
1466+
for (size_t i = 0; i < reps.ndim(); ++i) {
1467+
CHECK_GT(reps[i], 0) << "invalid reps=" << i << ", dim size must be greater than zero";
1468+
}
14651469
TShape oshape(std::max(ishape.ndim(), reps.ndim()));
14661470
int i1 = static_cast<int>(ishape.ndim()) - 1;
14671471
int i2 = static_cast<int>(reps.ndim()) - 1;

tests/python/unittest/test_operator.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import itertools
2525
from numpy.testing import assert_allclose, assert_array_equal
2626
from mxnet.test_utils import *
27-
from mxnet.base import py_str
27+
from mxnet.base import py_str, MXNetError
2828
from common import setup_module, with_seed
2929
import unittest
3030

@@ -3485,24 +3485,23 @@ def test_reverse():
34853485
@with_seed()
34863486
def test_tile():
34873487
def test_normal_case():
3488-
ndim_max = 3 # max number of dims of the ndarray
3489-
size_max = 10 # max number of elements in each dim
3490-
length_max = 3 # max length of reps
3491-
rep_max = 10 # max number of tiling in each dim
3492-
for ndim in range(ndim_max, ndim_max+1):
3493-
shape = ()
3494-
for i in range(0, ndim):
3495-
shape += (np.random.randint(1, size_max+1), )
3488+
ndim_min = 1
3489+
ndim_max = 5 # max number of dims of the ndarray
3490+
size_max = 10 # max number of elements in each dim
3491+
length_max = 3 # max length of reps
3492+
rep_max = 10 # max number of tiling in each dim
3493+
for ndim in range(ndim_min, ndim_max+1):
3494+
shape = []
3495+
for i in range(1, ndim+1):
3496+
shape.append(np.random.randint(1, size_max+1))
3497+
shape = tuple(shape)
34963498
a = np.random.randint(0, 100, shape)
3497-
a = np.asarray(a, dtype=np.int32)
3498-
if ndim == 0:
3499-
a = np.array([])
3500-
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)
3499+
b = mx.nd.array(a, dtype=a.dtype)
35013500

3502-
reps_len = np.random.randint(0, length_max+1)
3501+
reps_len = np.random.randint(1, length_max+1)
35033502
reps_tuple = ()
35043503
for i in range(1, reps_len):
3505-
reps_tuple += (np.random.randint(0, rep_max), )
3504+
reps_tuple += (np.random.randint(1, rep_max), )
35063505
reps_array = np.asarray(reps_tuple)
35073506

35083507
a_tiled = np.tile(a, reps_array)
@@ -3526,14 +3525,6 @@ def test_empty_reps():
35263525
b_tiled = mx.nd.tile(b, ()).asnumpy()
35273526
assert same(a_tiled, b_tiled)
35283527

3529-
def test_zero_reps():
3530-
a = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)
3531-
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)
3532-
reps = (2, 0, 4, 5)
3533-
a_tiled = np.tile(a, reps)
3534-
b_tiled = mx.nd.tile(b, reps).asnumpy()
3535-
assert same(a_tiled, b_tiled)
3536-
35373528
def test_tile_backward():
35383529
data = mx.sym.Variable('data')
35393530
n1 = 2
@@ -3570,12 +3561,17 @@ def test_tile_numeric_gradient():
35703561
test = mx.sym.tile(data, reps=reps)
35713562
check_numeric_gradient(test, [data_tmp], numeric_eps=1e-2, rtol=1e-2)
35723563

3564+
def test_invalid_reps():
3565+
data = mx.nd.arange(16).reshape((4, 4))
3566+
assert_exception(mx.nd.tile, MXNetError, data, (1, 2, -3))
3567+
assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3))
3568+
35733569
test_normal_case()
35743570
test_empty_tensor()
35753571
test_empty_reps()
3576-
test_zero_reps()
35773572
test_tile_backward()
35783573
test_tile_numeric_gradient()
3574+
test_invalid_reps()
35793575

35803576

35813577
@with_seed()

0 commit comments

Comments
 (0)