Skip to content

Provide more appropriate naming for the fourth and fifth arguments of BN #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions onnx_chainer/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ class Context(object):
def __init__(self, model):
self.name_list = dict()
self.parameters = []
namedlink = {n: l for n, l in model.namedlinks()}
self.param_to_link = {}
for name, param in model.namedparams():
owned_link_name = name[:name.rindex('/')]
if owned_link_name in namedlink:
onnx_owned_link_name = onnx_helper.cleanse_param_name(
owned_link_name)
self.param_to_link[id(param)] = (
onnx_owned_link_name, namedlink[owned_link_name])
onnx_name = onnx_helper.cleanse_param_name(name)
self.set_name(param, onnx_name)

Expand Down Expand Up @@ -56,19 +64,33 @@ def is_pinned(self, variable):
return False
return self.name_list[str_id][1]

def add_param(self, array, name):
def add_param(self, array, name, use_original_name=False):
"""Add array to context parameter

To be converted as ONNX tensor.

Return:
(str) registered name.
Returns:
str: registered name.
"""
param = chainer.Parameter(array)
if not (name.startswith('/') or name.startswith('_')):
name = '/' + name
onnx_name = '{}_{}'.format(
onnx_helper.get_func_name(), onnx_helper.cleanse_param_name(name))
self.set_name(param, onnx_name)
self.parameters.append(param)
if use_original_name:
onnx_name = name
else:
if not (name.startswith('/') or name.startswith('_')):
name = '/' + name
onnx_name = '{}_{}'.format(
onnx_helper.get_func_name(),
onnx_helper.cleanse_param_name(name))
self.set_name(array, onnx_name)
self.parameters.append(array)
return onnx_name

def get_link(self, param):
"""Return link with name which has the param.

Arguments:
param(chainer.Parameter): the target param.

Returns:
tuple: name and link. returns ``None`` when not found.
"""
return self.param_to_link.get(id(param), None)
70 changes: 46 additions & 24 deletions onnx_chainer/functions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,58 @@
@support((1, 6, 7))
def convert_BatchNormalization(func, opset_version, input_names,
output_names, context, parameters):
if len(func.inputs) <= 3:
# expect this `func` is F.batch_normalization
x = func.inputs[0].get_variable().array
mean = x.mean(axis=func.axis)
param_mean_name = context.add_param(mean, 'mean')
input_names.append(param_mean_name)
param_var_name = context.add_param(x.var(axis=func.axis), 'var')
input_names.append(param_var_name)
is_fixed_bn = len(func.inputs) > 3

# NOTE(disktnk):
# if `use_beta=False`, beta_param is None, `use_gamma=False` is same.
beta_param = func.inputs[2].get_variable_or_none()
gamma_param = func.inputs[1].get_variable_or_none()
namedlink = context.get_link(beta_param) or context.get_link(gamma_param)

if namedlink is not None:
prefix, link = namedlink
if is_fixed_bn:
mean = link.avg_mean
var = link.avg_var
else:
# on train mode, avg_mean would be updated, so make them from x
x = func.inputs[0].get_variable().array
mean = x.mean(axis=func.axis)
var = x.var(axis=func.axis)
else:
# expect this `func` is F.fixed_batch_normalization
mean = func.inputs[3].get_variable().array
param_mean_name = context.add_param(mean, 'mean')
input_names[3] = param_mean_name
param_var_name = context.add_param(
func.inputs[4].get_variable().array, 'var')
input_names[4] = param_var_name
prefix = None
if is_fixed_bn:
mean = func.inputs[3].get_variable().array
var = func.inputs[4].get_variable().array
else:
x = func.inputs[0].get_variable().array
mean = x.mean(axis=func.axis)
var = x.var(axis=func.axis)

momentum = getattr(func, 'decay', 0.)
def add_param(v, suffix):
if prefix is None:
return context.add_param(v, suffix)
else:
return context.add_param(
v, '{}_{}'.format(prefix, suffix), use_original_name=True)

# if `use_beta=False`, passed None value to the functions
if func.inputs[2].get_variable_or_none() is None:
beta_name = context.add_param(
np.zeros_like(mean, dtype=mean.dtype), 'beta')
maen_name = add_param(mean, 'avg_mean')
var_name = add_param(var, 'avg_var')
if is_fixed_bn:
input_names[3] = maen_name
input_names[4] = var_name
else:
input_names.extend([maen_name, var_name])

if beta_param is None:
beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta')
input_names[2] = beta_name
# `use_gamma=False` is same
if func.inputs[1].get_variable_or_none() is None:
gamma_name = context.add_param(
np.ones_like(mean, dtype=mean.dtype), 'gamma')
if gamma_param is None:
gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma')
input_names[1] = gamma_name

momentum = getattr(func, 'decay', 0.)

# TODO(disktnk): On definition of ONNX's BatchNormalization operator,
# outputs one required output and four optional outputs. This converter
# must make 5 values for output and return them.
Expand Down
13 changes: 13 additions & 0 deletions tests/functions_tests/test_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import chainer.functions as F
import chainer.links as L
from chainer import testing
import onnx

import onnx_chainer
from onnx_chainer.testing import input_generator
from tests.helper import ONNXModelTest

Expand Down Expand Up @@ -82,3 +84,14 @@ def test_output(self):
if hasattr(self, 'condition'):
name += '_' + self.condition
self.expect(self.model, self.x, name=name, train=train)

def test_input_names(self):
for opset_version in range(
onnx_chainer.MINIMUM_OPSET_VERSION,
onnx.defs.onnx_opset_version() + 1):
onnx_model = onnx_chainer.export(
self.model, self.x, opset_version=opset_version)
input_names = set(v.name for v in onnx_model.graph.input)

assert 'param_bn_avg_mean' in input_names
assert 'param_bn_avg_var' in input_names