Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit 43f230c

Browse files
jingli9111facebook-github-bot
authored andcommitted
LARS exclude BN: use a flag instead of a function (#388)
Summary: This fix the problem that the Barlow Twins model needs to save a function in the checkpoint. Pull Request resolved: #388 Reviewed By: iseessel Differential Revision: D30158877 Pulled By: prigoyal fbshipit-source-id: 537d0686422148447a4a42e14b448eb6e592eec9
1 parent 554aa15 commit 43f230c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vissl/optimizers/lars.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
"weight_decay": weight_decay,
8888
"momentum": momentum,
8989
"eta": eta,
90-
"exclude": self._exclude_bias_and_norm if exclude_bias_and_norm else None,
90+
"exclude": exclude_bias_and_norm,
9191
}
9292
super().__init__(params, defaults)
9393

@@ -106,7 +106,7 @@ def step(self):
106106

107107
dp = dp.add(p, alpha=g["weight_decay"])
108108

109-
if g["exclude"] is None or not g["exclude"](p):
109+
if not g["exclude"] or not self._exclude_bias_and_norm(p):
110110
param_norm = torch.norm(p)
111111
update_norm = torch.norm(dp)
112112

0 commit comments

Comments
 (0)