Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6e64ee9

Browse files
committed
add relevant tests
1 parent d962df3 commit 6e64ee9

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tests/python/unittest/test_symbol.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from common import assertRaises, models
2424
from mxnet.base import NotImplementedForSymbol
25-
from mxnet.test_utils import discard_stderr
25+
from mxnet.test_utils import discard_stderr, rand_shape_nd
2626
import pickle as pkl
2727

2828
def test_symbol_basic():
@@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
188188
assert arg_shapes[1] == overwrite_shape
189189
assert out_shapes[0] == overwrite_shape
190190

191+
192+
def test_symbol_magic_abs():
193+
for dim in range(1, 7):
194+
with mx.name.NameManager():
195+
data = mx.symbol.Variable('data')
196+
method = data.abs(name='abs0')
197+
magic = abs(data)
198+
regular = mx.symbol.abs(data, name='abs0')
199+
ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)}
200+
mx.test_utils.check_consistency(
201+
[method, magic], ctx_list=[ctx, ctx])
202+
mx.test_utils.check_consistency(
203+
[regular, magic], ctx_list=[ctx, ctx])
204+
205+
191206
def test_symbol_fluent():
192207
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
193208
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',

0 commit comments

Comments
 (0)