@@ -51,8 +51,9 @@ def network_structure_3():
51
51
ret = ret1 + ret2
52
52
ret = mx .sym .BatchNorm (ret )
53
53
ret = mx .sym .BatchNorm (ret )
54
- return (ret , ['data' ], [(2 , 3 , 10 , 10 )])
55
-
54
+ # Return the same and shape of 'data' and auxiliary states
55
+ return (ret , ['data' , * ret .list_auxiliary_states ()], [(2 ,3 ,10 ,10 ), (3 ,), (3 ,), (3 ,), (3 ,)])
56
+
56
57
def network_structure_4 ():
57
58
# the last op has multiple duplicate outputs
58
59
data = mx .sym .var ('data' , shape = (2 , 3 , 10 , 10 ))
@@ -84,23 +85,24 @@ def network_structure_7():
84
85
return (ret , ['data' ], [(1 ,)])
85
86
86
87
def get_graphs ():
87
- return [(network_structure_1 (), ['Convolution' ]),
88
+ return [
89
+ (network_structure_1 (), ['Convolution' ]),
88
90
(network_structure_2 (), ['exp' , 'sin' , '_Plus' , 'elemwise_add' , '_plus' ]),
89
91
(network_structure_2 (), ['exp' , 'cos' , '_Plus' , 'elemwise_add' , '_plus' ]),
90
- # To do: fix batch norm issue for gluon tests.
91
- #(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
92
- #(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
93
- #(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
94
- #(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
95
- #(network_structure_3(), ['exp', 'BatchNorm']),
96
- #(network_structure_3(), ['BatchNorm']),
92
+ (network_structure_3 (), ['exp' , 'sin' , '_Plus' , 'elemwise_add' , '_plus' ]),
93
+ (network_structure_3 (), ['exp' , 'cos' , '_Plus' , 'elemwise_add' , '_plus' ]),
94
+ (network_structure_3 (), ['exp' , 'sin' , '_Plus' , 'elemwise_add' , '_plus' , 'BatchNorm' ]),
95
+ (network_structure_3 (), ['exp' , 'cos' , '_Plus' , 'elemwise_add' , '_plus' , 'BatchNorm' ]),
96
+ (network_structure_3 (), ['exp' , 'BatchNorm' ]),
97
+ (network_structure_3 (), ['BatchNorm' ]),
97
98
(network_structure_4 (), ['exp' ]),
98
99
(network_structure_5 (), ['_plus' , '_Plus' , 'elemwise_add' ]),
99
100
(network_structure_6 (), []),
100
101
(network_structure_6 (), [mx .sym .sin .__name__ ]),
101
102
(network_structure_6 (), [mx .sym .Convolution .__name__ ]),
102
103
(network_structure_6 (), [mx .sym .sin .__name__ , mx .sym .Convolution .__name__ ]),
103
- (network_structure_7 (), ['sin' , 'elemwise_add' , '_plus' , '_Plus' ])]
104
+ (network_structure_7 (), ['sin' , 'elemwise_add' , '_plus' , '_Plus' ])
105
+ ]
104
106
105
107
def check_subgraph_exe1 (sym , subgraph_backend , op_names ):
106
108
"""Use the partitioned sym to simple_bind an executor and compare the outputs
0 commit comments