|
22 | 22 | import numpy as np
|
23 | 23 | from common import assertRaises, models
|
24 | 24 | from mxnet.base import NotImplementedForSymbol
|
25 |
| -from mxnet.test_utils import discard_stderr |
| 25 | +from mxnet.test_utils import discard_stderr, rand_shape_nd |
26 | 26 | import pickle as pkl
|
27 | 27 |
|
28 | 28 | def test_symbol_basic():
|
@@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
|
188 | 188 | assert arg_shapes[1] == overwrite_shape
|
189 | 189 | assert out_shapes[0] == overwrite_shape
|
190 | 190 |
|
| 191 | + |
| 192 | +def test_symbol_magic_abs(): |
| 193 | + for dim in range(1, 7): |
| 194 | + with mx.name.NameManager(): |
| 195 | + data = mx.symbol.Variable('data') |
| 196 | + method = data.abs(name='abs0') |
| 197 | + magic = abs(data) |
| 198 | + regular = mx.symbol.abs(data, name='abs0') |
| 199 | + ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)} |
| 200 | + mx.test_utils.check_consistency( |
| 201 | + [method, magic], ctx_list=[ctx, ctx]) |
| 202 | + mx.test_utils.check_consistency( |
| 203 | + [regular, magic], ctx_list=[ctx, ctx]) |
| 204 | + |
| 205 | + |
191 | 206 | def test_symbol_fluent():
|
192 | 207 | has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
|
193 | 208 | 'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
|
|
0 commit comments