Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit cf28b46

Browse files
kshitij12345wkcn
authored andcommitted
Add magic method abs to NDArray and Symbol. (#15680)
* add magic method abs to ndarray * add relevant tests * add magic method abs to symbol * add relevant tests * retrigger CI * retrigger CI
1 parent b07211f commit cf28b46

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

python/mxnet/ndarray/ndarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def _to_shared_mem(self):
205205
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
206206
return shared_pid.value, shared_id.value, self.shape, self.dtype
207207

208+
def __abs__(self):
209+
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
210+
return self.abs()
211+
208212
def __add__(self, other):
209213
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
210214
return add(self, other)

python/mxnet/symbol/symbol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __iter__(self):
9393
"""
9494
return (self[i] for i in range(len(self)))
9595

96+
def __abs__(self):
97+
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
98+
return self.abs()
99+
96100
def __add__(self, other):
97101
"""x.__add__(y) <=> x+y
98102

tests/python/unittest/test_ndarray.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def test_ndarray_negate():
172172
assert_almost_equal(npy, arr.asnumpy())
173173

174174

175+
@with_seed()
176+
def test_ndarray_magic_abs():
177+
for dim in range(1, 7):
178+
shape = rand_shape_nd(dim)
179+
npy = np.random.uniform(-10, 10, shape)
180+
arr = mx.nd.array(npy)
181+
assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())
182+
183+
175184
@with_seed()
176185
def test_ndarray_reshape():
177186
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)

tests/python/unittest/test_symbol.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from common import assertRaises, models
2424
from mxnet.base import NotImplementedForSymbol
25-
from mxnet.test_utils import discard_stderr
25+
from mxnet.test_utils import discard_stderr, rand_shape_nd
2626
import pickle as pkl
2727

2828
def test_symbol_basic():
@@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
188188
assert arg_shapes[1] == overwrite_shape
189189
assert out_shapes[0] == overwrite_shape
190190

191+
192+
def test_symbol_magic_abs():
193+
for dim in range(1, 7):
194+
with mx.name.NameManager():
195+
data = mx.symbol.Variable('data')
196+
method = data.abs(name='abs0')
197+
magic = abs(data)
198+
regular = mx.symbol.abs(data, name='abs0')
199+
ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)}
200+
mx.test_utils.check_consistency(
201+
[method, magic], ctx_list=[ctx, ctx])
202+
mx.test_utils.check_consistency(
203+
[regular, magic], ctx_list=[ctx, ctx])
204+
205+
191206
def test_symbol_fluent():
192207
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
193208
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',

0 commit comments

Comments
 (0)