Skip to content

Commit 26206eb

Browse files
authored
Merge pull request #334 from RaulPPelaez/swiglu
Add GLU, Swish, Mish and SwiGLU
2 parents 6dea4b6 + 352da8a commit 26206eb

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

torchmdnet/models/utils.py

+76
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,36 @@ def forward(self, dist):
388388
)
389389

390390

391+
class GLU(nn.Module):
392+
r"""Applies the gated linear unit (GLU) function:
393+
394+
.. math::
395+
396+
\text{GLU}(x) = \text{Linear}_1(x) \otimes \sigma(\text{Linear}_2(x))
397+
398+
399+
where :math:`\otimes` is the element-wise multiplication operator and
400+
:math:`\sigma` is an activation function.
401+
402+
Args:
403+
in_channels (int): Number of input features.
404+
hidden_channels (int, optional): Number of hidden features. Defaults to None, meaning hidden_channels=in_channels.
405+
activation (nn.Module, optional): Activation function to use. Defaults to Sigmoid.
406+
"""
407+
408+
def __init__(
409+
self, in_channels, hidden_channels=None, activation: Optional[nn.Module] = None
410+
):
411+
super(GLU, self).__init__()
412+
self.act = nn.Sigmoid() if activation is None else activation
413+
hidden_channels = hidden_channels or in_channels
414+
self.W = nn.Linear(in_channels, hidden_channels)
415+
self.V = nn.Linear(in_channels, hidden_channels)
416+
417+
def forward(self, x):
418+
return self.W(x) * self.act(self.V(x))
419+
420+
391421
class ShiftedSoftplus(nn.Module):
392422
r"""Applies the ShiftedSoftplus function :math:`\text{ShiftedSoftplus}(x) = \frac{1}{\beta} *
393423
\log(1 + \exp(\beta * x))-\log(2)` element-wise.
@@ -404,6 +434,50 @@ def forward(self, x):
404434
return F.softplus(x) - self.shift
405435

406436

437+
class Swish(nn.Module):
438+
"""Swish activation function as defined in https://arxiv.org/pdf/1710.05941 :
439+
440+
.. math::
441+
442+
\text{Swish}(x) = x \cdot \sigma(\beta x)
443+
444+
Args:
445+
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.
446+
447+
"""
448+
449+
def __init__(self, beta=1.0):
450+
super(Swish, self).__init__()
451+
self.beta = beta
452+
453+
def forward(self, x):
454+
return x * torch.sigmoid(self.beta * x)
455+
456+
457+
class SwiGLU(nn.Module):
458+
"""SwiGLU activation function as defined in https://arxiv.org/pdf/2002.05202 :
459+
460+
.. math::
461+
462+
\text{SwiGLU}(x) = \text{Linear}_1(x) \otimes \text{Swish}(\text{Linear}_2(x))
463+
464+
W1, V have shape (in_features, hidden_features)
465+
Args:
466+
in_features (int): Number of input features.
467+
hidden_features (int, optional): Number of hidden features. Defaults to None, meaning hidden_features=in_features.
468+
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.0.
469+
"""
470+
471+
def __init__(self, in_features, hidden_features=None, beta=1.0):
472+
super().__init__()
473+
hidden_features = hidden_features or in_features
474+
act = Swish(beta)
475+
self.glu = GLU(in_features, hidden_features, activation=act)
476+
477+
def forward(self, x):
478+
return self.glu(x)
479+
480+
407481
class CosineCutoff(nn.Module):
408482
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
409483
super(CosineCutoff, self).__init__()
@@ -615,6 +689,8 @@ def scatter(
615689
"silu": nn.SiLU,
616690
"tanh": nn.Tanh,
617691
"sigmoid": nn.Sigmoid,
692+
"swish": Swish,
693+
"mish": nn.Mish,
618694
}
619695

620696
dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}

0 commit comments

Comments
 (0)