Skip to content

Commit cf0b98d

Browse files
authored
Merge pull request #185 from disktnk/fix/bn-param-name
Provide more appropriate naming for the fourth and fifth arguments of BN
2 parents fd8e29a + 7711fec commit cf0b98d

File tree

3 files changed

+91
-34
lines changed

3 files changed

+91
-34
lines changed

onnx_chainer/context.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ class Context(object):
2020
def __init__(self, model):
2121
self.name_list = dict()
2222
self.parameters = []
23+
namedlink = {n: l for n, l in model.namedlinks()}
24+
self.param_to_link = {}
2325
for name, param in model.namedparams():
26+
owned_link_name = name[:name.rindex('/')]
27+
if owned_link_name in namedlink:
28+
onnx_owned_link_name = onnx_helper.cleanse_param_name(
29+
owned_link_name)
30+
self.param_to_link[id(param)] = (
31+
onnx_owned_link_name, namedlink[owned_link_name])
2432
onnx_name = onnx_helper.cleanse_param_name(name)
2533
self.set_name(param, onnx_name)
2634

@@ -56,19 +64,33 @@ def is_pinned(self, variable):
5664
return False
5765
return self.name_list[str_id][1]
5866

59-
def add_param(self, array, name):
67+
def add_param(self, array, name, use_original_name=False):
6068
"""Add array to context parameter
6169
6270
To be converted as ONNX tensor.
6371
64-
Return:
65-
(str) registered name.
72+
Returns:
73+
str: registered name.
6674
"""
67-
param = chainer.Parameter(array)
68-
if not (name.startswith('/') or name.startswith('_')):
69-
name = '/' + name
70-
onnx_name = '{}_{}'.format(
71-
onnx_helper.get_func_name(), onnx_helper.cleanse_param_name(name))
72-
self.set_name(param, onnx_name)
73-
self.parameters.append(param)
75+
if use_original_name:
76+
onnx_name = name
77+
else:
78+
if not (name.startswith('/') or name.startswith('_')):
79+
name = '/' + name
80+
onnx_name = '{}_{}'.format(
81+
onnx_helper.get_func_name(),
82+
onnx_helper.cleanse_param_name(name))
83+
self.set_name(array, onnx_name)
84+
self.parameters.append(array)
7485
return onnx_name
86+
87+
def get_link(self, param):
88+
"""Return link with name which has the param.
89+
90+
Arguments:
91+
param(chainer.Parameter): the target param.
92+
93+
Returns:
94+
tuple: name and link. returns ``None`` when not found.
95+
"""
96+
return self.param_to_link.get(id(param), None)

onnx_chainer/functions/normalization.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,58 @@
1010
@support((1, 6, 7))
1111
def convert_BatchNormalization(func, opset_version, input_names,
1212
output_names, context, parameters):
13-
if len(func.inputs) <= 3:
14-
# expect this `func` is F.batch_normalization
15-
x = func.inputs[0].get_variable().array
16-
mean = x.mean(axis=func.axis)
17-
param_mean_name = context.add_param(mean, 'mean')
18-
input_names.append(param_mean_name)
19-
param_var_name = context.add_param(x.var(axis=func.axis), 'var')
20-
input_names.append(param_var_name)
13+
is_fixed_bn = len(func.inputs) > 3
14+
15+
# NOTE(disktnk):
16+
# if `use_beta=False`, beta_param is None, `use_gamma=False` is same.
17+
beta_param = func.inputs[2].get_variable_or_none()
18+
gamma_param = func.inputs[1].get_variable_or_none()
19+
namedlink = context.get_link(beta_param) or context.get_link(gamma_param)
20+
21+
if namedlink is not None:
22+
prefix, link = namedlink
23+
if is_fixed_bn:
24+
mean = link.avg_mean
25+
var = link.avg_var
26+
else:
27+
# on train mode, avg_mean would be updated, so make them from x
28+
x = func.inputs[0].get_variable().array
29+
mean = x.mean(axis=func.axis)
30+
var = x.var(axis=func.axis)
2131
else:
22-
# expect this `func` is F.fixed_batch_normalization
23-
mean = func.inputs[3].get_variable().array
24-
param_mean_name = context.add_param(mean, 'mean')
25-
input_names[3] = param_mean_name
26-
param_var_name = context.add_param(
27-
func.inputs[4].get_variable().array, 'var')
28-
input_names[4] = param_var_name
32+
prefix = None
33+
if is_fixed_bn:
34+
mean = func.inputs[3].get_variable().array
35+
var = func.inputs[4].get_variable().array
36+
else:
37+
x = func.inputs[0].get_variable().array
38+
mean = x.mean(axis=func.axis)
39+
var = x.var(axis=func.axis)
2940

30-
momentum = getattr(func, 'decay', 0.)
41+
def add_param(v, suffix):
42+
if prefix is None:
43+
return context.add_param(v, suffix)
44+
else:
45+
return context.add_param(
46+
v, '{}_{}'.format(prefix, suffix), use_original_name=True)
3147

32-
# if `use_beta=False`, passed None value to the functions
33-
if func.inputs[2].get_variable_or_none() is None:
34-
beta_name = context.add_param(
35-
np.zeros_like(mean, dtype=mean.dtype), 'beta')
48+
maen_name = add_param(mean, 'avg_mean')
49+
var_name = add_param(var, 'avg_var')
50+
if is_fixed_bn:
51+
input_names[3] = maen_name
52+
input_names[4] = var_name
53+
else:
54+
input_names.extend([maen_name, var_name])
55+
56+
if beta_param is None:
57+
beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta')
3658
input_names[2] = beta_name
37-
# `use_gamma=False` is same
38-
if func.inputs[1].get_variable_or_none() is None:
39-
gamma_name = context.add_param(
40-
np.ones_like(mean, dtype=mean.dtype), 'gamma')
59+
if gamma_param is None:
60+
gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma')
4161
input_names[1] = gamma_name
4262

63+
momentum = getattr(func, 'decay', 0.)
64+
4365
# TODO(disktnk): On definition of ONNX's BatchNormalization operator,
4466
# outputs one required output and four optional outputs. This converter
4567
# must make 5 values for output and return them.

tests/functions_tests/test_normalizations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import chainer.functions as F
33
import chainer.links as L
44
from chainer import testing
5+
import onnx
56

7+
import onnx_chainer
68
from onnx_chainer.testing import input_generator
79
from tests.helper import ONNXModelTest
810

@@ -82,3 +84,14 @@ def test_output(self):
8284
if hasattr(self, 'condition'):
8385
name += '_' + self.condition
8486
self.expect(self.model, self.x, name=name, train=train)
87+
88+
def test_input_names(self):
89+
for opset_version in range(
90+
onnx_chainer.MINIMUM_OPSET_VERSION,
91+
onnx.defs.onnx_opset_version() + 1):
92+
onnx_model = onnx_chainer.export(
93+
self.model, self.x, opset_version=opset_version)
94+
input_names = set(v.name for v in onnx_model.graph.input)
95+
96+
assert 'param_bn_avg_mean' in input_names
97+
assert 'param_bn_avg_var' in input_names

0 commit comments

Comments
 (0)