Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit b8e30eb

Browse files
prigoyalfacebook-github-bot
authored andcommitted
Remove the need to iterate the model parameters for freezing and speedup (#176)
Summary: Pull Request resolved: #176 in swav, there is a parameter that's frozen for 313 iterations. the previous logic iterated through all model named_parameters and checked if the parameter is frozen + iterations <313. This is inefficient as the model grows bigger in size and iterating through all parameters for the check will be slow. alternately, now we check whats the max frozen iteration to which a parameter is frozen and we only iterate named_parameters() of the model if the current iteration <= max frozen iteration inspired by blefaudeux Reviewed By: min-xu-ai Differential Revision: D26276803 fbshipit-source-id: 6d91658273e5ea7fb56ea2c796e0707e90c801b7
1 parent 94598ba commit b8e30eb

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

vissl/hooks/state_update_hooks.py

+12
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,18 @@ def on_backward(self, task: "tasks.ClassyTask") -> None:
260260
map_params_to_iters = {}
261261
for to_map in task.config.MODEL.TEMP_FROZEN_PARAMS_ITER_MAP:
262262
map_params_to_iters[to_map[0]] = to_map[1]
263+
264+
# get the maximum iterations until which the params are frozen.
265+
# if the iterations are past the maximum iterations freezing any
266+
# param, we simply return.
267+
max_iterations = max(list(map_params_to_iters.values()))
268+
if task.iteration >= max_iterations:
269+
if task.iteration == max_iterations:
270+
logging.info(
271+
f"No parameters grad removed from now on: {task.iteration}"
272+
)
273+
return
274+
263275
for name, p in task.model.named_parameters():
264276
if (
265277
name in map_params_to_iters

0 commit comments

Comments
 (0)