|
10 | 10 | @support((1, 6, 7))
|
11 | 11 | def convert_BatchNormalization(func, opset_version, input_names,
|
12 | 12 | output_names, context, parameters):
|
13 |
| - if len(func.inputs) <= 3: |
14 |
| - # expect this `func` is F.batch_normalization |
15 |
| - x = func.inputs[0].get_variable().array |
16 |
| - mean = x.mean(axis=func.axis) |
17 |
| - param_mean_name = context.add_param(mean, 'mean') |
18 |
| - input_names.append(param_mean_name) |
19 |
| - param_var_name = context.add_param(x.var(axis=func.axis), 'var') |
20 |
| - input_names.append(param_var_name) |
| 13 | + is_fixed_bn = len(func.inputs) > 3 |
| 14 | + |
| 15 | + # NOTE(disktnk): |
| 16 | + # if `use_beta=False`, beta_param is None, `use_gamma=False` is same. |
| 17 | + beta_param = func.inputs[2].get_variable_or_none() |
| 18 | + gamma_param = func.inputs[1].get_variable_or_none() |
| 19 | + namedlink = context.get_link(beta_param) or context.get_link(gamma_param) |
| 20 | + |
| 21 | + if namedlink is not None: |
| 22 | + prefix, link = namedlink |
| 23 | + if is_fixed_bn: |
| 24 | + mean = link.avg_mean |
| 25 | + var = link.avg_var |
| 26 | + else: |
| 27 | + # on train mode, avg_mean would be updated, so make them from x |
| 28 | + x = func.inputs[0].get_variable().array |
| 29 | + mean = x.mean(axis=func.axis) |
| 30 | + var = x.var(axis=func.axis) |
21 | 31 | else:
|
22 |
| - # expect this `func` is F.fixed_batch_normalization |
23 |
| - mean = func.inputs[3].get_variable().array |
24 |
| - param_mean_name = context.add_param(mean, 'mean') |
25 |
| - input_names[3] = param_mean_name |
26 |
| - param_var_name = context.add_param( |
27 |
| - func.inputs[4].get_variable().array, 'var') |
28 |
| - input_names[4] = param_var_name |
| 32 | + prefix = None |
| 33 | + if is_fixed_bn: |
| 34 | + mean = func.inputs[3].get_variable().array |
| 35 | + var = func.inputs[4].get_variable().array |
| 36 | + else: |
| 37 | + x = func.inputs[0].get_variable().array |
| 38 | + mean = x.mean(axis=func.axis) |
| 39 | + var = x.var(axis=func.axis) |
29 | 40 |
|
30 |
| - momentum = getattr(func, 'decay', 0.) |
| 41 | + def add_param(v, suffix): |
| 42 | + if prefix is None: |
| 43 | + return context.add_param(v, suffix) |
| 44 | + else: |
| 45 | + return context.add_param( |
| 46 | + v, '{}_{}'.format(prefix, suffix), use_original_name=True) |
31 | 47 |
|
32 |
| - # if `use_beta=False`, passed None value to the functions |
33 |
| - if func.inputs[2].get_variable_or_none() is None: |
34 |
| - beta_name = context.add_param( |
35 |
| - np.zeros_like(mean, dtype=mean.dtype), 'beta') |
| 48 | + maen_name = add_param(mean, 'avg_mean') |
| 49 | + var_name = add_param(var, 'avg_var') |
| 50 | + if is_fixed_bn: |
| 51 | + input_names[3] = maen_name |
| 52 | + input_names[4] = var_name |
| 53 | + else: |
| 54 | + input_names.extend([maen_name, var_name]) |
| 55 | + |
| 56 | + if beta_param is None: |
| 57 | + beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta') |
36 | 58 | input_names[2] = beta_name
|
37 |
| - # `use_gamma=False` is same |
38 |
| - if func.inputs[1].get_variable_or_none() is None: |
39 |
| - gamma_name = context.add_param( |
40 |
| - np.ones_like(mean, dtype=mean.dtype), 'gamma') |
| 59 | + if gamma_param is None: |
| 60 | + gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma') |
41 | 61 | input_names[1] = gamma_name
|
42 | 62 |
|
| 63 | + momentum = getattr(func, 'decay', 0.) |
| 64 | + |
43 | 65 | # TODO(disktnk): On definition of ONNX's BatchNormalization operator,
|
44 | 66 | # outputs one required output and four optional outputs. This converter
|
45 | 67 | # must make 5 values for output and return them.
|
|
0 commit comments