Skip to content

Commit d79c12d

Browse files
committed
refer to [PR Allow operators with multiple outputs in get_atomic_symbol apache#15740](apache#15740)
1 parent 87fe065 commit d79c12d

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

benchmark/python/gluon/benchmark_gluon.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
'By default, use CPU only.')
4747
parser.add_argument('--type', type=str, default='inference', choices=['all', 'training', 'inference'])
4848

49-
opt = parser.parse_args()
49+
args = parser.parse_args()
5050

51-
num_batches = opt.num_batches
51+
num_batches = args.num_batches
5252
dry_run = 10 # use 10 iterations to warm up
5353
batch_inf = [1, 32, 64, 128, 256]
5454
batch_train = [1, 32, 64, 128, 256]
@@ -116,10 +116,10 @@ def train(network, batch_size, ctx):
116116
return bwd
117117

118118
if __name__ == '__main__':
119-
runtype = opt.type
120-
bs = opt.batch_size
119+
runtype = args.type
120+
bs = args.batch_size
121121

122-
if opt.model == 'all':
122+
if args.model == 'all':
123123
networks = ['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
124124
'inceptionv3', 'mobilenet0.25', 'mobilenet0.5', 'mobilenet0.75',
125125
'mobilenet1.0', 'mobilenetv2_0.25', 'mobilenetv2_0.5', 'mobilenetv2_0.75',
@@ -130,9 +130,9 @@ def train(network, batch_size, ctx):
130130
logging.info('It may take some time to run all models, '
131131
'set --network to run a specific one')
132132
else:
133-
networks = [opt.model]
133+
networks = [args.model]
134134

135-
devs = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
135+
devs = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus.strip() else [mx.cpu()]
136136
num_gpus = len(devs)
137137

138138
for network in networks:

src/c_api/c_api_symbolic.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -854,10 +854,15 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_h
854854
nnvm::Symbol *s = new nnvm::Symbol();
855855
API_BEGIN();
856856
nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
857-
CHECK_EQ(source->outputs.size(), 1U)
858-
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
859-
const auto& node = source->outputs[0];
860-
const auto *op = node.node->op();
857+
CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs.";
858+
<< "Generating atomic symbol from other symbol only works for nongrouped symbol."; const auto &node = source->outputs[0].node;
859+
const auto &node = source->outputs[0]; for (const auto &other_node : source->outputs) {
860+
const auto *op = node.node->op(); if (node.get() != other_node.node.get()) {
861+
LOG(FATAL)
862+
<< "Generating atomic symbol from other symbol only works for nongrouped symbol.";
863+
}
864+
}
865+
const auto *op = node->op();
861866
const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
862867
*s = nnvm::Symbol::CreateFunctor(op, attrs);
863868
*ret_sym_handle = s;

tests/python/unittest/test_symbol.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ def test_children_same_name():
373373
for c in b.get_children():
374374
pass
375375

376+
def test_gen_atomic_symbol_multiple_outputs():
377+
data=mx.sym.Variable('data')
378+
p = mx.sym.Variable('param')
379+
h0 = mx.sym.Variable('h0')
380+
h1 = mx.sym.Variable('h1')
381+
s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2,
382+
bidirectional=True, state_outputs=True, mode='lstm')
383+
atomic_sym = s._gen_atomic_symbol()
384+
376385
if __name__ == '__main__':
377386
import nose
378387
nose.runmodule()

0 commit comments

Comments
 (0)