Skip to content

Introduce doctest #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ script:
- flake8
- autopep8 -r . --dif --exit-code
- pytest -m "not gpu" -x -s -vvvs tests/ --cov onnx_chainer
- if [[ $ONNX_CHAINER_DEPLOY_JOB == 1 ]]; then pushd docs && make doctest && popd; fi

after_success:
- if [[ $ONNX_CHAINER_DEPLOY_JOB == 1 ]]; then codecov; fi
Expand Down
8 changes: 8 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# ones.
extensions = ['sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon']

Expand Down Expand Up @@ -135,6 +136,13 @@
'chainercv': ('https://chainercv.readthedocs.io/en/latest/', None),
}

doctest_global_setup = '''
import chainer
import chainer.functions as F
import numpy as np
import onnx_chainer
'''

# -- Own configuration for this project -----------------------------------

html_scaled_image_link = False
15 changes: 9 additions & 6 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,19 @@ def export(model, args, filename=None, export_params=True,

>>> import onnx
>>> def custom_converter(param):
>>> return onnx.helper.make_node(
>>> 'CustomizedRelu', param.input_names, param.output_names,
>>> domain='chainer'),
... return onnx.helper.make_node(
... 'CustomizedRelu', param.input_names, param.output_names,
... domain='chainer'),
>>>
>>> external_converters = {'ReLU': custom_converter}
>>> external_imports = {'chainer': 0}
>>>
>>> export(model, args,
>>> external_converters=external_converters,
>>> external_opset_imports=external_imports)
>>> model = chainer.Sequential(F.relu) # set the target model
>>> args = chainer.Variable(np.random.rand(1,10)) # set dummy input
>>> onnx_graph = onnx_chainer.export(
... model, args,
... external_converters=external_converters,
... external_opset_imports=external_imports)

Returned model has ``CustomizedRelu`` node.

Expand Down
4 changes: 2 additions & 2 deletions onnx_chainer/functions/opset_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ def support(opset_versions):
from 6 and updated on 8, add this function as decorator like the below.

>>> @support((6, 8))
>>> def own_converter(func, opset_version, *args):
>>> print(opset_version)
... def own_converter(func, opset_version, *args):
... print(opset_version)
>>>
>>> own_converter(None, 6)
6
Expand Down
10 changes: 5 additions & 5 deletions onnx_chainer/replace_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def fake_as_funcnode(alt_func, name, rename_attributes=None):

>>> def func(x, a, b, c=1, d=2): pass
>>> # x is variable
>>> func = fake_as_funcnode(
>>> func = onnx_chainer.replace_func.fake_as_funcnode(
... func, 'CustomNode',
... rename_attributes=[(1, 'value'), ('c': 'y')])
... rename_attributes=[(1, 'value'), ('c', 'y')])

Then ``func`` will be operated as a function node named "CustomNode", and
``'value'``, ``'b'``, ``'y'``, ``'d'`` are set as function's attributes.
Expand Down Expand Up @@ -160,9 +160,9 @@ def as_funcnode(name, rename_attributes=None):

Example:

>>> @as_funcnode(
... 'CustomNode', rename_attributes=[(1, 'value'), ('c': 'y')])
>>> def func(x, a, b, c=1, d=2): pass
>>> @onnx_chainer.replace_func.as_funcnode(
... 'CustomNode', rename_attributes=[(1, 'value'), ('c', 'y')])
... def func(x, a, b, c=1, d=2): pass

Args:
name (str): function name. This name is used for what ONNX operator
Expand Down
6 changes: 3 additions & 3 deletions onnx_chainer/testing/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def increasing(*shape, dtype=np.float32):

Example:

>>> increasing(3, 4)
>>> onnx_chainer.testing.input_generator.increasing(3, 4)
array([[-3. , -2.5, -2. , -1.5],
[-1. , -0.5, 0. , 0.5],
[ 1. , 1.5, 2. , 2.5]], dtype=float32)
Expand All @@ -49,7 +49,7 @@ def nonzero_increasing(*shape, dtype=np.float32, bias=1e-7):

Example:

>>> nonzero_increasing(3, 4)
>>> onnx_chainer.testing.input_generator.nonzero_increasing(3, 4)
array([[-3.0000000e+00, -2.5000000e+00, -1.9999999e+00, -1.4999999e+00],
[-9.9999988e-01, -4.9999991e-01, 1.0000000e-07, 5.0000012e-01],
[ 1.0000001e+00, 1.5000001e+00, 2.0000000e+00, 2.5000000e+00]],
Expand All @@ -75,7 +75,7 @@ def positive_increasing(*shape, dtype=np.float32, bias=1e-7):

Example:

>>> positive_increasing(3, 4)
>>> onnx_chainer.testing.input_generator.positive_increasing(3, 4)
array([[1.0000000e-07, 5.0000012e-01, 1.0000001e+00, 1.5000001e+00],
[2.0000000e+00, 2.5000000e+00, 3.0000000e+00, 3.5000000e+00],
[4.0000000e+00, 4.5000000e+00, 5.0000000e+00, 5.5000000e+00]],
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
# CUDA9.1 is not supported latest cuDNN, so decided to use CPU version
'onnxruntime==0.4.0',
],
'doctest': [
'sphinx==1.8.2',
],
'travis': [
'-r stylecheck',
'-r test-cpu',
'-r doctest',
'pytest-cov',
'codecov',
],
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def target_opsets(request):
else:
try:
versions = [int(i) for i in opsets.split(',')]
except ValueError as e:
except ValueError:
raise ValueError('cannot convert {} to versions list'.format(
opsets))
return versions