@@ -388,6 +388,36 @@ def forward(self, dist):
388
388
)
389
389
390
390
391
+ class GLU (nn .Module ):
392
+ r"""Applies the gated linear unit (GLU) function:
393
+
394
+ .. math::
395
+
396
+ \text{GLU}(x) = \text{Linear}_1(x) \otimes \sigma(\text{Linear}_2(x))
397
+
398
+
399
+ where :math:`\otimes` is the element-wise multiplication operator and
400
+ :math:`\sigma` is an activation function.
401
+
402
+ Args:
403
+ in_channels (int): Number of input features.
404
+ hidden_channels (int, optional): Number of hidden features. Defaults to None, meaning hidden_channels=in_channels.
405
+ activation (nn.Module, optional): Activation function to use. Defaults to Sigmoid.
406
+ """
407
+
408
+ def __init__ (
409
+ self , in_channels , hidden_channels = None , activation : Optional [nn .Module ] = None
410
+ ):
411
+ super (GLU , self ).__init__ ()
412
+ self .act = nn .Sigmoid () if activation is None else activation
413
+ hidden_channels = hidden_channels or in_channels
414
+ self .W = nn .Linear (in_channels , hidden_channels )
415
+ self .V = nn .Linear (in_channels , hidden_channels )
416
+
417
+ def forward (self , x ):
418
+ return self .W (x ) * self .act (self .V (x ))
419
+
420
+
391
421
class ShiftedSoftplus (nn .Module ):
392
422
r"""Applies the ShiftedSoftplus function :math:`\text{ShiftedSoftplus}(x) = \frac{1}{\beta} *
393
423
\log(1 + \exp(\beta * x))-\log(2)` element-wise.
@@ -404,6 +434,50 @@ def forward(self, x):
404
434
return F .softplus (x ) - self .shift
405
435
406
436
437
+ class Swish (nn .Module ):
438
+ """Swish activation function as defined in https://arxiv.org/pdf/1710.05941 :
439
+
440
+ .. math::
441
+
442
+ \t ext{Swish}(x) = x \cdot \sigma(\b eta x)
443
+
444
+ Args:
445
+ beta (float, optional): Scaling factor for Swish activation. Defaults to 1.
446
+
447
+ """
448
+
449
+ def __init__ (self , beta = 1.0 ):
450
+ super (Swish , self ).__init__ ()
451
+ self .beta = beta
452
+
453
+ def forward (self , x ):
454
+ return x * torch .sigmoid (self .beta * x )
455
+
456
+
457
+ class SwiGLU (nn .Module ):
458
+ """SwiGLU activation function as defined in https://arxiv.org/pdf/2002.05202 :
459
+
460
+ .. math::
461
+
462
+ \t ext{SwiGLU}(x) = \t ext{Linear}_1(x) \otimes \t ext{Swish}(\t ext{Linear}_2(x))
463
+
464
+ W1, V have shape (in_features, hidden_features)
465
+ Args:
466
+ in_features (int): Number of input features.
467
+ hidden_features (int, optional): Number of hidden features. Defaults to None, meaning hidden_features=in_features.
468
+ beta (float, optional): Scaling factor for Swish activation. Defaults to 1.0.
469
+ """
470
+
471
+ def __init__ (self , in_features , hidden_features = None , beta = 1.0 ):
472
+ super ().__init__ ()
473
+ hidden_features = hidden_features or in_features
474
+ act = Swish (beta )
475
+ self .glu = GLU (in_features , hidden_features , activation = act )
476
+
477
+ def forward (self , x ):
478
+ return self .glu (x )
479
+
480
+
407
481
class CosineCutoff (nn .Module ):
408
482
def __init__ (self , cutoff_lower = 0.0 , cutoff_upper = 5.0 ):
409
483
super (CosineCutoff , self ).__init__ ()
@@ -615,6 +689,8 @@ def scatter(
615
689
"silu" : nn .SiLU ,
616
690
"tanh" : nn .Tanh ,
617
691
"sigmoid" : nn .Sigmoid ,
692
+ "swish" : Swish ,
693
+ "mish" : nn .Mish ,
618
694
}
619
695
620
696
dtype_mapping = {16 : torch .float16 , 32 : torch .float , 64 : torch .float64 }
0 commit comments