Skip to content

Add GLU, Swish, Mish and SwiGLU #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,36 @@ def forward(self, dist):
)


class GLU(nn.Module):
r"""Applies the gated linear unit (GLU) function:

.. math::

\text{GLU}(x) = \text{Linear}_1(x) \otimes \sigma(\text{Linear}_2(x))


where :math:`\otimes` is the element-wise multiplication operator and
:math:`\sigma` is an activation function.

Args:
in_channels (int): Number of input features.
hidden_channels (int, optional): Number of hidden features. Defaults to None, meaning hidden_channels=in_channels.
activation (nn.Module, optional): Activation function to use. Defaults to Sigmoid.
"""

def __init__(
self, in_channels, hidden_channels=None, activation: Optional[nn.Module] = None
):
super(GLU, self).__init__()
self.act = nn.Sigmoid() if activation is None else activation
hidden_channels = hidden_channels or in_channels
self.W = nn.Linear(in_channels, hidden_channels)
self.V = nn.Linear(in_channels, hidden_channels)

def forward(self, x):
return self.W(x) * self.act(self.V(x))


class ShiftedSoftplus(nn.Module):
r"""Applies the ShiftedSoftplus function :math:`\text{ShiftedSoftplus}(x) = \frac{1}{\beta} *
\log(1 + \exp(\beta * x))-\log(2)` element-wise.
Expand All @@ -404,6 +434,50 @@ def forward(self, x):
return F.softplus(x) - self.shift


class Swish(nn.Module):
"""Swish activation function as defined in https://arxiv.org/pdf/1710.05941 :

.. math::

\text{Swish}(x) = x \cdot \sigma(\beta x)

Args:
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.

"""

def __init__(self, beta=1.0):
super(Swish, self).__init__()
self.beta = beta

def forward(self, x):
return x * torch.sigmoid(self.beta * x)


class SwiGLU(nn.Module):
"""SwiGLU activation function as defined in https://arxiv.org/pdf/2002.05202 :

.. math::

\text{SwiGLU}(x) = \text{Linear}_1(x) \otimes \text{Swish}(\text{Linear}_2(x))

W1, V have shape (in_features, hidden_features)
Args:
in_features (int): Number of input features.
hidden_features (int, optional): Number of hidden features. Defaults to None, meaning hidden_features=in_features.
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.0.
"""

def __init__(self, in_features, hidden_features=None, beta=1.0):
super().__init__()
hidden_features = hidden_features or in_features
act = Swish(beta)
self.glu = GLU(in_features, hidden_features, activation=act)

def forward(self, x):
return self.glu(x)


class CosineCutoff(nn.Module):
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
super(CosineCutoff, self).__init__()
Expand Down Expand Up @@ -615,6 +689,8 @@ def scatter(
"silu": nn.SiLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"swish": Swish,
"mish": nn.Mish,
}

dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}
Loading