@@ -50,12 +50,14 @@ def __call__(self, data):
50
50
51
51
# This wrapper is here in order to permit Lightning to serialize the loss function.
52
52
class LossFunction :
53
- def __init__ (self , loss_fn , ** kwargs ):
53
+ def __init__ (self , loss_fn , extra_args = None ):
54
54
self .loss_fn = loss_fn
55
- self .kwargs = kwargs
55
+ self .extra_args = extra_args
56
+ if self .extra_args is None :
57
+ self .extra_args = {}
56
58
57
59
def __call__ (self , x , batch ):
58
- return self .loss_fn (x , batch , ** self .kwargs )
60
+ return self .loss_fn (x , batch , ** self .extra_args )
59
61
60
62
61
63
class LNNP (LightningModule ):
@@ -109,12 +111,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
109
111
raise ValueError (
110
112
f"Training loss { self .hparams .train_loss } not supported. Supported losses are { list (loss_class_mapping .keys ())} "
111
113
)
112
- if self .hparams .train_loss_arg is None :
113
- self .hparams .train_loss_arg = {}
114
114
115
115
self .train_loss_fn = LossFunction (
116
116
loss_class_mapping [self .hparams .train_loss ],
117
- ** self .hparams .train_loss_arg ,
117
+ self .hparams .train_loss_arg ,
118
118
)
119
119
120
120
def configure_optimizers (self ):
0 commit comments