Skip to content

Commit 8a7b520

Browse files
authored
Merge pull request #335 from RaulPPelaez/huberloss
Add Huber loss, allow choosing training loss function from the yaml
2 parents 26206eb + bfed435 commit 8a7b520

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

torchmdnet/loss.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch.nn.functional import mse_loss, l1_loss, huber_loss
2+
3+
loss_class_mapping = {
4+
"mse_loss": mse_loss,
5+
"l1_loss": l1_loss,
6+
"huber_loss": huber_loss,
7+
}

torchmdnet/module.py

+51-15
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import torch
77
from torch.optim import AdamW
88
from torch.optim.lr_scheduler import ReduceLROnPlateau
9-
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
9+
from torch.nn.functional import local_response_norm
1010
from torch import Tensor
1111
from typing import Optional, Dict, Tuple
12-
1312
from lightning import LightningModule
1413
from torchmdnet.models.model import create_model, load_model
1514
from torchmdnet.models.utils import dtype_mapping
15+
from torchmdnet.loss import l1_loss, loss_class_mapping
1616
import torch_geometric.transforms as T
1717

1818

@@ -48,6 +48,18 @@ def __call__(self, data):
4848
return data
4949

5050

51+
# This wrapper is here in order to permit Lightning to serialize the loss function.
52+
class LossFunction:
53+
def __init__(self, loss_fn, extra_args=None):
54+
self.loss_fn = loss_fn
55+
self.extra_args = extra_args
56+
if self.extra_args is None:
57+
self.extra_args = {}
58+
59+
def __call__(self, x, batch):
60+
return self.loss_fn(x, batch, **self.extra_args)
61+
62+
5163
class LNNP(LightningModule):
5264
"""
5365
Lightning wrapper for the Neural Network Potentials in TorchMD-Net.
@@ -65,7 +77,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
6577
hparams["charge"] = False
6678
if "spin" not in hparams:
6779
hparams["spin"] = False
68-
80+
if "train_loss" not in hparams:
81+
hparams["train_loss"] = "mse_loss"
82+
if "train_loss_arg" not in hparams:
83+
hparams["train_loss_arg"] = {}
6984
self.save_hyperparameters(hparams)
7085

7186
if self.hparams.load_model:
@@ -92,6 +107,16 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
92107
]
93108
)
94109

110+
if self.hparams.train_loss not in loss_class_mapping:
111+
raise ValueError(
112+
f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}"
113+
)
114+
115+
self.train_loss_fn = LossFunction(
116+
loss_class_mapping[self.hparams.train_loss],
117+
self.hparams.train_loss_arg,
118+
)
119+
95120
def configure_optimizers(self):
96121
optimizer = AdamW(
97122
self.model.parameters(),
@@ -105,9 +130,12 @@ def configure_optimizers(self):
105130
patience=self.hparams.lr_patience,
106131
min_lr=self.hparams.lr_min,
107132
)
133+
lr_metric = getattr(self.hparams, "lr_metric", "val")
134+
monitor = f"{lr_metric}_total_{self.hparams.train_loss}"
108135
lr_scheduler = {
109136
"scheduler": scheduler,
110-
"monitor": getattr(self.hparams, "lr_metric", "val_loss"),
137+
"strict": True,
138+
"monitor": monitor,
111139
"interval": "epoch",
112140
"frequency": 1,
113141
}
@@ -126,7 +154,9 @@ def forward(
126154
return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)
127155

128156
def training_step(self, batch, batch_idx):
129-
return self.step(batch, [mse_loss], "train")
157+
return self.step(
158+
batch, [(self.hparams.train_loss, self.train_loss_fn)], "train"
159+
)
130160

131161
def validation_step(self, batch, batch_idx, *args):
132162
# If args is not empty the first (and only) element is the dataloader_idx
@@ -135,28 +165,34 @@ def validation_step(self, batch, batch_idx, *args):
135165
# The dataloader takes care of sending the two sets only when the second one is needed.
136166
is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0)
137167
if is_val:
138-
step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"}
168+
step_type = {
169+
"loss_fn_list": [
170+
("l1_loss", l1_loss),
171+
(self.hparams.train_loss, self.train_loss_fn),
172+
],
173+
"stage": "val",
174+
}
139175
else:
140-
step_type = {"loss_fn_list": [l1_loss], "stage": "test"}
176+
step_type = {"loss_fn_list": [("l1_loss", l1_loss)], "stage": "test"}
141177
return self.step(batch, **step_type)
142178

143179
def test_step(self, batch, batch_idx):
144-
return self.step(batch, [l1_loss], "test")
180+
return self.step(batch, [("l1_loss", l1_loss)], "test")
145181

146-
def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
182+
def _compute_losses(self, y, neg_y, batch, loss_fn, loss_name, stage):
147183
# Compute the loss for the predicted value and the negative derivative (if available)
148184
# Args:
149185
# y: predicted value
150186
# neg_y: predicted negative derivative
151187
# batch: batch of data
152-
# loss_fn: loss function to compute
188+
# loss_fn: The loss function to compute
189+
# loss_name: The name of the loss function
153190
# Returns:
154191
# loss_y: loss for the predicted value
155192
# loss_neg_y: loss for the predicted negative derivative
156193
loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor(
157194
0.0, device=self.device
158195
)
159-
loss_name = loss_fn.__name__
160196
if self.hparams.derivative and "neg_dy" in batch:
161197
loss_neg_y = loss_fn(neg_y, batch.neg_dy)
162198
loss_neg_y = self._update_loss_with_ema(
@@ -221,10 +257,10 @@ def step(self, batch, loss_fn_list, stage):
221257
neg_dy = neg_dy + y.sum() * 0
222258
if "y" in batch and batch.y.ndim == 1:
223259
batch.y = batch.y.unsqueeze(1)
224-
for loss_fn in loss_fn_list:
225-
step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage)
226-
227-
loss_name = loss_fn.__name__
260+
for loss_name, loss_fn in loss_fn_list:
261+
step_losses = self._compute_losses(
262+
y, neg_dy, batch, loss_fn, loss_name, stage
263+
)
228264
if self.hparams.neg_dy_weight > 0:
229265
self.losses[stage]["neg_dy"][loss_name].append(
230266
step_losses["neg_dy"].detach()

torchmdnet/scripts/train.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchmdnet.module import LNNP
1818
from torchmdnet import datasets, priors, models
1919
from torchmdnet.data import DataModule
20+
from torchmdnet.loss import loss_class_mapping
2021
from torchmdnet.models import output_modules
2122
from torchmdnet.models.model import create_prior_models
2223
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
@@ -34,7 +35,7 @@ def get_argparse():
3435
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
3536
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
3637
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
37-
parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
38+
parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate')
3839
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
3940
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
4041
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
@@ -69,6 +70,8 @@ def get_argparse():
6970
parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB')
7071
parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function')
7172
parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function')
73+
parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training')
74+
parser.add_argument('--train-loss-arg', default=None, help='Additional arguments for the loss function. Needs to be a dictionary.')
7275

7376
# model architecture
7477
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
@@ -165,17 +168,16 @@ def main():
165168
# initialize lightning module
166169
model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)
167170

171+
val_loss_name = f"val_total_{args.train_loss}"
168172
checkpoint_callback = ModelCheckpoint(
169173
dirpath=args.log_dir,
170-
monitor="val_total_mse_loss",
174+
monitor=val_loss_name,
171175
save_top_k=10, # -1 to save all
172176
every_n_epochs=args.save_interval,
173-
filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}",
177+
filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}",
174178
auto_insert_metric_name=False,
175179
)
176-
early_stopping = EarlyStopping(
177-
"val_total_mse_loss", patience=args.early_stopping_patience
178-
)
180+
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)
179181

180182
csv_logger = CSVLogger(args.log_dir, name="", version="")
181183
_logger = [csv_logger]

0 commit comments

Comments
 (0)