Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added callable options for iteration_log and epoch_log in StatsHandler #5965

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class StatsHandler:

def __init__(
self,
iteration_log: bool = True,
epoch_log: bool = True,
iteration_log: bool | Callable[[Engine, int], bool] = True,
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_print_logger: Callable[[Engine], Any] | None = None,
iteration_print_logger: Callable[[Engine], Any] | None = None,
output_transform: Callable = lambda x: x[0],
Expand All @@ -80,8 +80,14 @@ def __init__(
"""

Args:
iteration_log: whether to log data when iteration completed, default to `True`.
epoch_log: whether to log data when epoch completed, default to `True`.
iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can
be also a function and it will be interpreted as an event filter
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
Event filter function accepts as input engine and event value (iteration) and should return True/False.
Event filtering can be helpful to customize iteration logging frequency.
epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be
also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more
details.
epoch_print_logger: customized callable printer for epoch level logging.
Must accept parameter "engine", use default printer if None.
iteration_print_logger: customized callable printer for iteration level logging.
Expand Down Expand Up @@ -135,9 +141,15 @@ def attach(self, engine: Engine) -> None:
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
)
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
event = Events.ITERATION_COMPLETED
if callable(self.iteration_log): # substitute event with new one using filter callable
event = event(event_filter=self.iteration_log)
engine.add_event_handler(event, self.iteration_completed)
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
event = Events.EPOCH_COMPLETED
if callable(self.epoch_log): # substitute event with new one using filter callable
event = event(event_filter=self.epoch_log)
engine.add_event_handler(event, self.epoch_completed)
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)

Expand Down
39 changes: 31 additions & 8 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@

import torch
from ignite.engine import Engine, Events
from parameterized import parameterized

from monai.handlers import StatsHandler


def get_event_filter(e):
def event_filter(_, event):
if event in e:
return True
return False

return event_filter


class TestHandlerStats(unittest.TestCase):
def test_metrics_print(self):
@parameterized.expand([[True], [get_event_filter([1, 2])]])
def test_metrics_print(self, epoch_log):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
Expand All @@ -48,10 +59,11 @@ def _update_metric(engine):
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
max_epochs = 4
engine.run(range(3), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
Expand All @@ -61,9 +73,13 @@ def _update_metric(engine):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
if epoch_log is True:
self.assertTrue(content_count == max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter

def test_loss_print(self):
@parameterized.expand([[True], [get_event_filter([1, 3])]])
def test_loss_print(self, iteration_log):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
Expand All @@ -80,10 +96,14 @@ def _train_func(engine, batch):
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
stats_handler = StatsHandler(
iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print
)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
num_iters = 3
max_epochs = 2
engine.run(range(num_iters), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
Expand All @@ -93,7 +113,10 @@ def _train_func(engine, batch):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
if iteration_log is True:
self.assertTrue(content_count == num_iters * max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter

def test_loss_dict(self):
log_stream = StringIO()
Expand Down