Skip to content

Commit 8dca199

Browse files
committed
Added an MLP module.
Allow number of hidden MLP layers in Scalar OutputModel to be configured from the yaml input.
1 parent 8b47246 commit 8dca199

File tree

4 files changed

+106
-33
lines changed

4 files changed

+106
-33
lines changed

torchmdnet/models/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
127127
activation=args["activation"],
128128
reduce_op=args["reduce_op"],
129129
dtype=dtype,
130+
num_layers=args.get("output_mlp_num_layers", 0),
130131
)
131132

132133
# combine representation and output network

torchmdnet/models/output_modules.py

+53-24
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from typing import Optional
77
import torch
88
from torch import nn
9-
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, scatter
9+
from torchmdnet.models.utils import (
10+
act_class_mapping,
11+
GatedEquivariantBlock,
12+
scatter,
13+
MLP,
14+
)
1015
from torchmdnet.utils import atomic_masses
1116
from torchmdnet.extensions import is_current_stream_capturing
1217
from warnings import warn
@@ -60,24 +65,23 @@ def __init__(
6065
allow_prior_model=True,
6166
reduce_op="sum",
6267
dtype=torch.float,
68+
**kwargs
6369
):
6470
super(Scalar, self).__init__(
6571
allow_prior_model=allow_prior_model, reduce_op=reduce_op
6672
)
67-
act_class = act_class_mapping[activation]
68-
self.output_network = nn.Sequential(
69-
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
70-
act_class(),
71-
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
73+
self.output_network = MLP(
74+
in_channels=hidden_channels,
75+
out_channels=1,
76+
hidden_channels=hidden_channels // 2,
77+
activation=activation,
78+
num_layers=kwargs.get("num_layers", 0),
79+
dtype=dtype,
7280
)
73-
7481
self.reset_parameters()
7582

7683
def reset_parameters(self):
77-
nn.init.xavier_uniform_(self.output_network[0].weight)
78-
self.output_network[0].bias.data.fill_(0)
79-
nn.init.xavier_uniform_(self.output_network[2].weight)
80-
self.output_network[2].bias.data.fill_(0)
84+
self.output_network.reset_parameters()
8185

8286
def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
8387
return self.output_network(x)
@@ -91,10 +95,13 @@ def __init__(
9195
allow_prior_model=True,
9296
reduce_op="sum",
9397
dtype=torch.float,
98+
**kwargs
9499
):
95100
super(EquivariantScalar, self).__init__(
96101
allow_prior_model=allow_prior_model, reduce_op=reduce_op
97102
)
103+
if kwargs.get("num_layers", 0) > 0:
104+
warn("num_layers is not used in EquivariantScalar")
98105
self.output_network = nn.ModuleList(
99106
[
100107
GatedEquivariantBlock(
@@ -125,14 +132,20 @@ def pre_reduce(self, x, v, z, pos, batch):
125132

126133
class DipoleMoment(Scalar):
127134
def __init__(
128-
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
135+
self,
136+
hidden_channels,
137+
activation="silu",
138+
reduce_op="sum",
139+
dtype=torch.float,
140+
**kwargs
129141
):
130142
super(DipoleMoment, self).__init__(
131143
hidden_channels,
132144
activation,
133145
allow_prior_model=False,
134146
reduce_op=reduce_op,
135147
dtype=dtype,
148+
**kwargs
136149
)
137150
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
138151
self.register_buffer("atomic_mass", atomic_mass)
@@ -152,14 +165,20 @@ def post_reduce(self, x):
152165

153166
class EquivariantDipoleMoment(EquivariantScalar):
154167
def __init__(
155-
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
168+
self,
169+
hidden_channels,
170+
activation="silu",
171+
reduce_op="sum",
172+
dtype=torch.float,
173+
**kwargs
156174
):
157175
super(EquivariantDipoleMoment, self).__init__(
158176
hidden_channels,
159177
activation,
160178
allow_prior_model=False,
161179
reduce_op=reduce_op,
162180
dtype=dtype,
181+
**kwargs
163182
)
164183
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
165184
self.register_buffer("atomic_mass", atomic_mass)
@@ -180,27 +199,31 @@ def post_reduce(self, x):
180199

181200
class ElectronicSpatialExtent(OutputModel):
182201
def __init__(
183-
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
202+
self,
203+
hidden_channels,
204+
activation="silu",
205+
reduce_op="sum",
206+
dtype=torch.float,
207+
**kwargs
184208
):
185209
super(ElectronicSpatialExtent, self).__init__(
186210
allow_prior_model=False, reduce_op=reduce_op
187211
)
188-
act_class = act_class_mapping[activation]
189-
self.output_network = nn.Sequential(
190-
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
191-
act_class(),
192-
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
212+
self.output_network = MLP(
213+
in_channels=hidden_channels,
214+
out_channels=1,
215+
hidden_channels=hidden_channels // 2,
216+
activation=activation,
217+
num_layers=kwargs.get("num_layers", 0),
218+
dtype=dtype,
193219
)
194220
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
195221
self.register_buffer("atomic_mass", atomic_mass)
196222

197223
self.reset_parameters()
198224

199225
def reset_parameters(self):
200-
nn.init.xavier_uniform_(self.output_network[0].weight)
201-
self.output_network[0].bias.data.fill_(0)
202-
nn.init.xavier_uniform_(self.output_network[2].weight)
203-
self.output_network[2].bias.data.fill_(0)
226+
self.output_network.reset_parameters()
204227

205228
def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
206229
x = self.output_network(x)
@@ -219,14 +242,20 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):
219242

220243
class EquivariantVectorOutput(EquivariantScalar):
221244
def __init__(
222-
self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float
245+
self,
246+
hidden_channels,
247+
activation="silu",
248+
reduce_op="sum",
249+
dtype=torch.float,
250+
**kwargs
223251
):
224252
super(EquivariantVectorOutput, self).__init__(
225253
hidden_channels,
226254
activation,
227255
allow_prior_model=False,
228256
reduce_op="sum",
229257
dtype=dtype,
258+
**kwargs
230259
)
231260

232261
def pre_reduce(self, x, v, z, pos, batch):

torchmdnet/models/utils.py

+51-9
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,49 @@ def forward(self, distances: Tensor) -> Tensor:
434434
return cutoffs
435435

436436

437+
class MLP(nn.Module):
438+
"""A simple multi-layer perceptron with a given number of layers and hidden channels.
439+
440+
Args:
441+
in_channels (int): Number of input features.
442+
out_channels (int): Number of output features.
443+
hidden_channels (int): Number of hidden features.
444+
activation (str): Activation function to use.
445+
num_layers (int, optional): Number of layers. Defaults to 0.
446+
dtype (torch.dtype, optional): Data type to use. Defaults to torch.float32.
447+
"""
448+
449+
def __init__(
450+
self,
451+
in_channels,
452+
out_channels,
453+
hidden_channels,
454+
activation,
455+
num_layers=0,
456+
dtype=torch.float32,
457+
):
458+
super(MLP, self).__init__()
459+
act_class = act_class_mapping[activation]
460+
self.act = act_class()
461+
self.layers = nn.Sequential()
462+
self.layers.append(nn.Linear(in_channels, hidden_channels, dtype=dtype))
463+
self.layers.append(self.act)
464+
for _ in range(num_layers):
465+
self.layers.append(nn.Linear(hidden_channels, hidden_channels, dtype=dtype))
466+
self.layers.append(self.act)
467+
self.layers.append(nn.Linear(hidden_channels, out_channels, dtype=dtype))
468+
469+
def reset_parameters(self):
470+
for layer in self.layers:
471+
if isinstance(layer, nn.Linear):
472+
nn.init.xavier_uniform_(layer.weight)
473+
layer.bias.data.fill_(0)
474+
475+
def forward(self, x):
476+
x = self.layers(x)
477+
return x
478+
479+
437480
class GatedEquivariantBlock(nn.Module):
438481
"""Gated Equivariant Block as defined in Schütt et al. (2021):
439482
Equivariant message passing for the prediction of tensorial properties and molecular spectra
@@ -462,21 +505,20 @@ def __init__(
462505
)
463506

464507
act_class = act_class_mapping[activation]
465-
self.update_net = nn.Sequential(
466-
nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype),
467-
act_class(),
468-
nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype),
508+
self.update_net = MLP(
509+
in_channels=hidden_channels * 2,
510+
out_channels=out_channels * 2,
511+
hidden_channels=intermediate_channels,
512+
activation=activation,
513+
num_layers=0,
514+
dtype=dtype,
469515
)
470-
471516
self.act = act_class() if scalar_activation else None
472517

473518
def reset_parameters(self):
474519
nn.init.xavier_uniform_(self.vec1_proj.weight)
475520
nn.init.xavier_uniform_(self.vec2_proj.weight)
476-
nn.init.xavier_uniform_(self.update_net[0].weight)
477-
self.update_net[0].bias.data.fill_(0)
478-
nn.init.xavier_uniform_(self.update_net[2].weight)
479-
self.update_net[2].bias.data.fill_(0)
521+
self.update_net.reset_parameters()
480522

481523
def forward(self, x, v):
482524
vec1_buffer = self.vec1_proj(v)

torchmdnet/scripts/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def get_argparse():
7474
# model architecture
7575
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
7676
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
77+
parser.add_argument('--output-mlp-num-layers', type=int, default=0, help='If the output model uses an MLP this will be the number of inner layers.')
7778
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")
7879

7980
# architectural args

0 commit comments

Comments
 (0)