Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7228343

Browse files
committed
Fixed auxiliary state issue for Gluon partition
1 parent 434c3c7 commit 7228343

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

python/mxnet/gluon/block.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def _build_cache(self, *args):
962962
if name in data_names.keys():
963963
arg_array.append(args[data_names[name]])
964964
else:
965-
if ctx == None:
965+
if ctx is None:
966966
ctx = params.get(name)._ctx_list[0]
967967
arg_array.append(ndarray.random.uniform(shape=params.get(name)._shape))
968968
# Partition the graph.
@@ -1053,9 +1053,12 @@ def register_child(self, block, name=None):
10531053
super(HybridBlock, self).register_child(block, name)
10541054
self._clear_cached_op()
10551055

1056-
def hybridize(self, active=True, backend=None, backend_args={}, **kwargs):
1056+
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs):
10571057
self._backend = backend
1058-
self._backend_args = backend_args
1058+
if backend_args is None:
1059+
self._backend_args = {}
1060+
else:
1061+
self._backend_args = backend_args
10591062
self._active = active
10601063
self._flags = list(kwargs.items())
10611064
self._clear_cached_op()

tests/python/unittest/test_subgraph_op.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def network_structure_3():
5151
ret = ret1 + ret2
5252
ret = mx.sym.BatchNorm(ret)
5353
ret = mx.sym.BatchNorm(ret)
54-
return (ret, ['data'], [(2, 3, 10, 10)])
55-
54+
# Return the same and shape of 'data' and auxiliary states
55+
return (ret, ['data', *ret.list_auxiliary_states()], [(2,3,10,10), (3,), (3,), (3,), (3,)])
56+
5657
def network_structure_4():
5758
# the last op has multiple duplicate outputs
5859
data = mx.sym.var('data', shape=(2, 3, 10, 10))
@@ -84,23 +85,24 @@ def network_structure_7():
8485
return (ret, ['data'], [(1,)])
8586

8687
def get_graphs():
87-
return [(network_structure_1(), ['Convolution']),
88+
return [
89+
(network_structure_1(), ['Convolution']),
8890
(network_structure_2(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
8991
(network_structure_2(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
90-
# To do: fix batch norm issue for gluon tests.
91-
#(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
92-
#(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
93-
#(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
94-
#(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
95-
#(network_structure_3(), ['exp', 'BatchNorm']),
96-
#(network_structure_3(), ['BatchNorm']),
92+
(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
93+
(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
94+
(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
95+
(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
96+
(network_structure_3(), ['exp', 'BatchNorm']),
97+
(network_structure_3(), ['BatchNorm']),
9798
(network_structure_4(), ['exp']),
9899
(network_structure_5(), ['_plus', '_Plus', 'elemwise_add']),
99100
(network_structure_6(), []),
100101
(network_structure_6(), [mx.sym.sin.__name__]),
101102
(network_structure_6(), [mx.sym.Convolution.__name__]),
102103
(network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]),
103-
(network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus'])]
104+
(network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus'])
105+
]
104106

105107
def check_subgraph_exe1(sym, subgraph_backend, op_names):
106108
"""Use the partitioned sym to simple_bind an executor and compare the outputs

0 commit comments

Comments
 (0)