Skip to content

Commit bfed435

Browse files
committed
Fix hparam
1 parent 5a88579 commit bfed435

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchmdnet/module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ def __call__(self, data):
5050

5151
# This wrapper is here in order to permit Lightning to serialize the loss function.
5252
class LossFunction:
53-
def __init__(self, loss_fn, **kwargs):
53+
def __init__(self, loss_fn, extra_args=None):
5454
self.loss_fn = loss_fn
55-
self.kwargs = kwargs
55+
self.extra_args = extra_args
56+
if self.extra_args is None:
57+
self.extra_args = {}
5658

5759
def __call__(self, x, batch):
58-
return self.loss_fn(x, batch, **self.kwargs)
60+
return self.loss_fn(x, batch, **self.extra_args)
5961

6062

6163
class LNNP(LightningModule):
@@ -109,12 +111,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
109111
raise ValueError(
110112
f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}"
111113
)
112-
if self.hparams.train_loss_arg is None:
113-
self.hparams.train_loss_arg = {}
114114

115115
self.train_loss_fn = LossFunction(
116116
loss_class_mapping[self.hparams.train_loss],
117-
**self.hparams.train_loss_arg,
117+
self.hparams.train_loss_arg,
118118
)
119119

120120
def configure_optimizers(self):

0 commit comments

Comments
 (0)