Skip to content

converter for SoftmaxCrossEntropy #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 27, 2019
Merged

Conversation

ir5
Copy link
Contributor

@ir5 ir5 commented Feb 26, 2019

I implemented converter for SoftmaxCrossEntropy. Since I used OneHot node in this implementation, onnx==1.4.1 is required. (Lower version may suffice, but I'm not sure.)
Testing using onnxruntime is not done because onnxruntime currently does not support OneHot node.
To check the validity of the implementation, I temporally wrote the following script to generate a test case.

import chainer
import numpy as np

import onnx_chainer

class Model(chainer.Chain):
    def __init__(self):
        super(Model, self).__init__()

    def __call__(self, x, t):
        return chainer.functions.softmax_cross_entropy(x, t)

model = Model()
x = np.random.uniform(size=(3, 50)).astype('f')
t = np.random.randint(size=3, low=0, high=50).astype(np.int32)
x = chainer.as_variable(x)
t = chainer.as_variable(t)
onnx_chainer.export_testcase(model, (x, t), 'softmax')

Then, I ran run_onnx in chainer-compiler. The result was "OK!", so I suppose that the converter works correctly.

./build/tools/run_onnx --test </path/to/softmax> --trace -d cuda

Copy link
Member

@shinh shinh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defer to tanaka-san

np.array([x_var.shape[1]], dtype=np.int32))))
nodes.append(helper.make_node(
'Constant', [], [zeroone], value=from_array(
np.array([0, 1], dtype='f'))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use dtype=x_var.dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -0,0 +1,38 @@
# Currently, the test for SoftmaxCrossEntropy is disabled since onnxruntime
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create a bug and add its URL here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what do you mean "create a bug and add its URL here"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.

@disktnk disktnk self-requested a review February 26, 2019 07:14
@disktnk disktnk added this to the 1.3.3 milestone Feb 26, 2019
from onnx.numpy_helper import from_array

from onnx_chainer.onnx_helper import gensym
from onnx_chainer import onnx_helper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change import order?

onnx_chainer/functions/loss.py:6:1: H306  imports not in alphabetical order (onnx_chainer.onnx_helper.gensym, onnx_chainer.onnx_helper)

Coding style guide of Chainer follows "OpenStack Style Guidelines", but currently hacking test is not set in CI because of dependency conflict, sorry for inconvenient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that hacking was used in Chainer for long time, but it seem now hacking is not in requirements in the coding style guide of Chainer? https://github.com/chainer/chainer/blob/master/setup.py#L34-L39
Without hacking, both flake8 and autopep8 do not output format error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, Chainer remove hacking from requirement, so they fix import order or other to fit OpenStack Style Guildelines manually like this chainer/chainer#6128

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Thank for for clarifying the rule.

func, opset_version, input_names,
num_outputs, parameters):
# obtain input variable
x_var, t_var = func.get_retained_inputs()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Chainer 5.2.0=current stable version, SoftmaxCrossEntropy is not supported get_retained_inputs method because F.SoftmaxCrossEntropy is not a subclass of FunctionNode, is subclass of Function. ONNX-chainer don't have to support old version, I think, but I would like to support at least stable version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for noting important point. Unfortunately, it seems like there's no easy way to retrieve input variables (x and t) just from the interface of Function... https://github.com/chainer/chainer/blob/v5.2.0/chainer/functions/loss/softmax_cross_entropy.py#L165
Since this converter is an experimental feature, I may want to a version requirements for this feature to keep the implementation simple. I may put an explicit error message when non-latest version Chainer is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

x_var, t_var = func.get_retained_inputs()
if len(x_var.shape) != 2:
raise NotImplementedError(
'onnx-chainer currently handles SoftmaxCrossEntropy only when '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently error and warning messages are unified "ONNX-C hainer", could you follow it? L20, L25 is same

depth = gensym()
zeroone = gensym()

nodes.append(helper.make_node(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any special reason not to use onnx_helper? ref https://github.com/chainer/onnx-chainer/blob/master/onnx_chainer/functions/array.py#L53

I thinks the below code can work, not to call gensym function, but I'd like to know if you have reasons.

nodes.append(onnx_helper.make_node('LogSoftmax', [x], 1))
y_log = nodes[-1].output[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, there are lots of intermediate values. Probably this is a personal preference, but I liked current style because this makes the inputs and outputs of a node more understandable. But if gensym is not supposed to be called from outside the helper module, then I would change the this part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree nodes[-1] is not easy to read. One idea would be defining a helper function and use it like

def add_node(*args, **kwargs):
  node = onnx_helper.make_node(*args, **kwargs)
  nodes.append(node)
  return node.output[0]

y_log = add_node(LogSoftmax, [x], 1)

However, I think it's better to more this kind of helper function to onnx_helper. Actually, I was planning to do this kind of refactoring by myself. So, how about keeping the code as is in this PR. We will see if I can make it better :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would be more helpful if some better function is implemented in the helper module. It sounds good to keep the code as is seems and try to improve later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

helper function looks nice, so approve current implementation, thx

# @testing.parameterize(
# {'in_shape': (3, 5)},
# )
# class TestSoftmaxCrossEntropy(unittest.TestCase):
Copy link
Member

@disktnk disktnk Feb 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use @unittest.skip("OneHot operator is not supported on test runtime") decorator than comment out.

@ir5
Copy link
Contributor Author

ir5 commented Feb 27, 2019

I fixed the code except for gensym part.

@disktnk disktnk merged commit 848a074 into chainer:master Feb 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants