Skip to content

Introduce GraphBuilder helper class #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions onnx_chainer/functions/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,20 @@ def convert_GetItem(func, opset_version, input_names,
raise ValueError(
'GetItem with type {} cannot handle in ONNX Slice, so that '
'ONNX-Chainer does not accept the type'.format(type(idx)))
nodes = []
nodes.append(onnx_helper.make_node(
'Slice', input_names, 1,
axes=axes, starts=starts, ends=ends))

gb = onnx_helper.GraphBuilder()
output = gb.op('Slice', input_names,
axes=axes, starts=starts, ends=ends)

if squeeze_idxs:
nodes.append(onnx_helper.make_node(
'Squeeze', nodes[-1].output, 1,
axes=squeeze_idxs))
output = gb.op('Squeeze', [output],
axes=squeeze_idxs)

if unsqueeze_idxs:
nodes.append(onnx_helper.make_node(
'Unsqueeze', nodes[-1].output, 1,
axes=unsqueeze_idxs))
output = gb.op('Unsqueeze', [output],
axes=unsqueeze_idxs)

return tuple(nodes)
return gb.nodes()


def convert_Pad(func, opset_version, input_names, num_outputs,
Expand Down
41 changes: 10 additions & 31 deletions onnx_chainer/functions/loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import chainer
import numpy as np
from onnx import helper
from onnx.numpy_helper import from_array

from onnx_chainer import onnx_helper
from onnx_chainer.onnx_helper import gensym


def convert_SoftmaxCrossEntropy(
Expand All @@ -31,33 +28,15 @@ def convert_SoftmaxCrossEntropy(
'argument parameters are default setting.')

# create intermediate values
nodes = []
gb = onnx_helper.GraphBuilder()
x, t = input_names
y_log = gensym()
th = gensym()
s0 = gensym()
sn = gensym()
sr = gensym()
depth = gensym()
zeroone = gensym()
y_log = gb.op('LogSoftmax', [x])
depth = gb.const(np.array([x_var.shape[1]], dtype=np.int32))
zeroone = gb.const(np.array([0, 1], dtype=x_var.dtype))
th = gb.op('OneHot', [t, depth, zeroone])
s0 = gb.op('Mul', [y_log, th])
sn = gb.op('Neg', [s0])
sr = gb.op('ReduceSum', [sn], axes=[1], keepdims=0)
gb.op('ReduceMean', [sr], axes=[0], keepdims=0)

nodes.append(helper.make_node(
'LogSoftmax', [x], [y_log]))
nodes.append(helper.make_node(
'Constant', [], [depth], value=from_array(
np.array([x_var.shape[1]], dtype=np.int32))))
nodes.append(helper.make_node(
'Constant', [], [zeroone], value=from_array(
np.array([0, 1], dtype=x_var.dtype))))
nodes.append(helper.make_node(
'OneHot', [t, depth, zeroone], [th]))
nodes.append(helper.make_node(
'Mul', [y_log, th], [s0]))
nodes.append(helper.make_node(
'Neg', [s0], [sn]))
nodes.append(helper.make_node(
'ReduceSum', [sn], [sr], axes=[1], keepdims=0))
nodes.append(onnx_helper.make_node(
'ReduceMean', [sr], num_outputs, axes=[0], keepdims=0))

return tuple(nodes)
return gb.nodes()
25 changes: 11 additions & 14 deletions onnx_chainer/functions/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,17 @@ def convert_LinearInterpolate(func, opset_version, input_names,
one = chainer.Parameter(np.array(1, dtype=typ))
parameters.append(one)

kwargs = {"consumed_inputs": [1, 1]} if opset_version == 1 else {}
kwargs2 = {} if opset_version >= 7 else {"broadcast": 1}

n1 = onnx_helper.make_node(
"Sub", [str(id(one)), input_names[0]], 1,
**kwargs, **kwargs2)
n2 = onnx_helper.make_node(
"Mul", [input_names[0], input_names[1]], 1, **kwargs)
n3 = onnx_helper.make_node(
"Mul", [n1.output[0], input_names[2]], 1, **kwargs)
n4 = onnx_helper.make_node(
"Add", [n2.output[0], n3.output[0]], num_outputs, **kwargs)

return n1, n2, n3, n4
kwargs = {'consumed_inputs': [1, 1]} if opset_version == 1 else {}
kwargs2 = {} if opset_version >= 7 else {'broadcast': 1}

gb = onnx_helper.GraphBuilder()
p, x, y = input_names
n1 = gb.op('Sub', [str(id(one)), p], **kwargs, **kwargs2)
n2 = gb.op('Mul', [p, x], **kwargs)
n3 = gb.op('Mul', [n1, y], **kwargs)
gb.op('Add', [n2, n3], num_outputs, **kwargs)

return gb.nodes()


def convert_Square(func, opset_version, input_names,
Expand Down
51 changes: 51 additions & 0 deletions onnx_chainer/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,54 @@ def make_node(op_name, input_names, num_outputs, **kwargs):
"""
output_names = [gensym() for i in range(num_outputs)]
return onnx.helper.make_node(op_name, input_names, output_names, **kwargs)


class GraphBuilder(object):
"""A helper class to build consecutive ONNX nodes."""

def __init__(self):
self._nodes = []

def op(self, op_name, input_names, num_outputs=1, **kwargs):
"""Creates a new ONNX node and returns its outputs.

Args:
op_name (str): The name of an ONNX op.
input_names (list of str): The names of input values.
num_outputs (int): The number of output values.
**kwargs (dict): ONNX attributes of the node.

Returns:
A str of the output name when `num_outputs` is 1.
A tuple of str of the output names otherwise.
"""
# Prevent a common mistake. `input_names="input"` creates a
# node with 5 inputs.
assert not isinstance(input_names, str)
node = make_node(op_name, input_names, num_outputs, **kwargs)
self._nodes.append(node)
if num_outputs == 1:
return node.output[0]
else:
return tuple(node.output)

def const(self, array):
"""Creates a Constant node of ONNX.

Args:
array (numpy.ndarray): A numpy array.

Returns:
A str of the name of the constant value.
"""
tensor = onnx.numpy_helper.from_array(array)
return self.op('Constant', [], 1, value=tensor)

def nodes(self):
"""Returns all nodes created so far.

Returns:
A list of `onnx.NodeProto` objects, suitable as the return
value of converter functions.
"""
return tuple(self._nodes)