Skip to content

Commit 9f52d54

Browse files
committed
Added callable options for iteration_log and epoch_log in StatsHandler
Fixes #5964
1 parent 94feae5 commit 9f52d54

File tree

2 files changed

+102
-63
lines changed

2 files changed

+102
-63
lines changed

monai/handlers/stats_handler.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class StatsHandler:
6666

6767
def __init__(
6868
self,
69-
iteration_log: bool = True,
70-
epoch_log: bool = True,
69+
iteration_log: bool | Callable[[Engine, int], bool] = True,
70+
epoch_log: bool | Callable[[Engine, int], bool] = True,
7171
epoch_print_logger: Callable[[Engine], Any] | None = None,
7272
iteration_print_logger: Callable[[Engine], Any] | None = None,
7373
output_transform: Callable = lambda x: x[0],
@@ -80,8 +80,14 @@ def __init__(
8080
"""
8181
8282
Args:
83-
iteration_log: whether to log data when iteration completed, default to `True`.
84-
epoch_log: whether to log data when epoch completed, default to `True`.
83+
iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can
84+
be also a function and it will be interpreted as an event filter
85+
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
86+
Event filter function accepts as input engine and event value (iteration) and should return True/False.
87+
Event filtering can be helpful to customize iteration logging frequency.
88+
epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be
89+
also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more
90+
details.
8591
epoch_print_logger: customized callable printer for epoch level logging.
8692
Must accept parameter "engine", use default printer if None.
8793
iteration_print_logger: customized callable printer for iteration level logging.
@@ -135,9 +141,19 @@ def attach(self, engine: Engine) -> None:
135141
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
136142
)
137143
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
138-
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
144+
event = (
145+
Events.ITERATION_COMPLETED(event_filter=self.iteration_log)
146+
if callable(self.iteration_log)
147+
else Events.ITERATION_COMPLETED
148+
)
149+
engine.add_event_handler(event, self.iteration_completed)
139150
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
140-
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
151+
event = (
152+
Events.EPOCH_COMPLETED(event_filter=self.epoch_log)
153+
if callable(self.epoch_log)
154+
else Events.EPOCH_COMPLETED
155+
)
156+
engine.add_event_handler(event, self.epoch_completed)
141157
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
142158
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
143159

tests/test_handler_stats.py

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,74 +26,97 @@
2626

2727
class TestHandlerStats(unittest.TestCase):
2828
def test_metrics_print(self):
29-
log_stream = StringIO()
30-
log_handler = logging.StreamHandler(log_stream)
31-
log_handler.setLevel(logging.INFO)
32-
key_to_handler = "test_logging"
33-
key_to_print = "testing_metric"
29+
def event_filter(_, event):
30+
if event in [1, 2]:
31+
return True
32+
return False
33+
34+
for epoch_log in [True, event_filter]:
35+
log_stream = StringIO()
36+
log_handler = logging.StreamHandler(log_stream)
37+
log_handler.setLevel(logging.INFO)
38+
key_to_handler = "test_logging"
39+
key_to_print = "testing_metric"
3440

35-
# set up engine
36-
def _train_func(engine, batch):
37-
return [torch.tensor(0.0)]
41+
# set up engine
42+
def _train_func(engine, batch):
43+
return [torch.tensor(0.0)]
3844

39-
engine = Engine(_train_func)
45+
engine = Engine(_train_func)
4046

41-
# set up dummy metric
42-
@engine.on(Events.EPOCH_COMPLETED)
43-
def _update_metric(engine):
44-
current_metric = engine.state.metrics.get(key_to_print, 0.1)
45-
engine.state.metrics[key_to_print] = current_metric + 0.1
47+
# set up dummy metric
48+
@engine.on(Events.EPOCH_COMPLETED)
49+
def _update_metric(engine):
50+
current_metric = engine.state.metrics.get(key_to_print, 0.1)
51+
engine.state.metrics[key_to_print] = current_metric + 0.1
4652

47-
# set up testing handler
48-
logger = logging.getLogger(key_to_handler)
49-
logger.setLevel(logging.INFO)
50-
logger.addHandler(log_handler)
51-
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
52-
stats_handler.attach(engine)
53-
54-
engine.run(range(3), max_epochs=2)
53+
# set up testing handler
54+
logger = logging.getLogger(key_to_handler)
55+
logger.setLevel(logging.INFO)
56+
logger.addHandler(log_handler)
57+
stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)
58+
stats_handler.attach(engine)
5559

56-
# check logging output
57-
output_str = log_stream.getvalue()
58-
log_handler.close()
59-
has_key_word = re.compile(f".*{key_to_print}.*")
60-
content_count = 0
61-
for line in output_str.split("\n"):
62-
if has_key_word.match(line):
63-
content_count += 1
64-
self.assertTrue(content_count > 0)
60+
max_epochs = 4
61+
engine.run(range(3), max_epochs=max_epochs)
62+
63+
# check logging output
64+
output_str = log_stream.getvalue()
65+
log_handler.close()
66+
has_key_word = re.compile(f".*{key_to_print}.*")
67+
content_count = 0
68+
for line in output_str.split("\n"):
69+
if has_key_word.match(line):
70+
content_count += 1
71+
if epoch_log is True:
72+
self.assertTrue(content_count == max_epochs)
73+
else:
74+
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter
6575

6676
def test_loss_print(self):
67-
log_stream = StringIO()
68-
log_handler = logging.StreamHandler(log_stream)
69-
log_handler.setLevel(logging.INFO)
70-
key_to_handler = "test_logging"
71-
key_to_print = "myLoss"
72-
73-
# set up engine
74-
def _train_func(engine, batch):
75-
return [torch.tensor(0.0)]
77+
def event_filter(_, event):
78+
if event in [1, 3]:
79+
return True
80+
return False
81+
82+
for iteration_log in [True, event_filter]:
83+
log_stream = StringIO()
84+
log_handler = logging.StreamHandler(log_stream)
85+
log_handler.setLevel(logging.INFO)
86+
key_to_handler = "test_logging"
87+
key_to_print = "myLoss"
7688

77-
engine = Engine(_train_func)
89+
# set up engine
90+
def _train_func(engine, batch):
91+
return [torch.tensor(0.0)]
7892

79-
# set up testing handler
80-
logger = logging.getLogger(key_to_handler)
81-
logger.setLevel(logging.INFO)
82-
logger.addHandler(log_handler)
83-
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
84-
stats_handler.attach(engine)
93+
engine = Engine(_train_func)
8594

86-
engine.run(range(3), max_epochs=2)
95+
# set up testing handler
96+
logger = logging.getLogger(key_to_handler)
97+
logger.setLevel(logging.INFO)
98+
logger.addHandler(log_handler)
99+
stats_handler = StatsHandler(
100+
iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print
101+
)
102+
stats_handler.attach(engine)
87103

88-
# check logging output
89-
output_str = log_stream.getvalue()
90-
log_handler.close()
91-
has_key_word = re.compile(f".*{key_to_print}.*")
92-
content_count = 0
93-
for line in output_str.split("\n"):
94-
if has_key_word.match(line):
95-
content_count += 1
96-
self.assertTrue(content_count > 0)
104+
num_iters = 3
105+
max_epochs = 2
106+
engine.run(range(num_iters), max_epochs=max_epochs)
107+
108+
# check logging output
109+
output_str = log_stream.getvalue()
110+
log_handler.close()
111+
has_key_word = re.compile(f".*{key_to_print}.*")
112+
content_count = 0
113+
for line in output_str.split("\n"):
114+
if has_key_word.match(line):
115+
content_count += 1
116+
if iteration_log is True:
117+
self.assertTrue(content_count == num_iters * max_epochs)
118+
else:
119+
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter
97120

98121
def test_loss_dict(self):
99122
log_stream = StringIO()

0 commit comments

Comments
 (0)