Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 61597a5

Browse files
authored
revise activations (#18700)
1 parent c4c7b11 commit 61597a5

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

python/mxnet/gluon/nn/activations.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def __init__(self, alpha_initializer=initializer.Constant(0.25),
139139
init=alpha_initializer)
140140

141141
def hybrid_forward(self, F, x, alpha):
142-
return F.LeakyReLU(x, gamma=alpha, act_type='prelu', name='fwd')
142+
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
143+
return leaky_relu(x, gamma=alpha, act_type='prelu', name='fwd')
143144

144145

145146
class ELU(HybridBlock):
@@ -167,7 +168,8 @@ def __init__(self, alpha=1.0, **kwargs):
167168
self._alpha = alpha
168169

169170
def hybrid_forward(self, F, x):
170-
return F.LeakyReLU(x, act_type='elu', slope=self._alpha)
171+
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
172+
return leaky_relu(x, act_type='elu', slope=self._alpha)
171173

172174

173175
class SELU(HybridBlock):
@@ -187,7 +189,9 @@ def __init__(self, **kwargs):
187189
super(SELU, self).__init__(**kwargs)
188190

189191
def hybrid_forward(self, F, x):
190-
return F.LeakyReLU(x, act_type='selu', name='fwd')
192+
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
193+
return leaky_relu(x, act_type='selu', name='fwd')
194+
191195

192196
class GELU(HybridBlock):
193197
r"""
@@ -206,7 +210,8 @@ def __init__(self, **kwargs):
206210
super(GELU, self).__init__(**kwargs)
207211

208212
def hybrid_forward(self, F, x):
209-
return F.LeakyReLU(x, act_type='gelu', name='fwd')
213+
leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
214+
return leaky_relu(x, act_type='gelu', name='fwd')
210215

211216

212217
class Swish(HybridBlock):
@@ -232,4 +237,7 @@ def __init__(self, beta=1.0, **kwargs):
232237
self._beta = beta
233238

234239
def hybrid_forward(self, F, x):
235-
return x * F.sigmoid(self._beta * x, name='fwd')
240+
if is_np_array():
241+
return x * F.npx.sigmoid(self._beta * x)
242+
else:
243+
return x * F.sigmoid(self._beta * x, name='fwd')

tests/python/unittest/test_numpy_gluon.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import mxnet as mx
2626
from mxnet import gluon, autograd, np
2727
from mxnet.test_utils import use_np, assert_almost_equal, check_gluon_hybridize_consistency
28+
from mxnet.gluon import nn
2829
from common import with_seed
2930
import random
3031

@@ -422,6 +423,55 @@ def hybrid_forward(self, F, valid_length):
422423
assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy())
423424

424425

426+
@with_seed()
427+
@use_np
428+
def test_activations_leakyrelu():
429+
# Currently, all the activation tests, we will just test for runnable.
430+
act_layer = nn.LeakyReLU(0.1)
431+
out = act_layer(mx.np.random.uniform(size=(10,)))
432+
out.asnumpy()
433+
434+
435+
@with_seed()
436+
@use_np
437+
def test_activations_prelu():
438+
act_layer = nn.PReLU()
439+
act_layer.initialize()
440+
out = act_layer(mx.np.random.uniform(size=(10,)))
441+
out.asnumpy()
442+
443+
444+
@with_seed()
445+
@use_np
446+
def test_activations_elu():
447+
act_layer = nn.ELU(1.0)
448+
out = act_layer(mx.np.random.uniform(size=(10,)))
449+
out.asnumpy()
450+
451+
452+
@with_seed()
453+
@use_np
454+
def test_activations_selu():
455+
act_layer = nn.SELU()
456+
out = act_layer(mx.np.random.uniform(size=(10,)))
457+
out.asnumpy()
458+
459+
460+
@with_seed()
461+
@use_np
462+
def test_activations_gelu():
463+
act_layer = nn.GELU()
464+
out = act_layer(mx.np.random.uniform(size=(10,)))
465+
out.asnumpy()
466+
467+
468+
@with_seed()
469+
@use_np
470+
def test_activations_swish():
471+
act_layer = nn.Swish()
472+
out = act_layer(mx.np.random.uniform(size=(10,)))
473+
out.asnumpy()
474+
425475
if __name__ == '__main__':
426476
import nose
427477
nose.runmodule()

0 commit comments

Comments
 (0)