Skip to content

Add Huber loss, allow choosing training loss function from the yaml #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 16, 2024
7 changes: 7 additions & 0 deletions torchmdnet/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torch.nn.functional import mse_loss, l1_loss, huber_loss

loss_class_mapping = {
"mse_loss": mse_loss,
"l1_loss": l1_loss,
"huber_loss": huber_loss,
}
66 changes: 51 additions & 15 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch.nn.functional import local_response_norm
from torch import Tensor
from typing import Optional, Dict, Tuple

from lightning import LightningModule
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models.utils import dtype_mapping
from torchmdnet.loss import l1_loss, loss_class_mapping
import torch_geometric.transforms as T


Expand Down Expand Up @@ -48,6 +48,18 @@ def __call__(self, data):
return data


# This wrapper is here in order to permit Lightning to serialize the loss function.
class LossFunction:
def __init__(self, loss_fn, extra_args=None):
self.loss_fn = loss_fn
self.extra_args = extra_args
if self.extra_args is None:
self.extra_args = {}

def __call__(self, x, batch):
return self.loss_fn(x, batch, **self.extra_args)


class LNNP(LightningModule):
"""
Lightning wrapper for the Neural Network Potentials in TorchMD-Net.
Expand All @@ -65,7 +77,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
hparams["charge"] = False
if "spin" not in hparams:
hparams["spin"] = False

if "train_loss" not in hparams:
hparams["train_loss"] = "mse_loss"
if "train_loss_arg" not in hparams:
hparams["train_loss_arg"] = {}
self.save_hyperparameters(hparams)

if self.hparams.load_model:
Expand All @@ -92,6 +107,16 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
]
)

if self.hparams.train_loss not in loss_class_mapping:
raise ValueError(
f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}"
)

self.train_loss_fn = LossFunction(
loss_class_mapping[self.hparams.train_loss],
self.hparams.train_loss_arg,
)

def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
Expand All @@ -105,9 +130,12 @@ def configure_optimizers(self):
patience=self.hparams.lr_patience,
min_lr=self.hparams.lr_min,
)
lr_metric = getattr(self.hparams, "lr_metric", "val")
monitor = f"{lr_metric}_total_{self.hparams.train_loss}"
lr_scheduler = {
"scheduler": scheduler,
"monitor": getattr(self.hparams, "lr_metric", "val_loss"),
"strict": True,
"monitor": monitor,
"interval": "epoch",
"frequency": 1,
}
Expand All @@ -126,7 +154,9 @@ def forward(
return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)

def training_step(self, batch, batch_idx):
return self.step(batch, [mse_loss], "train")
return self.step(
batch, [(self.hparams.train_loss, self.train_loss_fn)], "train"
)

def validation_step(self, batch, batch_idx, *args):
# If args is not empty the first (and only) element is the dataloader_idx
Expand All @@ -135,28 +165,34 @@ def validation_step(self, batch, batch_idx, *args):
# The dataloader takes care of sending the two sets only when the second one is needed.
is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0)
if is_val:
step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"}
step_type = {
"loss_fn_list": [
("l1_loss", l1_loss),
(self.hparams.train_loss, self.train_loss_fn),
],
"stage": "val",
}
else:
step_type = {"loss_fn_list": [l1_loss], "stage": "test"}
step_type = {"loss_fn_list": [("l1_loss", l1_loss)], "stage": "test"}
return self.step(batch, **step_type)

def test_step(self, batch, batch_idx):
return self.step(batch, [l1_loss], "test")
return self.step(batch, [("l1_loss", l1_loss)], "test")

def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
def _compute_losses(self, y, neg_y, batch, loss_fn, loss_name, stage):
# Compute the loss for the predicted value and the negative derivative (if available)
# Args:
# y: predicted value
# neg_y: predicted negative derivative
# batch: batch of data
# loss_fn: loss function to compute
# loss_fn: The loss function to compute
# loss_name: The name of the loss function
# Returns:
# loss_y: loss for the predicted value
# loss_neg_y: loss for the predicted negative derivative
loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor(
0.0, device=self.device
)
loss_name = loss_fn.__name__
if self.hparams.derivative and "neg_dy" in batch:
loss_neg_y = loss_fn(neg_y, batch.neg_dy)
loss_neg_y = self._update_loss_with_ema(
Expand Down Expand Up @@ -221,10 +257,10 @@ def step(self, batch, loss_fn_list, stage):
neg_dy = neg_dy + y.sum() * 0
if "y" in batch and batch.y.ndim == 1:
batch.y = batch.y.unsqueeze(1)
for loss_fn in loss_fn_list:
step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage)

loss_name = loss_fn.__name__
for loss_name, loss_fn in loss_fn_list:
step_losses = self._compute_losses(
y, neg_dy, batch, loss_fn, loss_name, stage
)
if self.hparams.neg_dy_weight > 0:
self.losses[stage]["neg_dy"][loss_name].append(
step_losses["neg_dy"].detach()
Expand Down
14 changes: 8 additions & 6 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.loss import loss_class_mapping
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
Expand All @@ -34,7 +35,7 @@ def get_argparse():
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
Expand Down Expand Up @@ -69,6 +70,8 @@ def get_argparse():
parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB')
parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function')
parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function')
parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training')
parser.add_argument('--train-loss-arg', default=None, help='Additional arguments for the loss function. Needs to be a dictionary.')

# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
Expand Down Expand Up @@ -165,17 +168,16 @@ def main():
# initialize lightning module
model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)

val_loss_name = f"val_total_{args.train_loss}"
checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_total_mse_loss",
monitor=val_loss_name,
save_top_k=10, # -1 to save all
every_n_epochs=args.save_interval,
filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}",
filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}",
auto_insert_metric_name=False,
)
early_stopping = EarlyStopping(
"val_total_mse_loss", patience=args.early_stopping_patience
)
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)

csv_logger = CSVLogger(args.log_dir, name="", version="")
_logger = [csv_logger]
Expand Down
Loading