Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d464a47

Browse files
ZhennanQinpengzhao-intel
authored andcommitted
Fix subgraph with custom_op (#15671)
1 parent 3d366a3 commit d464a47

File tree

5 files changed

+81
-3
lines changed

5 files changed

+81
-3
lines changed

src/c_api/c_api_symbolic.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,8 +1046,10 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
10461046
for (auto property : subgraph_prop_list) {
10471047
nnvm::Graph g = Symbol2Graph(*s);
10481048
property->SetAttr("graph", g);
1049-
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
1049+
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
10501050
g = ApplyPass(std::move(g), "BuildSubgraph");
1051+
property->RemoveAttr("graph");
1052+
g.attrs.erase("subgraph_property");
10511053
s->outputs = g.outputs;
10521054
}
10531055
*ret_sym_handle = s;

src/c_api/c_api_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
5050
g.outputs = s->outputs;
5151
property->SetAttr("graph", g);
5252
property->SetAttr("op_names", op_name_set);
53-
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
53+
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
5454
g = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
55+
property->RemoveAttr("graph");
56+
g.attrs.erase("subgraph_property");
5557
s->outputs = g.outputs;
5658
}
5759
}

src/executor/graph_executor.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1688,8 +1688,10 @@ static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, op::SubgraphPropertyP
16881688
g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
16891689
aux_state_ctxes, true);
16901690
subgraph_prop->SetAttr("graph", g);
1691-
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
1691+
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(subgraph_prop);
16921692
g = ApplyPass(std::move(g), "BuildSubgraph");
1693+
subgraph_prop->RemoveAttr("graph");
1694+
g.attrs.erase("subgraph_property");
16931695
ret.outputs = g.outputs;
16941696
return ret;
16951697
}

src/operator/subgraph/subgraph_property.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ class SubgraphProperty {
330330
auto it = attrs_.find(name);
331331
return it != attrs_.end();
332332
}
333+
/*!
334+
* \brief Remove attr if the attr exists.
335+
*/
336+
void RemoveAttr(const std::string& name) {
337+
auto it = attrs_.find(name);
338+
if (it != attrs_.end()) {
339+
attrs_.erase(it);
340+
}
341+
}
333342
/*!
334343
* \brief Get the property type.
335344
*/
@@ -384,6 +393,16 @@ class SubgraphBackend {
384393
return it != attrs_.end();
385394
}
386395

396+
/*!
397+
* \brief Remove attr if the attr exists.
398+
*/
399+
void RemoveAttr(const std::string& name) {
400+
auto it = attrs_.find(name);
401+
if (it != attrs_.end()) {
402+
attrs_.erase(it);
403+
}
404+
}
405+
387406
SubgraphPropertyPtr& RegisterSubgraphProperty(const SubgraphPropertyPtr prop) {
388407
prop_ptr_.push_back(prop);
389408
return prop_ptr_.back();

tests/python/unittest/test_subgraph.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,59 @@ def make_subgraph4(stype):
144144
rtol=0.001, atol=0.0001)
145145

146146

147+
def test_subgraph_with_customOp():
148+
class MyAdd(mx.operator.CustomOp):
149+
def forward(self, is_train, req, in_data, out_data, aux):
150+
self.assign(out_data[0], req[0], in_data[0] + 1)
151+
152+
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
153+
self.assign(in_grad[0], req[0], out_grad[0])
154+
155+
@mx.operator.register('MyAdd1')
156+
class MyAdd1Prop(mx.operator.CustomOpProp):
157+
def __init__(self):
158+
super(MyAdd1Prop, self).__init__(need_top_grad=True)
159+
160+
def list_arguments(self):
161+
return ['data']
162+
163+
def list_outputs(self):
164+
return ['output']
165+
166+
def infer_shape(self, in_shape):
167+
# inputs, outputs, aux
168+
return [in_shape[0]], [in_shape[0]], []
169+
170+
def create_operator(self, ctx, shapes, dtypes):
171+
return MyAdd()
172+
173+
@mx.operator.register('MyAdd2')
174+
class MyAdd2Prop(mx.operator.CustomOpProp):
175+
def __init__(self):
176+
super(MyAdd2Prop, self).__init__(need_top_grad=True)
177+
178+
def list_arguments(self):
179+
return ['data']
180+
181+
def list_outputs(self):
182+
return ['output']
183+
184+
def infer_shape(self, in_shape):
185+
# inputs, outputs, aux
186+
return [in_shape[0]], [in_shape[0]], []
187+
188+
def create_operator(self, ctx, shapes, dtypes):
189+
return MyAdd()
190+
191+
inp = mx.nd.zeros(shape=(100, 100))
192+
a = mx.symbol.Variable('a')
193+
b = a + 1
194+
b = mx.symbol.Custom(data=a, op_type='MyAdd1')
195+
c = mx.symbol.Custom(data=a, op_type='MyAdd2')
196+
b.bind(mx.cpu(), {'a': inp}).forward()
197+
c.bind(mx.cpu(), {'a': inp}).forward()
198+
mx.nd.waitall()
199+
147200
if __name__ == '__main__':
148201
import nose
149202
nose.runmodule()

0 commit comments

Comments
 (0)