Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit eef376f

Browse files
vreisfacebook-github-bot
authored andcommitted
Revamp logging (#478)
Summary: This is a bunch of changes to make our training logs more meaningful and easier to understand. We print the task config in the beginning of training, make it clear what values are approximate or final, supress verbose logs by default and format floats accordingly. Before this diff: P128995674 After this diff: P128995244 Pull Request resolved: #478 Reviewed By: mannatsingh Differential Revision: D21022171 Pulled By: vreis fbshipit-source-id: d63d5ac9b4b3b3cc9abb359914141bf6821dc14a
1 parent d16fc33 commit eef376f

9 files changed

+74
-78
lines changed

classy_vision/hooks/loss_lr_meter_logging_hook.py

+11-26
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class LossLrMeterLoggingHook(ClassyHook):
1818
Logs the loss, optimizer LR, and meters. Logs at the end of a phase.
1919
"""
2020

21-
on_start = ClassyHook._noop
2221
on_phase_start = ClassyHook._noop
2322
on_end = ClassyHook._noop
2423

@@ -35,6 +34,9 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
3534
), "log_freq must be an int or None"
3635
self.log_freq: Optional[int] = log_freq
3736

37+
def on_start(self, task) -> None:
38+
logging.info(f"Starting training. Task: {task}")
39+
3840
def on_phase_end(self, task) -> None:
3941
"""
4042
Log the loss, optimizer LR, and meters for the phase.
@@ -45,10 +47,7 @@ def on_phase_end(self, task) -> None:
4547
# do not explicitly state this since it is possible for a
4648
# trainer to implement an unsynced end of phase meter or
4749
# for meters to not provide a sync function.
48-
logging.info("End of phase metric values:")
49-
self._log_loss_meters(task)
50-
if task.train:
51-
self._log_lr(task)
50+
self._log_loss_meters(task, prefix="Synced meters: ")
5251

5352
def on_step(self, task) -> None:
5453
"""
@@ -58,18 +57,9 @@ def on_step(self, task) -> None:
5857
return
5958
batches = len(task.losses)
6059
if batches and batches % self.log_freq == 0:
61-
self._log_lr(task)
62-
logging.info("Local unsynced metric values:")
63-
self._log_loss_meters(task)
64-
65-
def _log_lr(self, task) -> None:
66-
"""
67-
Compute and log the optimizer LR.
68-
"""
69-
optimizer_lr = task.optimizer.parameters.lr
70-
logging.info("Learning Rate: {}\n".format(optimizer_lr))
60+
self._log_loss_meters(task, prefix="Approximate meters: ")
7161

72-
def _log_loss_meters(self, task) -> None:
62+
def _log_loss_meters(self, task, prefix="") -> None:
7363
"""
7464
Compute and log the loss and meters.
7565
"""
@@ -80,14 +70,9 @@ def _log_loss_meters(self, task) -> None:
8070

8171
# Loss for the phase
8272
loss = sum(task.losses) / (batches * task.get_batchsize_per_replica())
73+
phase_pct = batches / task.num_batches_per_phase
8374

84-
log_strs = [
85-
"Rank: {}, {} phase: {}, processed batches: {}".format(
86-
get_rank(), phase_type, phase_type_idx, batches
87-
),
88-
"{} loss: {}".format(phase_type, loss),
89-
"Meters:",
90-
]
91-
for meter in task.meters:
92-
log_strs.append("{}".format(meter))
93-
logging.info("\n".join(log_strs))
75+
logging.info(
76+
f"{prefix}[{get_rank()}] {phase_type} phase {phase_type_idx} "
77+
f"({phase_pct*100:.2f}% done), loss: {loss:.4f}, meters: {task.meters}"
78+
)

classy_vision/hooks/model_complexity_hook.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def on_start(self, task) -> None:
5353
)
5454
except NotImplementedError:
5555
logging.warning(
56-
"""Model contains unsupported modules:
57-
Could not compute FLOPs for model forward pass. Exception:""",
58-
exc_info=True,
56+
"Model contains unsupported modules, "
57+
"could not compute FLOPs for model forward pass."
5958
)
59+
logging.debug("Exception:", exc_info=True)
6060
try:
6161
self.num_activations = compute_activations(
6262
task.base_model,

classy_vision/meters/accuracy_meter.py

-3
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def set_classy_state(self, state):
126126
self._curr_correct_predictions_k = state["curr_correct_predictions_k"].clone()
127127
self._curr_sample_count = state["curr_sample_count"].clone()
128128

129-
def __repr__(self):
130-
return repr({"name": self.name, "value": self.value})
131-
132129
def update(self, model_output, target, **kwargs):
133130
"""
134131
args:

classy_vision/meters/classy_meter.py

+13
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,16 @@ def set_classy_state(self, state: Dict[str, Any]) -> None:
114114
This is used to load the state of the meter from a checkpoint.
115115
"""
116116
raise NotImplementedError
117+
118+
def __repr__(self):
119+
"""Returns a string representation of the meter, used for logging.
120+
121+
The default implementation assumes value is a dict. value is not
122+
required to be a dict, and in that case you should override this
123+
method."""
124+
125+
if not isinstance(self.value, dict):
126+
return super().__repr__()
127+
128+
values = ",".join([f"{key}={value:.6f}" for key, value in self.value.items()])
129+
return f"{self.name}_meter({values})"

classy_vision/meters/precision_meter.py

-3
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,6 @@ def set_classy_state(self, state):
127127
self._curr_correct_predictions_k = state["curr_correct_predictions_k"].clone()
128128
self._curr_sample_count = state["curr_sample_count"].clone()
129129

130-
def __repr__(self):
131-
return repr({"name": self.name, "value": self.value})
132-
133130
def update(self, model_output, target, **kwargs):
134131
"""
135132
args:

classy_vision/meters/recall_meter.py

-3
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def set_classy_state(self, state):
126126
self._curr_correct_predictions_k = state["curr_correct_predictions_k"].clone()
127127
self._curr_correct_targets = state["curr_correct_targets"].clone()
128128

129-
def __repr__(self):
130-
return repr({"name": self.name, "value": self.value})
131-
132129
def update(self, model_output, target, **kwargs):
133130
"""
134131
args:

classy_vision/meters/video_meter.py

-3
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ def set_classy_state(self, state):
7676
self.reset()
7777
self.meter.set_classy_state(state["meter_state"])
7878

79-
def __repr__(self):
80-
return repr({"name": self.name, "value": self.value})
81-
8279
def update(self, model_output, target, is_train, **kwargs):
8380
"""Updates any internal state of meter with new model output and target.
8481

classy_vision/tasks/classification_task.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import enum
9+
import json
910
import logging
1011
import math
1112
import time
@@ -413,6 +414,10 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
413414
for phase_type in phase_types:
414415
task.set_dataset(datasets[phase_type], phase_type)
415416

417+
# NOTE: this is a private member and only meant to be used for
418+
# logging/debugging purposes. See __repr__ implementation
419+
task._config = config
420+
416421
return task
417422

418423
@property
@@ -854,7 +859,7 @@ def advance_phase(self):
854859
resets counters, shuffles dataset, rebuilds iterators, and
855860
sets the train / test state for phase.
856861
"""
857-
logging.info("Advancing phase")
862+
logging.debug("Advancing phase")
858863
# Reset meters for next phase / epoch
859864
for meter in self.meters:
860865
meter.reset()
@@ -893,7 +898,7 @@ def _recreate_data_loader_from_dataset(self, phase_type=None):
893898
if phase_type is None:
894899
phase_type = self.phase_type
895900

896-
logging.info("Recreating data loader for new phase")
901+
logging.debug("Recreating data loader for new phase")
897902
num_workers = 0
898903
if hasattr(self.dataloaders[phase_type], "num_workers"):
899904
num_workers = self.dataloaders[phase_type].num_workers
@@ -979,10 +984,10 @@ def on_phase_start(self):
979984
def on_phase_end(self):
980985
self.log_phase_end("train")
981986

982-
logging.info("Syncing meters on phase end...")
987+
logging.debug("Syncing meters on phase end...")
983988
for meter in self.meters:
984989
meter.sync_state()
985-
logging.info("...meters synced")
990+
logging.debug("...meters synced")
986991
barrier()
987992

988993
for hook in self.hooks:
@@ -1016,3 +1021,10 @@ def log_phase_end(self, tag):
10161021
"im_per_sec": im_per_sec,
10171022
}
10181023
)
1024+
1025+
def __repr__(self):
1026+
if hasattr(self, "_config"):
1027+
config = json.dumps(self._config, indent=4)
1028+
return f"{super().__repr__()} initialized with config:\n{config}"
1029+
1030+
return super().__repr__()

test/hooks_loss_lr_meter_logging_hook_test.py

+31-33
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from test.generic.config_utils import get_test_mlp_task_config, get_test_task_config
1212
from test.generic.hook_test_utils import HookTestBase
1313

14-
from classy_vision.hooks import LossLrMeterLoggingHook
14+
from classy_vision.hooks import ClassyHook, LossLrMeterLoggingHook
1515
from classy_vision.optim.param_scheduler import UpdateInterval
1616
from classy_vision.tasks import ClassyTask, build_task
1717
from classy_vision.trainer import LocalTrainer
@@ -48,6 +48,8 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
4848
config["dataset"]["test"]["batchsize_per_replica"] = 5
4949
task = build_task(config)
5050
task.prepare()
51+
task.on_start()
52+
task.on_phase_start()
5153

5254
losses = [1.2, 2.3, 3.4, 4.5]
5355

@@ -62,32 +64,25 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
6264
# and _log_lr() is called after on_step() every log_freq batches
6365
# and after on_phase_end()
6466
with mock.patch.object(loss_lr_meter_hook, "_log_loss_meters") as mock_fn:
65-
with mock.patch.object(loss_lr_meter_hook, "_log_lr") as mock_lr_fn:
66-
num_batches = 20
67-
68-
for i in range(num_batches):
69-
task.losses = list(range(i))
70-
loss_lr_meter_hook.on_step(task)
71-
if log_freq is not None and i and i % log_freq == 0:
72-
mock_fn.assert_called_with(task)
73-
mock_fn.reset_mock()
74-
mock_lr_fn.assert_called_with(task)
75-
mock_lr_fn.reset_mock()
76-
continue
77-
mock_fn.assert_not_called()
78-
mock_lr_fn.assert_not_called()
79-
80-
loss_lr_meter_hook.on_phase_end(task)
81-
mock_fn.assert_called_with(task)
82-
if task.train:
83-
mock_lr_fn.assert_called_with(task)
67+
num_batches = 20
68+
69+
for i in range(num_batches):
70+
task.losses = list(range(i))
71+
loss_lr_meter_hook.on_step(task)
72+
if log_freq is not None and i and i % log_freq == 0:
73+
mock_fn.assert_called()
74+
mock_fn.reset_mock()
75+
continue
76+
mock_fn.assert_not_called()
77+
78+
loss_lr_meter_hook.on_phase_end(task)
79+
mock_fn.assert_called()
8480

8581
# test _log_loss_lr_meters()
8682
task.losses = losses
8783

8884
with self.assertLogs():
8985
loss_lr_meter_hook._log_loss_meters(task)
90-
loss_lr_meter_hook._log_lr(task)
9186

9287
task.phase_idx += 1
9388

@@ -106,18 +101,21 @@ def scheduler_mock(where):
106101
task.optimizer.param_schedulers["lr"] = mock_lr_scheduler
107102
trainer = LocalTrainer()
108103

109-
# 2 LR updates per epoch
110-
# At end of each epoch for train, LR is logged an additional time
111-
lr_order = [0.0, 1 / 6, 1 / 6, 2 / 6, 3 / 6, 3 / 6, 4 / 6, 5 / 6, 5 / 6]
104+
# 2 LR updates per epoch = 6
105+
lr_order = [0.0, 1 / 6, 2 / 6, 3 / 6, 4 / 6, 5 / 6]
112106
lr_list = []
113107

114-
def mock_log_lr(task: ClassyTask) -> None:
115-
lr_list.append(task.optimizer.parameters.lr)
108+
class LRLoggingHook(ClassyHook):
109+
on_end = ClassyHook._noop
110+
on_phase_end = ClassyHook._noop
111+
on_phase_start = ClassyHook._noop
112+
on_start = ClassyHook._noop
113+
114+
def on_step(self, task):
115+
if task.train:
116+
lr_list.append(task.optimizer.parameters.lr)
116117

117-
with mock.patch.object(
118-
LossLrMeterLoggingHook, "_log_lr", side_effect=mock_log_lr
119-
):
120-
hook = LossLrMeterLoggingHook(1)
121-
task.set_hooks([hook])
122-
trainer.train(task)
123-
self.assertEqual(lr_list, lr_order)
118+
hook = LRLoggingHook()
119+
task.set_hooks([hook])
120+
trainer.train(task)
121+
self.assertEqual(lr_list, lr_order)

0 commit comments

Comments
 (0)