Skip to content

Commit

Permalink
Added callable options for iteration_log and epoch_log in StatsHandler
Browse files Browse the repository at this point in the history
Fixes #5964
  • Loading branch information
vfdev-5 committed Feb 9, 2023
1 parent 94feae5 commit 9f52d54
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 63 deletions.
28 changes: 22 additions & 6 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class StatsHandler:

def __init__(
self,
iteration_log: bool = True,
epoch_log: bool = True,
iteration_log: bool | Callable[[Engine, int], bool] = True,
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_print_logger: Callable[[Engine], Any] | None = None,
iteration_print_logger: Callable[[Engine], Any] | None = None,
output_transform: Callable = lambda x: x[0],
Expand All @@ -80,8 +80,14 @@ def __init__(
"""
Args:
iteration_log: whether to log data when iteration completed, default to `True`.
epoch_log: whether to log data when epoch completed, default to `True`.
iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can
be also a function and it will be interpreted as an event filter
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
Event filter function accepts as input engine and event value (iteration) and should return True/False.
Event filtering can be helpful to customize iteration logging frequency.
epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be
also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more
details.
epoch_print_logger: customized callable printer for epoch level logging.
Must accept parameter "engine", use default printer if None.
iteration_print_logger: customized callable printer for iteration level logging.
Expand Down Expand Up @@ -135,9 +141,19 @@ def attach(self, engine: Engine) -> None:
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
)
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
event = (
Events.ITERATION_COMPLETED(event_filter=self.iteration_log)
if callable(self.iteration_log)
else Events.ITERATION_COMPLETED
)
engine.add_event_handler(event, self.iteration_completed)
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
event = (
Events.EPOCH_COMPLETED(event_filter=self.epoch_log)
if callable(self.epoch_log)
else Events.EPOCH_COMPLETED
)
engine.add_event_handler(event, self.epoch_completed)
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)

Expand Down
137 changes: 80 additions & 57 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,97 @@

class TestHandlerStats(unittest.TestCase):
def test_metrics_print(self):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "testing_metric"
def event_filter(_, event):
if event in [1, 2]:
return True
return False

for epoch_log in [True, event_filter]:
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "testing_metric"

# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]
# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]

engine = Engine(_train_func)
engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get(key_to_print, 0.1)
engine.state.metrics[key_to_print] = current_metric + 0.1
# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get(key_to_print, 0.1)
engine.state.metrics[key_to_print] = current_metric + 0.1

# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)
stats_handler.attach(engine)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
max_epochs = 4
engine.run(range(3), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
if epoch_log is True:
self.assertTrue(content_count == max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter

def test_loss_print(self):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "myLoss"

# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]
def event_filter(_, event):
if event in [1, 3]:
return True
return False

for iteration_log in [True, event_filter]:
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "myLoss"

engine = Engine(_train_func)
# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]

# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
stats_handler.attach(engine)
engine = Engine(_train_func)

engine.run(range(3), max_epochs=2)
# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(
iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print
)
stats_handler.attach(engine)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
num_iters = 3
max_epochs = 2
engine.run(range(num_iters), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
if iteration_log is True:
self.assertTrue(content_count == num_iters * max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter

def test_loss_dict(self):
log_stream = StringIO()
Expand Down

0 comments on commit 9f52d54

Please sign in to comment.