Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 2efb9fc

Browse files
vreisfacebook-github-bot
authored andcommitted
Remove local_variables from on_step (#411)
Summary: Pull Request resolved: #411 local_variables makes the code in train_step really hard to read. Killing it from all hooks will take time, so start from a single hook (on_step). Differential Revision: D20171981 fbshipit-source-id: a6f158003926425d9f3e0d5e8489447d49bb2443
1 parent 04a99c8 commit 2efb9fc

14 files changed

+70
-60
lines changed

classy_vision/hooks/classy_hook.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, a, b):
5151
def __init__(self):
5252
self.state = ClassyHookState()
5353

54-
def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
54+
def _noop(self, *args, **kwargs) -> None:
5555
"""Derived classes can set their hook functions to this.
5656
5757
This is useful if they want those hook functions to not do anything.
@@ -79,9 +79,7 @@ def on_phase_start(
7979
pass
8080

8181
@abstractmethod
82-
def on_step(
83-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
84-
) -> None:
82+
def on_step(self, task: "tasks.ClassyTask") -> None:
8583
"""Called each time after parameters have been updated by the optimizer."""
8684
pass
8785

classy_vision/hooks/exponential_moving_average_model_hook.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
103103
# state in the test phase
104104
self._save_current_model_state(task.base_model, self.state.model_state)
105105

106-
def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
106+
def on_step(self, task: ClassyTask) -> None:
107107
if not task.train:
108108
return
109109

classy_vision/hooks/loss_lr_meter_logging_hook.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -45,36 +45,30 @@ def on_phase_end(
4545
# trainer to implement an unsynced end of phase meter or
4646
# for meters to not provide a sync function.
4747
logging.info("End of phase metric values:")
48-
self._log_loss_meters(task, local_variables)
48+
self._log_loss_meters(task)
4949
if task.train:
50-
self._log_lr(task, local_variables)
50+
self._log_lr(task)
5151

52-
def on_step(
53-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
54-
) -> None:
52+
def on_step(self, task: "tasks.ClassyTask") -> None:
5553
"""
5654
Log the LR every log_freq batches, if log_freq is not None.
5755
"""
5856
if self.log_freq is None or not task.train:
5957
return
6058
batches = len(task.losses)
6159
if batches and batches % self.log_freq == 0:
62-
self._log_lr(task, local_variables)
60+
self._log_lr(task)
6361
logging.info("Local unsynced metric values:")
64-
self._log_loss_meters(task, local_variables)
62+
self._log_loss_meters(task)
6563

66-
def _log_lr(
67-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
68-
) -> None:
64+
def _log_lr(self, task: "tasks.ClassyTask") -> None:
6965
"""
7066
Compute and log the optimizer LR.
7167
"""
7268
optimizer_lr = task.optimizer.parameters.lr
7369
logging.info("Learning Rate: {}\n".format(optimizer_lr))
7470

75-
def _log_loss_meters(
76-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
77-
) -> None:
71+
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
7872
"""
7973
Compute and log the loss and meters.
8074
"""

classy_vision/hooks/progress_bar_hook.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def on_phase_start(
5151
self.progress_bar = progressbar.ProgressBar(self.bar_size)
5252
self.progress_bar.start()
5353

54-
def on_step(
55-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
56-
) -> None:
54+
def on_step(self, task: "tasks.ClassyTask") -> None:
5755
"""Update the progress bar with the batch size."""
5856
if task.train and is_master() and self.progress_bar is not None:
5957
self.batches += 1

classy_vision/hooks/tensorboard_plot_hook.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def on_phase_start(
6464
self.wall_times = []
6565
self.num_steps_global = []
6666

67-
def on_step(
68-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
69-
) -> None:
67+
def on_step(self, task: "tasks.ClassyTask") -> None:
7068
"""Store the observed learning rates."""
7169
if self.learning_rates is None:
7270
logging.warning("learning_rates is not initialized")

classy_vision/hooks/time_metrics_hook.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,17 @@ def on_phase_start(
4040
Initialize start time and reset perf stats
4141
"""
4242
self.start_time = time.time()
43-
local_variables["perf_stats"] = PerfStats()
43+
task.perf_stats = PerfStats()
4444

45-
def on_step(
46-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
47-
) -> None:
45+
def on_step(self, task: "tasks.ClassyTask") -> None:
4846
"""
4947
Log metrics every log_freq batches, if log_freq is not None.
5048
"""
5149
if self.log_freq is None:
5250
return
5351
batches = len(task.losses)
5452
if batches and batches % self.log_freq == 0:
55-
self._log_performance_metrics(task, local_variables)
53+
self._log_performance_metrics(task)
5654

5755
def on_phase_end(
5856
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
@@ -62,11 +60,9 @@ def on_phase_end(
6260
"""
6361
batches = len(task.losses)
6462
if batches:
65-
self._log_performance_metrics(task, local_variables)
63+
self._log_performance_metrics(task)
6664

67-
def _log_performance_metrics(
68-
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
69-
) -> None:
65+
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
7066
"""
7167
Compute and log performance metrics.
7268
"""
@@ -85,11 +81,11 @@ def _log_performance_metrics(
8581
)
8682

8783
# Train step time breakdown
88-
if local_variables.get("perf_stats") is None:
89-
logging.warning('"perf_stats" not set in local_variables')
84+
if not hasattr(task, "perf_stats") or task.perf_stats is None:
85+
logging.warning('"perf_stats" not set in task')
9086
elif task.train:
9187
logging.info(
9288
"Train step time breakdown (rank {}):\n{}".format(
93-
get_rank(), local_variables["perf_stats"].report_str()
89+
get_rank(), task.perf_stats.report_str()
9490
)
9591
)

classy_vision/tasks/classification_task.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import enum
99
import logging
1010
import time
11-
from typing import Any, Dict, List, Optional, Union
11+
from typing import Any, Dict, List, NamedTuple, Optional, Union
1212

1313
import torch
1414
from classy_vision.dataset import ClassyDataset, build_dataset
@@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum):
5454
BEFORE_EVAL = enum.auto()
5555

5656

57+
class LastBatchInfo(NamedTuple):
58+
loss: torch.Tensor
59+
output: torch.Tensor
60+
target: torch.Tensor
61+
sample: Dict[str, Any]
62+
63+
5764
@register_task("classification_task")
5865
class ClassificationTask(ClassyTask):
5966
"""Basic classification training task.
@@ -672,6 +679,14 @@ def eval_step(self, use_gpu, local_variables=None):
672679

673680
self.update_meters(local_variables["output"], local_variables["sample"])
674681

682+
# Move some data to the task so hooks get a chance to access it
683+
self.last_batch = LastBatchInfo(
684+
loss=local_variables["loss"],
685+
output=local_variables["output"],
686+
target=local_variables["target"],
687+
sample=local_variables["sample"],
688+
)
689+
675690
def train_step(self, use_gpu, local_variables=None):
676691
"""Train step to be executed in train loop
677692
@@ -684,6 +699,8 @@ def train_step(self, use_gpu, local_variables=None):
684699
if local_variables is None:
685700
local_variables = {}
686701

702+
self.last_batch = None
703+
687704
# Process next sample
688705
sample = next(self.get_data_iterator())
689706
local_variables["sample"] = sample
@@ -738,6 +755,14 @@ def train_step(self, use_gpu, local_variables=None):
738755

739756
self.num_updates += self.get_global_batchsize()
740757

758+
# Move some data to the task so hooks get a chance to access it
759+
self.last_batch = LastBatchInfo(
760+
loss=local_variables["loss"],
761+
output=local_variables["output"],
762+
target=local_variables["target"],
763+
sample=local_variables["sample"],
764+
)
765+
741766
def compute_loss(self, model_output, sample):
742767
return self.loss(model_output, sample["target"])
743768

classy_vision/tasks/classy_task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
178178
else:
179179
self.eval_step(use_gpu, local_variables)
180180

181-
self.run_hooks(local_variables, ClassyHookFunctions.on_step.name)
181+
for hook in self.hooks:
182+
hook.on_step(self)
182183

183184
def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None:
184185
"""

test/hooks_exponential_moving_average_model_hook_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
5353
task.base_model.update_fc_weight()
5454
fc_weight = model.fc.weight.clone()
5555
for _ in range(num_updates):
56-
exponential_moving_average_hook.on_step(task, local_variables)
56+
exponential_moving_average_hook.on_step(task)
5757
exponential_moving_average_hook.on_phase_end(task, local_variables)
5858
# the model weights shouldn't have changed
5959
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))

test/hooks_loss_lr_meter_logging_hook_test.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,27 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
5151

5252
for i in range(num_batches):
5353
task.losses = list(range(i))
54-
loss_lr_meter_hook.on_step(task, local_variables)
54+
loss_lr_meter_hook.on_step(task)
5555
if log_freq is not None and i and i % log_freq == 0:
56-
mock_fn.assert_called_with(task, local_variables)
56+
mock_fn.assert_called_with(task)
5757
mock_fn.reset_mock()
58-
mock_lr_fn.assert_called_with(task, local_variables)
58+
mock_lr_fn.assert_called_with(task)
5959
mock_lr_fn.reset_mock()
6060
continue
6161
mock_fn.assert_not_called()
6262
mock_lr_fn.assert_not_called()
6363

6464
loss_lr_meter_hook.on_phase_end(task, local_variables)
65-
mock_fn.assert_called_with(task, local_variables)
65+
mock_fn.assert_called_with(task)
6666
if task.train:
67-
mock_lr_fn.assert_called_with(task, local_variables)
67+
mock_lr_fn.assert_called_with(task)
6868

6969
# test _log_loss_lr_meters()
7070
task.losses = losses
7171

7272
with self.assertLogs():
73-
loss_lr_meter_hook._log_loss_meters(task, local_variables)
74-
loss_lr_meter_hook._log_lr(task, local_variables)
73+
loss_lr_meter_hook._log_loss_meters(task)
74+
loss_lr_meter_hook._log_lr(task)
7575

7676
task.phase_idx += 1
7777

@@ -95,7 +95,7 @@ def scheduler_mock(where):
9595
lr_order = [0.0, 1 / 6, 1 / 6, 2 / 6, 3 / 6, 3 / 6, 4 / 6, 5 / 6, 5 / 6]
9696
lr_list = []
9797

98-
def mock_log_lr(task: ClassyTask, local_variables) -> None:
98+
def mock_log_lr(task: ClassyTask) -> None:
9999
lr_list.append(task.optimizer.parameters.lr)
100100

101101
with mock.patch.object(

test/hooks_time_metrics_hook_test.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_time_metrics(
4949
mock_time.return_value = start_time
5050
time_metrics_hook.on_phase_start(task, local_variables)
5151
self.assertEqual(time_metrics_hook.start_time, start_time)
52-
self.assertTrue(isinstance(local_variables.get("perf_stats"), PerfStats))
52+
self.assertTrue(isinstance(task.perf_stats, PerfStats))
5353

5454
# test that the code doesn't raise an exception if losses is empty
5555
try:
@@ -66,15 +66,15 @@ def test_time_metrics(
6666

6767
for i in range(num_batches):
6868
task.losses = list(range(i))
69-
time_metrics_hook.on_step(task, local_variables)
69+
time_metrics_hook.on_step(task)
7070
if log_freq is not None and i and i % log_freq == 0:
71-
mock_fn.assert_called_with(task, local_variables)
71+
mock_fn.assert_called_with(task)
7272
mock_fn.reset_mock()
7373
continue
7474
mock_fn.assert_not_called()
7575

7676
time_metrics_hook.on_phase_end(task, local_variables)
77-
mock_fn.assert_called_with(task, local_variables)
77+
mock_fn.assert_called_with(task)
7878

7979
task.losses = [0.23, 0.45, 0.34, 0.67]
8080

@@ -84,7 +84,7 @@ def test_time_metrics(
8484

8585
# test _log_performance_metrics()
8686
with self.assertLogs() as log_watcher:
87-
time_metrics_hook._log_performance_metrics(task, local_variables)
87+
time_metrics_hook._log_performance_metrics(task)
8888

8989
# there should 2 be info logs for train and 1 for test
9090
self.assertEqual(len(log_watcher.output), 2 if train else 1)
@@ -112,7 +112,7 @@ def test_time_metrics(
112112

113113
# if on_phase_start() is not called, 2 warnings should be logged
114114
# create a new time metrics hook
115-
local_variables = {}
115+
task.perf_stats = None
116116
time_metrics_hook_new = TimeMetricsHook()
117117

118118
with self.assertLogs() as log_watcher:

test/manual/hooks_progress_bar_hook_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def test_progress_bar(
4848

4949
# on_step should update the progress bar correctly
5050
for i in range(num_batches):
51-
progress_bar_hook.on_step(task, local_variables)
51+
progress_bar_hook.on_step(task)
5252
mock_progress_bar.update.assert_called_once_with(i + 1)
5353
mock_progress_bar.update.reset_mock()
5454

5555
# check that even if on_step is called again, the progress bar is
5656
# only updated with num_batches
5757
for _ in range(num_batches):
58-
progress_bar_hook.on_step(task, local_variables)
58+
progress_bar_hook.on_step(task)
5959
mock_progress_bar.update.assert_called_once_with(num_batches)
6060
mock_progress_bar.update.reset_mock()
6161

@@ -68,7 +68,7 @@ def test_progress_bar(
6868
# crash
6969
progress_bar_hook = ProgressBarHook()
7070
try:
71-
progress_bar_hook.on_step(task, local_variables)
71+
progress_bar_hook.on_step(task)
7272
progress_bar_hook.on_phase_end(task, local_variables)
7373
except Exception as e:
7474
self.fail(
@@ -81,7 +81,7 @@ def test_progress_bar(
8181
progress_bar_hook = ProgressBarHook()
8282
try:
8383
progress_bar_hook.on_phase_start(task, local_variables)
84-
progress_bar_hook.on_step(task, local_variables)
84+
progress_bar_hook.on_step(task)
8585
progress_bar_hook.on_phase_end(task, local_variables)
8686
except Exception as e:
8787
self.fail("Received Exception when is_master() is False: {}".format(e))

test/manual/hooks_tensorboard_plot_hook_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
6262
# the writer if on_phase_start() is not called for initialization
6363
# before on_step() is called.
6464
with self.assertLogs() as log_watcher:
65-
tensorboard_plot_hook.on_step(task, local_variables)
65+
tensorboard_plot_hook.on_step(task)
6666

6767
self.assertTrue(
6868
len(log_watcher.records) == 1
@@ -88,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
8888

8989
for loss in losses:
9090
task.losses.append(loss)
91-
tensorboard_plot_hook.on_step(task, local_variables)
91+
tensorboard_plot_hook.on_step(task)
9292

9393
tensorboard_plot_hook.on_phase_end(task, local_variables)
9494

test/optim_param_scheduler_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class TestHook(ClassyHook):
207207
on_phase_end = ClassyHook._noop
208208
on_end = ClassyHook._noop
209209

210-
def on_step(self, task: ClassyTask, local_variables) -> None:
210+
def on_step(self, task: ClassyTask) -> None:
211211
if not task.train:
212212
return
213213

0 commit comments

Comments
 (0)