Skip to content

Move rename_tensor to function_hook #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 64 additions & 44 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function

import collections
from collections import OrderedDict
import warnings

import chainer
Expand Down Expand Up @@ -48,38 +49,6 @@ def convert_parameter(parameter, context):
return numpy_helper.from_array(array, context.get_name(parameter))


def rename_tensors(model, force_rename_network_out=True, network_outputs=None):
names = {v.name: v.name for v in model.graph.initializer}
network_output_names = set() if force_rename_network_out else\
{o.name for o in model.graph.output}
op_counts = collections.defaultdict(int)

for op in model.graph.node:
for i, input_name in enumerate(op.input):
if input_name not in names:
names[input_name] = input_name
op.input[i] = names[input_name]

op_name = '{}_{}'.format(op.op_type, op_counts[op.op_type])
op_counts[op.op_type] += 1
for i, output_name in enumerate(op.output):
if output_name in network_output_names:
continue
if len(op.output) == 1:
names[output_name] = op_name
else:
names[output_name] = '{}_{}'.format(op_name, i)
op.output[i] = names[output_name]
if output_name in network_outputs:
var = network_outputs[output_name]
del network_outputs[output_name]
network_outputs[names[output_name]] = var

for v in tuple(model.graph.input) + tuple(model.graph.output):
if v.name in names:
v.name = names[v.name]


def rename_variable_name(
context, variables, named_vars, new_names, prefix='Input'):
# Update ``named_vars`` keys to ``new_names``
Expand Down Expand Up @@ -136,14 +105,21 @@ def rename_variable_name(

class ONNXExport(chainer.FunctionHook):

def __init__(self, context, converters, opset_version=None):
def __init__(
self, context, converters, opset_version, is_output_renamed,
network_outputs):
self.context = context
self.converters = converters

self.graph = []
# Converter nodes keyed by "number:func_name"
self.converted_nodes = OrderedDict()
self.func_name_counts = collections.defaultdict(int)
self.inputs = {} # Input `Variable` objects keyed by string IDs
self.additional_parameters = []
self.specified_opset_version = opset_version
self.is_output_renamed = is_output_renamed
self.network_outputs = network_outputs

def create_node(
self, func_name, func, input_names, output_names, parameters):
Expand All @@ -164,6 +140,10 @@ def backward_postprocess(self, function, in_data, out_grad):
if isinstance(function, chainer.function.FunctionAdapter):
function = function.function
func_name = function.__class__.__name__
temp_node_name = '{}:{}'.format(
self.func_name_counts[func_name], func_name)
self.func_name_counts[func_name] += 1

input_names = []
for i in function.inputs:
# 'i' is a VariableNode, so check if it has a Variable/Parameter
Expand Down Expand Up @@ -191,7 +171,52 @@ def backward_postprocess(self, function, in_data, out_grad):
nodes = self.create_node(
func_name, function, input_names, output_names,
self.additional_parameters)
self.graph.extend(nodes)
self.converted_nodes[temp_node_name] = nodes

def deleted(self, function=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a brief comment to explain why we need to use deleted?

"""Rename output names.

When renaming an output name, another node can reference the same value
as input, so the input name must be renamed at once. So this renaming
process should be run after all functions are converted and this
`deleted` function is called when function hook is done, which means
all functions are converted.

If input/output names are given externally, these given names take
priority over named by this process.
"""
func_name_counts = collections.defaultdict(int)
names = {}
for temp_func_name, nodes in reversed(self.converted_nodes.items()):
func_name = temp_func_name[temp_func_name.index(':')+1:]
base_node_name = '{}_{}'.format(
func_name, func_name_counts[func_name])
func_name_counts[func_name] += 1
for num, node in enumerate(reversed(nodes)):
if len(nodes) > 1 and num != len(nodes)-1:
node_name = '{}_tmp_{}'.format(base_node_name, num)
else:
node_name = base_node_name
node.name = node_name

for i, input_name in enumerate(node.input):
if input_name not in names:
names[input_name] = input_name
node.input[i] = names[input_name]

for i, output_name in enumerate(node.output):
if self.is_output_renamed:
continue
elif len(node.output) == 1:
names[output_name] = node_name
else:
names[output_name] = '{}_{}'.format(node_name, i)
node.output[i] = names[output_name]
if output_name in self.network_outputs:
var = self.network_outputs[output_name]
del self.network_outputs[output_name]
self.network_outputs[names[output_name]] = var
self.graph.append(node)


def export(model, args, filename=None, export_params=True,
Expand Down Expand Up @@ -365,10 +390,11 @@ def _export(model, args, filename, export_params, graph_name, save_text,
'Unexpected output type from the model: {}'.format(type(outputs)))
network_outputs = {context.get_name(var): var for var in flat_outputs}
if output_names:
rename_variable_name(
context, outputs, network_outputs, output_names)
rename_variable_name(context, outputs, network_outputs, output_names)
# Backward computation to construct graph
with ONNXExport(context, converters, opset_version) as o:
with ONNXExport(
context, converters, opset_version, (output_names is not None),
network_outputs) as o:
chainer.grad(flat_outputs, list(model.params()) + flat_args)

implicit_input_names = set(o.inputs.keys()) - param_names -\
Expand All @@ -387,9 +413,6 @@ def _export(model, args, filename, export_params, graph_name, save_text,
input_tensors.append(helper.make_tensor_value_info(
context.get_name(param), tensor.data_type, tensor.dims))

# The graph must be topologically sorted
graph = reversed(o.graph)

# Convert output tensors
output_tensors = []
for name, var in network_outputs.items():
Expand All @@ -400,7 +423,7 @@ def _export(model, args, filename, export_params, graph_name, save_text,
initializers = []

onnx_graph = helper.make_graph(
graph, graph_name, input_tensors, output_tensors,
o.graph, graph_name, input_tensors, output_tensors,
initializer=initializers)

opset_imports = [helper.make_operatorsetid('', opset_version)]
Expand All @@ -417,9 +440,6 @@ def _export(model, args, filename, export_params, graph_name, save_text,

model.ir_version = onnx.IR_VERSION

rename_tensors(
model, force_rename_network_out=not output_names,
network_outputs=network_outputs)
try:
checker.check_model(model)
except onnx.checker.ValidationError as e:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_export_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_export_testcase(
output_pb_path = os.path.join(path, 'test_data_set_0', 'output_0.pb')
assert os.path.isfile(output_pb_path)
output_tensor = onnx.load_tensor(output_pb_path)
assert output_tensor.name == (out_names[0] if out_names else 'Gemm_1')
assert output_tensor.name == (
out_names[0] if out_names else 'LinearFunction_1')


def test_output_grad(tmpdir, model, x, desable_experimental_warning):
Expand Down