|
22 | 22 | from vissl.utils.perf_stats import PerfStats
|
23 | 23 |
|
24 | 24 |
|
| 25 | +class LogGpuMemoryHook(ClassyHook): |
| 26 | + """ |
| 27 | + Hook executed at a specified iteration number and prints the |
| 28 | + memory summary for the primary device at several steps of training. |
| 29 | + """ |
| 30 | + |
| 31 | + on_start = ClassyHook._noop |
| 32 | + on_loss_and_meter = ClassyHook._noop |
| 33 | + on_step = ClassyHook._noop |
| 34 | + on_phase_end = ClassyHook._noop |
| 35 | + on_end = ClassyHook._noop |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + log_iteration_num: int = 1, |
| 40 | + ) -> None: |
| 41 | + super().__init__() |
| 42 | + self.log_iteration_num = log_iteration_num |
| 43 | + |
| 44 | + def on_phase_start(self, task: "tasks.ClassyTask") -> None: |
| 45 | + """ |
| 46 | + Print the stats just before the training epoch starts |
| 47 | + """ |
| 48 | + self._print_memory_summary(task, "on_phase_start") |
| 49 | + |
| 50 | + def on_forward(self, task: "tasks.ClassyTask") -> None: |
| 51 | + """ |
| 52 | + Print the stats after the model forward pass is done |
| 53 | + """ |
| 54 | + self._print_memory_summary(task, "on_forward") |
| 55 | + |
| 56 | + def on_backward(self, task: "tasks.ClassyTask") -> None: |
| 57 | + """ |
| 58 | + Print the stats just after model.backward() is done |
| 59 | + """ |
| 60 | + self._print_memory_summary(task, "on_backward") |
| 61 | + |
| 62 | + def on_update(self, task: "tasks.ClassyTask") -> None: |
| 63 | + """ |
| 64 | + Print the stats just after model params are updated |
| 65 | + """ |
| 66 | + self._print_memory_summary(task, "on_update") |
| 67 | + |
| 68 | + def _print_memory_summary(self, task: "tasks.ClassyTask", stage_name: str) -> None: |
| 69 | + if ( |
| 70 | + is_primary() |
| 71 | + and (task.device.type == "cuda") |
| 72 | + and task.local_iteration_num == self.log_iteration_num |
| 73 | + ): |
| 74 | + logging.info( |
| 75 | + f"========= Memory Summary at {stage_name} =======" |
| 76 | + f"\n{torch.cuda.memory_summary()}\n" |
| 77 | + ) |
| 78 | + |
| 79 | + |
25 | 80 | class LogGpuStatsHook(ClassyHook):
|
26 | 81 | """
|
27 | 82 | Hook executed at the start of training and after every training iteration is done.
|
@@ -92,8 +147,8 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
|
92 | 147 | monitoring the stats (optionally) for every N iterations to get better
|
93 | 148 | idea about the batch time and training eta.
|
94 | 149 |
|
95 |
| - Set the btime_freq input using cfg.PERF_STAT_FREQUENCY=N ensuring that |
96 |
| - cfg.MONITOR_PERF_STATS = True. |
| 150 | + Set the btime_freq input using cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY=N |
| 151 | + ensuring that cfg.HOOKS.PERF_STATS.MONITOR_PERF_STATS = True. |
97 | 152 | """
|
98 | 153 | phase_type = "train" if task.train else "test"
|
99 | 154 | if is_primary() and phase_type == "train":
|
|
0 commit comments