@@ -471,7 +471,7 @@ def infer_shape(self, in_shape):
471
471
List of aux shapes calculated from in_shape,
472
472
in the same order as declared in list_auxiliary_states.
473
473
"""
474
- return in_shape , [ in_shape [0 ]], []
474
+ return in_shape , ( in_shape [0 ],) * len ( self . list_outputs ()), ()
475
475
476
476
def infer_type (self , in_type ):
477
477
"""infer_type interface. override to create new operators
@@ -753,9 +753,7 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
753
753
NDArrayHandle ),
754
754
writable = False ))
755
755
reqs = [req_enum [reqs [i ]] for i in range (len (tensors [1 ]))]
756
- op .forward (is_train = is_train , req = reqs ,
757
- in_data = tensors [0 ], out_data = tensors [1 ],
758
- aux = tensors [4 ])
756
+ op .forward (is_train , reqs , tensors [0 ], tensors [1 ], tensors [4 ])
759
757
except Exception :
760
758
print ('Error in CustomOp.forward: %s' % traceback .format_exc ())
761
759
return False
@@ -776,10 +774,8 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
776
774
NDArrayHandle ),
777
775
writable = False ))
778
776
reqs = [req_enum [reqs [i ]] for i in range (len (tensors [2 ]))]
779
- op .backward (req = reqs ,
780
- in_data = tensors [0 ], out_data = tensors [1 ],
781
- in_grad = tensors [2 ], out_grad = tensors [3 ],
782
- aux = tensors [4 ])
777
+ op .backward (reqs , tensors [0 ], tensors [1 ], tensors [2 ],
778
+ tensors [3 ], tensors [4 ])
783
779
except Exception :
784
780
print ('Error in CustomOp.backward: %s' % traceback .format_exc ())
785
781
return False
0 commit comments