Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b5751e9

Browse files
committed
fix
1 parent 9884b32 commit b5751e9

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

python/mxnet/ndarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def __le__(self, other):
258258
return lesser_equal(self, other)
259259

260260
def __bool__(self):
261-
raise ValueError("The truth value of an NDArray with more than one element is ambiguous.")
261+
raise ValueError("The truth value of an NDArray is ambiguous. " \
262+
"Please convert to number with asscalar() first.")
263+
262264
__nonzero__ = __bool__
263265

264266
def __getstate__(self):

python/mxnet/operator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def infer_shape(self, in_shape):
471471
List of aux shapes calculated from in_shape,
472472
in the same order as declared in list_auxiliary_states.
473473
"""
474-
return in_shape, [in_shape[0]], []
474+
return in_shape, (in_shape[0],)*len(self.list_outputs()), ()
475475

476476
def infer_type(self, in_type):
477477
"""infer_type interface. override to create new operators
@@ -753,9 +753,7 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
753753
NDArrayHandle),
754754
writable=False))
755755
reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
756-
op.forward(is_train=is_train, req=reqs,
757-
in_data=tensors[0], out_data=tensors[1],
758-
aux=tensors[4])
756+
op.forward(is_train, reqs, tensors[0], tensors[1], tensors[4])
759757
except Exception:
760758
print('Error in CustomOp.forward: %s' % traceback.format_exc())
761759
return False
@@ -776,10 +774,8 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
776774
NDArrayHandle),
777775
writable=False))
778776
reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
779-
op.backward(req=reqs,
780-
in_data=tensors[0], out_data=tensors[1],
781-
in_grad=tensors[2], out_grad=tensors[3],
782-
aux=tensors[4])
777+
op.backward(reqs, tensors[0], tensors[1], tensors[2],
778+
tensors[3], tensors[4])
783779
except Exception:
784780
print('Error in CustomOp.backward: %s' % traceback.format_exc())
785781
return False

0 commit comments

Comments
 (0)