Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add hooks building to classification task (#62) #402

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def compute_pr_curves(class_hist, total_hist):


def get_checkpoint_dict(task, input_args, deep_copy=False):
assert isinstance(
assert input_args is None or isinstance(
input_args, dict
), f"Unexpected input_args of type: {type(input_args)}"
return {
Expand Down
5 changes: 2 additions & 3 deletions classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import Any, Collection, Dict, Optional

from classy_vision import tasks
from classy_vision.generic.distributed_util import is_master
from classy_vision.generic.util import get_checkpoint_dict, save_checkpoint
from classy_vision.hooks import register_hook
Expand Down Expand Up @@ -85,7 +84,7 @@ def _save_checkpoint(self, task, filename):
if checkpoint_file:
PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")

def on_start(self, task: "tasks.ClassyTask") -> None:
def on_start(self, task) -> None:
if not is_master() or getattr(task, "test_only", False):
return
if not PathManager.exists(self.checkpoint_folder):
Expand All @@ -94,7 +93,7 @@ def on_start(self, task: "tasks.ClassyTask") -> None:
)
raise FileNotFoundError(err_msg)

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""Checkpoint the task every checkpoint_period phases.

We do not necessarily checkpoint the task at the end of every phase.
Expand Down
12 changes: 5 additions & 7 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

from classy_vision import tasks


class ClassyHookState:
"""Class to store state within instances of ClassyHook.
Expand Down Expand Up @@ -69,27 +67,27 @@ def name(cls) -> str:
return cls.__name__

@abstractmethod
def on_start(self, task: "tasks.ClassyTask") -> None:
def on_start(self, task) -> None:
"""Called at the start of training."""
pass

@abstractmethod
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
def on_phase_start(self, task) -> None:
"""Called at the start of each phase."""
pass

@abstractmethod
def on_step(self, task: "tasks.ClassyTask") -> None:
def on_step(self, task) -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

@abstractmethod
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""Called at the end of each phase (epoch)."""
pass

@abstractmethod
def on_end(self, task: "tasks.ClassyTask") -> None:
def on_end(self, task) -> None:
"""Called at the end of training."""
pass

Expand Down
11 changes: 5 additions & 6 deletions classy_vision/hooks/exponential_moving_average_model_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch.nn as nn
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
from classy_vision.tasks import ClassyTask


@register_hook("ema_model_weights")
Expand Down Expand Up @@ -78,7 +77,7 @@ def _save_current_model_state(self, model: nn.Module, model_state: Dict[str, Any
for name, param in self.get_model_state_iterator(model):
model_state[name] = param.detach().clone().to(device=self.device)

def on_start(self, task: ClassyTask) -> None:
def on_start(self, task) -> None:
if self.state.model_state:
# loaded state from checkpoint, do not re-initialize, only move the state
# to the right device
Expand All @@ -93,17 +92,17 @@ def on_start(self, task: ClassyTask) -> None:
self._save_current_model_state(task.base_model, self.state.model_state)
self._save_current_model_state(task.base_model, self.state.ema_model_state)

def on_phase_start(self, task: ClassyTask) -> None:
def on_phase_start(self, task) -> None:
# restore the right state depending on the phase type
self.set_model_state(task, use_ema=not task.train)

def on_phase_end(self, task: ClassyTask) -> None:
def on_phase_end(self, task) -> None:
if task.train:
# save the current model state since this will be overwritten by the ema
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask) -> None:
def on_step(self, task) -> None:
if not task.train:
return

Expand All @@ -117,7 +116,7 @@ def on_step(self, task: ClassyTask) -> None:
device=self.device
)

def set_model_state(self, task: ClassyTask, use_ema: bool) -> None:
def set_model_state(self, task, use_ema: bool) -> None:
"""
Depending on use_ema, set the appropriate state for the model.
"""
Expand Down
9 changes: 4 additions & 5 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import Any, Dict, Optional

from classy_vision import tasks
from classy_vision.generic.distributed_util import get_rank
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
), "log_freq must be an int or None"
self.log_freq: Optional[int] = log_freq

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""
Log the loss, optimizer LR, and meters for the phase.
"""
Expand All @@ -51,7 +50,7 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None:
if task.train:
self._log_lr(task)

def on_step(self, task: "tasks.ClassyTask") -> None:
def on_step(self, task) -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
Expand All @@ -63,14 +62,14 @@ def on_step(self, task: "tasks.ClassyTask") -> None:
logging.info("Local unsynced metric values:")
self._log_loss_meters(task)

def _log_lr(self, task: "tasks.ClassyTask") -> None:
def _log_lr(self, task) -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
def _log_loss_meters(self, task) -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
3 changes: 1 addition & 2 deletions classy_vision/hooks/model_complexity_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import Any, Dict

from classy_vision import tasks
from classy_vision.generic.profiler import (
compute_activations,
compute_flops,
Expand All @@ -28,7 +27,7 @@ class ModelComplexityHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(self, task: "tasks.ClassyTask") -> None:
def on_start(self, task) -> None:
"""Measure number of parameters, FLOPs and activations."""
self.num_flops = 0
self.num_activations = 0
Expand Down
3 changes: 1 addition & 2 deletions classy_vision/hooks/model_tensorboard_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import Any, Dict

from classy_vision import tasks
from classy_vision.generic.distributed_util import is_master
from classy_vision.generic.visualize import plot_model
from classy_vision.hooks import register_hook
Expand Down Expand Up @@ -62,7 +61,7 @@ def from_config(cls, config: [Dict[str, Any]]) -> "ModelTensorboardHook":
tb_writer = SummaryWriter(**config["summary_writer"])
return cls(tb_writer=tb_writer)

def on_start(self, task: "tasks.ClassyTask") -> None:
def on_start(self, task) -> None:
"""
Plot the model on Tensorboard.
"""
Expand Down
3 changes: 1 addition & 2 deletions classy_vision/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from typing import Any, Dict

from classy_vision import tasks
from classy_vision.generic.profiler import profile, summarize_profiler_info
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
Expand All @@ -25,7 +24,7 @@ class ProfilerHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(self, task: "tasks.ClassyTask") -> None:
def on_start(self, task) -> None:
"""Profile the forward pass."""
logging.info("Profiling forward pass...")
batchsize_per_replica = getattr(
Expand Down
7 changes: 3 additions & 4 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import Any, Dict, Optional

from classy_vision import tasks
from classy_vision.generic.distributed_util import is_master
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(self) -> None:
self.bar_size: int = 0
self.batches: int = 0

def on_phase_start(self, task: "tasks.ClassyTask") -> None:
def on_phase_start(self, task) -> None:
"""Create and display a progress bar with 0 progress."""
if not progressbar_available:
raise RuntimeError(
Expand All @@ -49,13 +48,13 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None:
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(self, task: "tasks.ClassyTask") -> None:
def on_step(self, task) -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
self.progress_bar.update(min(self.batches, self.bar_size))

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""Clear the progress bar at the end of the phase."""
if is_master() and self.progress_bar is not None:
self.progress_bar.finish()
7 changes: 3 additions & 4 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import time
from typing import Any, Dict, List, Optional

from classy_vision import tasks
from classy_vision.generic.distributed_util import is_master
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
Expand Down Expand Up @@ -70,7 +69,7 @@ def from_config(cls, config: Dict[str, Any]) -> "TensorboardPlotHook":
log_period = config.get("log_period", 10)
return cls(tb_writer=tb_writer, log_period=log_period)

def on_phase_start(self, task: "tasks.ClassyTask") -> None:
def on_phase_start(self, task) -> None:
"""Initialize losses and learning_rates."""
self.learning_rates = []
self.wall_times = []
Expand All @@ -87,7 +86,7 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None:
f"Parameters/{name}", parameter, global_step=-1
)

def on_step(self, task: "tasks.ClassyTask") -> None:
def on_step(self, task) -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand All @@ -106,7 +105,7 @@ def on_step(self, task: "tasks.ClassyTask") -> None:

self.step_idx += 1

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""Add the losses and learning rates to tensorboard."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
3 changes: 1 addition & 2 deletions classy_vision/hooks/visdom_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging
from typing import Any, Dict

from classy_vision import tasks
from classy_vision.generic.distributed_util import is_master
from classy_vision.generic.util import flatten_dict
from classy_vision.generic.visualize import plot_learning_curves
Expand Down Expand Up @@ -60,7 +59,7 @@ def __init__(
self.metrics: Dict = {}
self.visdom: Visdom = Visdom(self.server, self.port)

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
def on_phase_end(self, task) -> None:
"""
Plot the metrics on visdom.
"""
Expand Down
10 changes: 10 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
recursive_copy_to_gpu,
update_classy_state,
)
from classy_vision.hooks import build_hooks
from classy_vision.losses import ClassyLoss, build_loss
from classy_vision.meters import build_meters
from classy_vision.models import ClassyModel, build_model
Expand Down Expand Up @@ -328,6 +329,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
amp_args = config.get("amp_args")
meters = build_meters(config.get("meters", {}))
model = build_model(config["model"])

# hooks config is optional
hooks_config = config.get("hooks")
hooks = []
if hooks_config is not None:
hooks = build_hooks(hooks_config)

optimizer = build_optimizer(optimizer_config)

task = (
Expand All @@ -348,7 +356,9 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
config.get("batch_norm_sync_mode", "disabled").upper()
],
)
.set_hooks(hooks)
)

for phase_type in phase_types:
task.set_dataset(datasets[phase_type], phase_type)

Expand Down
7 changes: 7 additions & 0 deletions test/tasks_classification_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def test_build_task(self):
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))

def test_hooks_config_builds_correctly(self):
config = get_test_task_config()
config["hooks"] = [{"name": "loss_lr_meter_logging"}]
task = build_task(config)
self.assertTrue(len(task.hooks) == 1)
self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))

def test_get_state(self):
config = get_test_task_config()
loss = build_loss(config["loss"])
Expand Down