Skip to content

Commit 8de1ef3

Browse files
authored
Merge pull request #224 from durswd/impl_triangles
implemented triangle functions
2 parents 7eebbf2 + 9cebb92 commit 8de1ef3

File tree

5 files changed

+86
-0
lines changed

5 files changed

+86
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,21 @@ Currently 82 Chainer Functions are supported to export in ONNX format.
134134
- Absolute
135135
- Add
136136
- AddConstant
137+
- ArcCos
138+
- ArcSin
139+
- ArcTan
137140
- ArgMax
138141
- ArgMin
139142
- BroadcastTo
143+
- Cos
144+
- Cosh
140145
- Clip
141146
- Div
142147
- DivFromConstant
143148
- Exp
144149
- Identity
145150
- LinearInterpolate
151+
- Log
146152
- LogSumExp
147153
- MatMul
148154
- Max
@@ -156,11 +162,14 @@ Currently 82 Chainer Functions are supported to export in ONNX format.
156162
- PowVarConst
157163
- Prod
158164
- RsqrtGPU
165+
- Sin
166+
- Sinh
159167
- Sqrt
160168
- Square
161169
- Sub
162170
- SubFromConstant
163171
- Sum
172+
- Tan
164173

165174
### Noise
166175

onnx_chainer/functions/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,21 @@
4848
from onnx_chainer.functions.math import convert_Absolute # NOQA
4949
from onnx_chainer.functions.math import convert_Add # NOQA
5050
from onnx_chainer.functions.math import convert_AddConstant # NOQA
51+
from onnx_chainer.functions.math import convert_Arccos # NOQA
52+
from onnx_chainer.functions.math import convert_Arcsin # NOQA
53+
from onnx_chainer.functions.math import convert_Arctan # NOQA
5154
from onnx_chainer.functions.math import convert_ArgMax # NOQA
5255
from onnx_chainer.functions.math import convert_ArgMin # NOQA
5356
from onnx_chainer.functions.math import convert_BroadcastTo # NOQA
5457
from onnx_chainer.functions.math import convert_Clip # NOQA
58+
from onnx_chainer.functions.math import convert_Cos # NOQA
59+
from onnx_chainer.functions.math import convert_Cosh # NOQA
5560
from onnx_chainer.functions.math import convert_Div # NOQA
5661
from onnx_chainer.functions.math import convert_DivFromConstant # NOQA
5762
from onnx_chainer.functions.math import convert_Exp # NOQA
5863
from onnx_chainer.functions.math import convert_Identity # NOQA
5964
from onnx_chainer.functions.math import convert_LinearInterpolate # NOQA
65+
from onnx_chainer.functions.math import convert_Log # NOQA
6066
from onnx_chainer.functions.math import convert_LogSumExp # NOQA
6167
from onnx_chainer.functions.math import convert_MatMul # NOQA
6268
from onnx_chainer.functions.math import convert_Max # NOQA
@@ -70,11 +76,14 @@
7076
from onnx_chainer.functions.math import convert_PowVarConst # NOQA
7177
from onnx_chainer.functions.math import convert_Prod # NOQA
7278
from onnx_chainer.functions.math import convert_RsqrtGPU # NOQA
79+
from onnx_chainer.functions.math import convert_Sin # NOQA
80+
from onnx_chainer.functions.math import convert_Sinh # NOQA
7381
from onnx_chainer.functions.math import convert_Sqrt # NOQA
7482
from onnx_chainer.functions.math import convert_Square # NOQA
7583
from onnx_chainer.functions.math import convert_Sub # NOQA
7684
from onnx_chainer.functions.math import convert_SubFromConstant # NOQA
7785
from onnx_chainer.functions.math import convert_Sum # NOQA
86+
from onnx_chainer.functions.math import convert_Tan # NOQA
7887

7988
from onnx_chainer.functions.noise import convert_Dropout # NOQA
8089

onnx_chainer/functions/math.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ def convert_Absolute(func, opset_version, input_names, output_names, context):
115115
return onnx_helper.make_node('Abs', input_names, output_names),
116116

117117

118+
@support((7,))
119+
def convert_Arccos(func, opset_version, input_names, output_names, context):
120+
return onnx_helper.make_node('Acos', input_names, output_names),
121+
122+
123+
@support((7,))
124+
def convert_Arcsin(func, opset_version, input_names, output_names, context):
125+
return onnx_helper.make_node('Asin', input_names, output_names),
126+
127+
128+
@support((7,))
129+
def convert_Arctan(func, opset_version, input_names, output_names, context):
130+
return onnx_helper.make_node('Atan', input_names, output_names),
131+
132+
118133
@support((1, 7))
119134
def convert_PowVarConst(
120135
func, opset_version, input_names, output_names, context):
@@ -143,6 +158,16 @@ def convert_Clip(func, opset_version, input_names, output_names, context):
143158
),
144159

145160

161+
@support((7,))
162+
def convert_Cos(func, opset_version, input_names, output_names, context):
163+
return onnx_helper.make_node('Cos', input_names, output_names),
164+
165+
166+
@support((9,))
167+
def convert_Cosh(func, opset_version, input_names, output_names, context):
168+
return onnx_helper.make_node('Cosh', input_names, output_names),
169+
170+
146171
@support((1, 6))
147172
def convert_Exp(func, opset_version, input_names, output_names, context):
148173
if opset_version == 1:
@@ -191,6 +216,16 @@ def convert_Minimum(func, opset_version, input_names, output_names, context):
191216
return onnx_helper.make_node('Min', input_names, output_names),
192217

193218

219+
@support((7,))
220+
def convert_Sin(func, opset_version, input_names, output_names, context):
221+
return onnx_helper.make_node('Sin', input_names, output_names),
222+
223+
224+
@support((9,))
225+
def convert_Sinh(func, opset_version, input_names, output_names, context):
226+
return onnx_helper.make_node('Sinh', input_names, output_names),
227+
228+
194229
@support((1, 6))
195230
def convert_Sqrt(func, opset_version, input_names, output_names, context):
196231
if opset_version == 1:
@@ -207,6 +242,11 @@ def convert_RsqrtGPU(func, opset_version, input_names, output_names, context):
207242
return gb.nodes(output_names)
208243

209244

245+
@support((6,))
246+
def convert_Log(func, opset_version, input_names, output_names, context):
247+
return onnx_helper.make_node('Log', input_names, output_names),
248+
249+
210250
def convert_LogSumExp(func, opset_version, input_names, output_names, context):
211251
# Use keepdims=False by default
212252
# since the chainer does not support keepdims option
@@ -259,6 +299,11 @@ def convert_Sum(func, opset_version, input_names, output_names, context):
259299
'ReduceSum', input_names, output_names, **kwargs),
260300

261301

302+
@support((7,))
303+
def convert_Tan(func, opset_version, input_names, output_names, context):
304+
return onnx_helper.make_node('Tan', input_names, output_names),
305+
306+
262307
@support((1, 6, 7))
263308
def convert_LinearInterpolate(
264309
func, opset_version, input_names, output_names, context):

onnx_chainer/mapping.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,21 @@
5858
'Absolute',
5959
'Add',
6060
'AddConstant',
61+
'Arccos',
62+
'Arcsin',
63+
'Arctan',
6164
'ArgMax',
6265
'ArgMin',
6366
'BroadcastTo',
6467
'Clip',
68+
'Cos',
69+
'Cosh',
6570
'Div',
6671
'DivFromConstant',
6772
'Exp',
6873
'Identity',
6974
'LinearInterpolate',
75+
'Log',
7076
'LogSumExp',
7177
'MatMul',
7278
'Max',
@@ -80,11 +86,14 @@
8086
'PowVarConst',
8187
'Prod',
8288
'RsqrtGPU',
89+
'Sin',
90+
'Sinh',
8391
'Sqrt',
8492
'Square',
8593
'Sub',
8694
'SubFromConstant',
8795
'Sum',
96+
'Tan',
8897

8998
# Noise
9099
'Dropout',

tests/functions_tests/test_maths.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
@testing.parameterize(
1313
{'op_name': 'Neg', 'ops': '-a'},
1414
{'op_name': 'Absolute', 'ops': 'abs(a)'},
15+
{'op_name': 'Arccos', 'ops': 'chainer.functions.arccos(a)'},
16+
{'op_name': 'Arcsin', 'ops': 'chainer.functions.arcsin(a)'},
17+
{'op_name': 'Arctan', 'ops': 'chainer.functions.arctan(a)'},
18+
{'op_name': 'Cos', 'ops': 'chainer.functions.cos(a)'},
19+
{'op_name': 'Cosh', 'ops': 'chainer.functions.cosh(a)'},
1520
{'op_name': 'Clip', 'ops': 'chainer.functions.clip(a, 0.1, 0.2)'},
1621
{'op_name': 'Exp', 'ops': 'chainer.functions.exp(a)'},
1722
{'op_name': 'Sqrt', 'ops': 'chainer.functions.sqrt(a)'},
@@ -47,10 +52,14 @@
4752
'condition': 'axis0'},
4853
{'op_name': 'Prod', 'ops': 'chainer.functions.prod(a, keepdims=True)',
4954
'condition': 'keepdims'},
55+
{'op_name': 'Log', 'ops': 'chainer.functions.log(a)'},
5056
{'op_name': 'LogSumExp', 'ops': 'chainer.functions.logsumexp(a)'},
5157
{'op_name': 'LogSumExp', 'ops': 'chainer.functions.logsumexp(a, axis=0)',
5258
'condition': 'axis0'},
59+
{'op_name': 'Sin', 'ops': 'chainer.functions.sin(a)'},
60+
{'op_name': 'Sinh', 'ops': 'chainer.functions.sinh(a)'},
5361
{'op_name': 'Square', 'ops': 'chainer.functions.square(a)'},
62+
{'op_name': 'Tan', 'ops': 'chainer.functions.tan(a)'},
5463
{'op_name': 'BroadcastTo',
5564
'ops': 'chainer.functions.broadcast_to(a, (2,2,3))'},
5665
)
@@ -78,8 +87,13 @@ def __call__(self, a):
7887
self.name = name
7988

8089
skip_opset_version = []
90+
if self.op_name == 'Cosh' or self.op_name == 'Sinh':
91+
skip_opset_version.append(7)
92+
skip_opset_version.append(8)
93+
8194
if self.op_name == 'BroadcastTo':
8295
skip_opset_version.append(7)
96+
8397
self.skip_opset_version = skip_opset_version
8498

8599
def test_output(self):

0 commit comments

Comments
 (0)