Skip to content

Commit

Permalink
add metrics to training monitor title
Browse files Browse the repository at this point in the history
* add mean time per epoch
* add ETA to finish next 10 epochs
* add plateau patience fraction (when in plateau)
  • Loading branch information
roomrys committed Mar 30, 2022
1 parent d7bfef8 commit bf00897
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ jobs:
python-version: 3.7
- name: Install Dependencies
run: |
pip install black==20.8b1
pip install click==8.0.4
pip install black==21.6b0
- name: Run Black
run: |
black --check sleap tests
Expand Down
3 changes: 2 additions & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ twine==3.3.0
PyGithub
jupyterlab
jedi==0.17.2
ipykernel
ipykernel
click==8.0.4
2 changes: 1 addition & 1 deletion sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def run_gui_training(

if gui:
print("Resetting monitor window.")
win.reset(what=str(model_type))
win.reset(what=str(model_type), config=job)
win.setWindowTitle(f"Training Model - {str(model_type)}")
win.set_message(f"Preparing to run training...")
if save_viz:
Expand Down
53 changes: 52 additions & 1 deletion sleap/gui/widgets/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import numpy as np
from time import perf_counter
from sleap.nn.config.training_job import TrainingJobConfig
import zmq
import jsonpickle
import logging
from typing import Optional
from PySide2 import QtCore, QtWidgets, QtGui
from PySide2.QtCharts import QtCharts
import attr


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +69,11 @@ def unbind(self):
self.ctx.term()
self.ctx = None

def reset(self, what: str = ""):
def reset(
self,
what: str = "",
config: TrainingJobConfig = attr.ib(factory=TrainingJobConfig),
):
"""Reset all chart series.
Args:
Expand Down Expand Up @@ -221,16 +227,24 @@ def reset(self, what: str = ""):
wid.setLayout(layout)
self.setCentralWidget(wid)

self.config = config
self.X = []
self.Y = []
self.best_val_x = None
self.best_val_y = None

self.t0 = None
self.mean_epoch_time_min = None
self.mean_epoch_time_sec = None
self.eta_ten_epochs_min = None

self.current_job_output_type = what
self.epoch = 0
self.epoch_size = 1
self.epochs_in_plateau = 0
self.last_epoch_val_loss = None
self.penultimate_epoch_val_loss = None
self.epoch_in_plateau_flag = False
self.last_batch_number = 0
self.is_running = False

Expand Down Expand Up @@ -417,6 +431,18 @@ def update_runtime(self):
title = f"Training Epoch <b>{self.epoch + 1}</b> / "
title += f"Runtime: <b>{int(dt_min):02}:{int(dt_sec):02}</b>"
if self.last_epoch_val_loss is not None:
if self.penultimate_epoch_val_loss is not None:
title += (
f"<br />Mean Time per Epoch: "
f"<b>{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02}</b> / "
f"ETA Next 10 Epochs: <b>{int(self.eta_ten_epochs_min)} min</b>"
)
if self.epoch_in_plateau_flag:
title += (
f"<br />Epochs in Plateau: "
f"<b>{self.epochs_in_plateau} / "
f"{self.config.optimization.early_stopping.plateau_patience}</b>"
)
title += (
f"<br />Last Epoch Validation Loss: "
f"<b>{self.last_epoch_val_loss:.3e}</b>"
Expand Down Expand Up @@ -492,12 +518,37 @@ def check_messages(
"epoch_loss",
)
if "val_loss" in msg["logs"].keys():
# update variables and add points to plot
self.penultimate_epoch_val_loss = self.last_epoch_val_loss
self.last_epoch_val_loss = msg["logs"]["val_loss"]
self.add_datapoint(
(self.epoch + 1) * self.epoch_size,
msg["logs"]["val_loss"],
"val_loss",
)
# calculate timing and flags at new epoch
if self.penultimate_epoch_val_loss is not None:
mean_epoch_time = (perf_counter() - self.t0) / (
self.epoch + 1
)
self.mean_epoch_time_min, self.mean_epoch_time_sec = divmod(
mean_epoch_time, 60
)
self.eta_ten_epochs_min = (mean_epoch_time * 10) // 60

val_loss_delta = (
self.penultimate_epoch_val_loss
- self.last_epoch_val_loss
)
self.epoch_in_plateau_flag = (
val_loss_delta
< self.config.optimization.early_stopping.plateau_min_delta
) or (self.best_val_y < self.last_epoch_val_loss)
self.epochs_in_plateau = (
self.epochs_in_plateau + 1
if self.epoch_in_plateau_flag
else 0
)
self.on_epoch.emit()
elif msg["event"] == "batch_end":
self.last_batch_number = msg["batch"]
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __del__(self):
self.context.term()

def on_batch_end(self, batch, logs=None):
""" Called at the end of a training batch. """
"""Called at the end of a training batch."""
if self.socket.poll(self.timeout, zmq.POLLIN):
msg = jsonpickle.decode(self.socket.recv_string())
logger.info(f"Received control message: {msg}")
Expand Down

0 comments on commit bf00897

Please sign in to comment.