Skip to content

Commit 1db8813

Browse files
Aaron Adcockfacebook-github-bot
Aaron Adcock
authored andcommitted
Add hooks building to classification task (facebookresearch#402)
Summary: Pull Request resolved: facebookresearch#402 Pull Request resolved: fairinternal/ClassyVision#62 Add configurable hooks to classification task, had to remove the typehints from ClassyHook to avoid a circular dependency. Reviewed By: mannatsingh, vreis Differential Revision: D19770583 fbshipit-source-id: 0011c3519bf0af5c5e317319e80ede70127e754c
1 parent 2ed4394 commit 1db8813

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

classy_vision/hooks/classy_hook.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from abc import ABC, abstractmethod
88
from typing import Any, Dict
99

10-
from classy_vision import tasks
11-
1210

1311
class ClassyHookState:
1412
"""Class to store state within instances of ClassyHook.
@@ -69,27 +67,27 @@ def name(cls) -> str:
6967
return cls.__name__
7068

7169
@abstractmethod
72-
def on_start(self, task: "tasks.ClassyTask") -> None:
70+
def on_start(self, task) -> None:
7371
"""Called at the start of training."""
7472
pass
7573

7674
@abstractmethod
77-
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
75+
def on_phase_start(self, task) -> None:
7876
"""Called at the start of each phase."""
7977
pass
8078

8179
@abstractmethod
82-
def on_step(self, task: "tasks.ClassyTask") -> None:
80+
def on_step(self, task) -> None:
8381
"""Called each time after parameters have been updated by the optimizer."""
8482
pass
8583

8684
@abstractmethod
87-
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
85+
def on_phase_end(self, task) -> None:
8886
"""Called at the end of each phase (epoch)."""
8987
pass
9088

9189
@abstractmethod
92-
def on_end(self, task: "tasks.ClassyTask") -> None:
90+
def on_end(self, task) -> None:
9391
"""Called at the end of training."""
9492
pass
9593

classy_vision/tasks/classification_task.py

+6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
recursive_copy_to_gpu,
2525
update_classy_state,
2626
)
27+
from classy_vision.hooks import build_hooks
2728
from classy_vision.losses import ClassyLoss, build_loss
2829
from classy_vision.meters import build_meters
2930
from classy_vision.models import ClassyModel, build_model
@@ -328,6 +329,10 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
328329
amp_args = config.get("amp_args")
329330
meters = build_meters(config.get("meters", {}))
330331
model = build_model(config["model"])
332+
333+
# hooks config is optional
334+
hooks_config = config.get("hooks", [])
335+
hooks = build_hooks(hooks_config)
331336
optimizer = build_optimizer(optimizer_config)
332337

333338
task = (
@@ -348,6 +353,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
348353
config.get("batch_norm_sync_mode", "disabled").upper()
349354
],
350355
)
356+
.set_hooks(hooks)
351357
)
352358
for phase_type in phase_types:
353359
task.set_dataset(datasets[phase_type], phase_type)

test/generic/config_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def get_test_task_config(head_num_classes=1000):
1212
"name": "classification_task",
1313
"num_epochs": 12,
1414
"loss": {"name": "CrossEntropyLoss"},
15+
"hooks": [{"name": "loss_lr_meter_logging"}],
1516
"dataset": {
1617
"train": {
1718
"name": "synthetic_image",

test/tasks_classification_task_test.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def test_build_task(self):
4444
task = build_task(config)
4545
self.assertTrue(isinstance(task, ClassificationTask))
4646

47+
# Hooks configuration is optional
48+
del config["hooks"]
49+
task = build_task(config)
50+
self.assertTrue(isinstance(task, ClassificationTask))
51+
4752
def test_get_state(self):
4853
config = get_test_task_config()
4954
loss = build_loss(config["loss"])

0 commit comments

Comments
 (0)