Skip to content

Allow to configure the depth of the MLP in output modules #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=dtype,
num_hidden_layers=args.get("output_mlp_num_layers", 0),
)

# combine representation and output network
Expand Down Expand Up @@ -232,6 +233,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
model.prior_model[-1].enable = True

state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
state_dict = {
re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v
for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
return model.to(device)

Expand Down
78 changes: 54 additions & 24 deletions torchmdnet/models/output_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from typing import Optional
import torch
from torch import nn
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, scatter
from torchmdnet.models.utils import (
act_class_mapping,
GatedEquivariantBlock,
scatter,
MLP,
)
from torchmdnet.utils import atomic_masses
from torchmdnet.extensions import is_current_stream_capturing
from warnings import warn
Expand All @@ -20,6 +25,7 @@ class OutputModel(nn.Module, metaclass=ABCMeta):
Derive this class to make custom output models.
As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model.
"""

def __init__(self, allow_prior_model, reduce_op):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
Expand Down Expand Up @@ -60,24 +66,23 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(Scalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
self.output_network = MLP(
in_channels=hidden_channels,
out_channels=1,
hidden_channels=hidden_channels // 2,
activation=activation,
num_hidden_layers=kwargs.get("num_layers", 0),
dtype=dtype,
)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
self.output_network.reset_parameters()

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
return self.output_network(x)
Expand All @@ -91,10 +96,13 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
if kwargs.get("num_layers", 0) > 0:
warn("num_layers is not used in EquivariantScalar")
self.output_network = nn.ModuleList(
[
GatedEquivariantBlock(
Expand Down Expand Up @@ -125,14 +133,20 @@ def pre_reduce(self, x, v, z, pos, batch):

class DipoleMoment(Scalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(DipoleMoment, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
**kwargs
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)
Expand All @@ -152,14 +166,20 @@ def post_reduce(self, x):

class EquivariantDipoleMoment(EquivariantScalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantDipoleMoment, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
**kwargs
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)
Expand All @@ -180,27 +200,31 @@ def post_reduce(self, x):

class ElectronicSpatialExtent(OutputModel):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(ElectronicSpatialExtent, self).__init__(
allow_prior_model=False, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
self.output_network = MLP(
in_channels=hidden_channels,
out_channels=1,
hidden_channels=hidden_channels // 2,
activation=activation,
num_hidden_layers=kwargs.get("num_layers", 0),
dtype=dtype,
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
self.output_network.reset_parameters()

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
x = self.output_network(x)
Expand All @@ -219,14 +243,20 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):

class EquivariantVectorOutput(EquivariantScalar):
def __init__(
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
self,
hidden_channels,
activation="silu",
reduce_op="sum",
dtype=torch.float,
**kwargs
):
super(EquivariantVectorOutput, self).__init__(
hidden_channels,
activation,
allow_prior_model=False,
reduce_op="sum",
dtype=dtype,
**kwargs
)

def pre_reduce(self, x, v, z, pos, batch):
Expand Down
69 changes: 60 additions & 9 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,58 @@ def forward(self, distances: Tensor) -> Tensor:
return cutoffs


class MLP(nn.Module):
"""A simple multi-layer perceptron with a given number of layers and hidden channels.

The simplest MLP has no hidden layers and is composed of two linear layers with a non-linear activation function in between:

.. math::

\text{MLP}(x) = \text{Linear}_o(\text{act}(\text{Linear}_i(x)))

Where :math:`\text{Linear}_i` has input size :math:`\text{in_channels}` and output size :math:`\text{hidden_channels}` and :math:`\text{Linear}_o` has input size :math:`\text{hidden_channels}` and output size :math:`\text{out_channels}`.


Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
hidden_channels (int): Number of hidden features.
activation (str): Activation function to use.
num_hidden_layers (int, optional): Number of hidden layers. Defaults to 0.
dtype (torch.dtype, optional): Data type to use. Defaults to torch.float32.
"""

def __init__(
self,
in_channels,
out_channels,
hidden_channels,
activation,
num_hidden_layers=0,
dtype=torch.float32,
):
super(MLP, self).__init__()
act_class = act_class_mapping[activation]
self.act = act_class()
self.layers = nn.Sequential()
self.layers.append(nn.Linear(in_channels, hidden_channels, dtype=dtype))
self.layers.append(self.act)
for _ in range(num_hidden_layers):
self.layers.append(nn.Linear(hidden_channels, hidden_channels, dtype=dtype))
self.layers.append(self.act)
self.layers.append(nn.Linear(hidden_channels, out_channels, dtype=dtype))

def reset_parameters(self):
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
layer.bias.data.fill_(0)

def forward(self, x):
x = self.layers(x)
return x


class GatedEquivariantBlock(nn.Module):
"""Gated Equivariant Block as defined in Schütt et al. (2021):
Equivariant message passing for the prediction of tensorial properties and molecular spectra
Expand Down Expand Up @@ -462,21 +514,20 @@ def __init__(
)

act_class = act_class_mapping[activation]
self.update_net = nn.Sequential(
nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype),
act_class(),
nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype),
self.update_net = MLP(
in_channels=hidden_channels * 2,
out_channels=out_channels * 2,
hidden_channels=intermediate_channels,
activation=activation,
num_hidden_layers=0,
dtype=dtype,
)

self.act = act_class() if scalar_activation else None

def reset_parameters(self):
nn.init.xavier_uniform_(self.vec1_proj.weight)
nn.init.xavier_uniform_(self.vec2_proj.weight)
nn.init.xavier_uniform_(self.update_net[0].weight)
self.update_net[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.update_net[2].weight)
self.update_net[2].bias.data.fill_(0)
self.update_net.reset_parameters()

def forward(self, x, v):
vec1_buffer = self.vec1_proj(v)
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_argparse():
# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--output-mlp-num-layers', type=int, default=0, help='If the output model uses an MLP this will be the number of inner layers.')
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")

# architectural args
Expand Down
Loading