diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 020c0d17f0d1..74ab0e885277 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1060,10 +1060,15 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_h nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *source = static_cast(sym_handle); - CHECK_EQ(source->outputs.size(), 1U) - << "Generating atomic symbol from other symbol only works for nongrouped symbol."; - const auto &node = source->outputs[0]; - const auto *op = node.node->op(); + CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs."; + const auto &node = source->outputs[0].node; + for (const auto &other_node : source->outputs) { + if (node.get() != other_node.node.get()) { + LOG(FATAL) + << "Generating atomic symbol from other symbol only works for nongrouped symbol."; + } + } + const auto *op = node->op(); const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow); *s = nnvm::Symbol::CreateFunctor(op, attrs); *ret_sym_handle = s; diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 963b32493b44..28f302f9ec15 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -389,6 +389,15 @@ def test_children_same_name(): for c in b.get_children(): pass +def test_gen_atomic_symbol_multiple_outputs(): + data=mx.sym.Variable('data') + p = mx.sym.Variable('param') + h0 = mx.sym.Variable('h0') + h1 = mx.sym.Variable('h1') + s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2, + bidirectional=True, state_outputs=True, mode='lstm') + atomic_sym = s._gen_atomic_symbol() + if __name__ == '__main__': import nose nose.runmodule()