Skip to content

Reset inputs after exported for out_grad #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 55 additions & 49 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class RetainInputHook(chainer.LinkHook):
def __init__(self):
self.link_inputs = set()
self.retain_inputs = []
self.replaced_inputs = []

self.org_apply = chainer.function_node.FunctionNode.apply

Expand All @@ -181,6 +182,7 @@ def hooked_apply(_self, inputs):
# of function node. To avoid to lose reference out
# of the forward, retain the variable.
self.retain_inputs.append(referenced_var)
self.replaced_inputs.append((_self, _self.inputs))
_self.inputs = tuple(func_inodes)
return ret
self.hooked_apply = hooked_apply
Expand Down Expand Up @@ -213,6 +215,8 @@ def __enter__(self):

def __exit__(self, *exc_details):
chainer.function_node.FunctionNode.apply = self.org_apply
for _self, inputs in self.replaced_inputs:
_self.inputs = inputs
super().__exit__(*exc_details)


Expand Down Expand Up @@ -327,7 +331,7 @@ def _export(model, args, filename, export_params, graph_name, save_text,
# if input shapes are invalid, raise exception before forwarding.
input_shapes = format_customized_shapes(args, input_shapes)

with RetainInputHook() as hook: # NOQA hook is not used, to keep retained value
with RetainInputHook():
# Forward computation
context = Context(model)
network_inputs = OrderedDict()
Expand Down Expand Up @@ -357,55 +361,57 @@ def _export(model, args, filename, export_params, graph_name, save_text,
'The \'args\' argument should be a list, tuple, dict, '
'numpy array, or Chainer Variable. But a {} object was '
'given.'.format(type(args)))
rename_variable_name(context, args, network_inputs, input_names)

initializers = []
input_tensors = []
param_names = set()
for org_name, param in model.namedparams():
# `model.namedparams()` has `include_uninit` flag but not use, to
# output user warning
if param.array is None:
warnings.warn(
'The parameter \'{}\' is not initialized, skip setting to '
'ONNX graph'.format(org_name))
continue
name = context.get_name(param)
param_names.add(name)
tensor = convert_parameter(param, context)
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
name, tensor.data_type, tensor.dims))

for i, (name, var) in enumerate(network_inputs.items()):
shape = var.shape if input_shapes is None else input_shapes[i]
input_tensors.append(helper.make_tensor_value_info(
name, NP_TYPE_TO_TENSOR_TYPE[var.dtype], shape))
rename_variable_name(context, args, network_inputs, input_names)

if external_converters:
chainer.utils.experimental('external_converters')
converters = dict(mapping.converters, **external_converters)
else:
converters = mapping.converters

if isinstance(outputs, (list, tuple)):
flat_outputs = outputs
elif isinstance(outputs, dict):
flat_outputs = list(outputs.values())
elif isinstance(outputs, chainer.Variable):
flat_outputs = [outputs]
else:
raise RuntimeError(
'Unexpected output type from the model: {}'.format(type(outputs)))
if not all([isinstance(o, chainer.Variable) for o in flat_outputs]):
raise ValueError('The all \'outputs\' must be Chainer Variable')
network_outputs = OrderedDict(
[(context.get_name(var), var) for var in flat_outputs])
if output_names:
rename_variable_name(context, outputs, network_outputs, output_names)

o = Graph(context, converters, opset_version, network_outputs)
o.to_onnx_graph()
initializers = []
input_tensors = []
param_names = set()
for org_name, param in model.namedparams():
# `model.namedparams()` has `include_uninit` flag but not use, to
# output user warning
if param.array is None:
warnings.warn(
'The parameter \'{}\' is not initialized, skip setting to '
'ONNX graph'.format(org_name))
continue
name = context.get_name(param)
param_names.add(name)
tensor = convert_parameter(param, context)
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
name, tensor.data_type, tensor.dims))

for i, (name, var) in enumerate(network_inputs.items()):
shape = var.shape if input_shapes is None else input_shapes[i]
input_tensors.append(helper.make_tensor_value_info(
name, NP_TYPE_TO_TENSOR_TYPE[var.dtype], shape))

if external_converters:
chainer.utils.experimental('external_converters')
converters = dict(mapping.converters, **external_converters)
else:
converters = mapping.converters

if isinstance(outputs, (list, tuple)):
flat_outputs = outputs
elif isinstance(outputs, dict):
flat_outputs = list(outputs.values())
elif isinstance(outputs, chainer.Variable):
flat_outputs = [outputs]
else:
raise RuntimeError(
'Unexpected output type from the model: {}'.format(
type(outputs)))
if not all([isinstance(o, chainer.Variable) for o in flat_outputs]):
raise ValueError('The all \'outputs\' must be Chainer Variable')
network_outputs = OrderedDict(
[(context.get_name(var), var) for var in flat_outputs])
if output_names:
rename_variable_name(
context, outputs, network_outputs, output_names)

o = Graph(context, converters, opset_version, network_outputs)
o.to_onnx_graph()

implicit_input_names = set(context.implicit_inputs.keys()) - param_names -\
set(network_inputs.keys())
Expand Down
10 changes: 6 additions & 4 deletions tests/test_export_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def model():
L.Convolution2D(8, 5, 5, 1, 2),
F.relu,
L.Linear(None, 100),
L.BatchNormalization(100),
F.relu,
L.Linear(100, 10)
)
Expand Down Expand Up @@ -51,9 +52,10 @@ def test_export_testcase(
out_names[0] if out_names else 'LinearFunction_1')


def test_output_grad(tmpdir, model, x, disable_experimental_warning):
@pytest.mark.parametrize('train', [True, False])
def test_output_grad(tmpdir, model, x, train, disable_experimental_warning):
path = str(tmpdir)
export_testcase(model, (x,), path, output_grad=True, train=True)
export_testcase(model, (x,), path, output_grad=True, train=train)

model_filename = os.path.join(path, 'model.onnx')
assert os.path.isfile(model_filename)
Expand All @@ -64,12 +66,12 @@ def test_output_grad(tmpdir, model, x, disable_experimental_warning):
initializer_names = {i.name for i in onnx_model.graph.initializer}

# 10 gradient files should be there
for i in range(10):
for i in range(12):
tensor_filename = os.path.join(
path, 'test_data_set_0', 'gradient_{}.pb'.format(i))
assert os.path.isfile(tensor_filename)
tensor = onnx.load_tensor(tensor_filename)
assert tensor.name.startswith('param_')
assert tensor.name in initializer_names
assert not os.path.isfile(
os.path.join(path, 'test_data_set_0', 'gradient_10.pb'))
os.path.join(path, 'test_data_set_0', 'gradient_12.pb'))