From ab4f2a304e7780c13a6fef8ba17725013b75056c Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 30 Mar 2022 15:01:31 -0700 Subject: [PATCH] Training monitor enhancements (#691) * Training monitor enhancements - Cleaned up imports - Docstrings - Now update based on time, not epochs - Added markers for epoch-level losses - Added best validation loss marker and text - Reduced minimum possible y-axis value when log scaling - Marker colors, alpha, sizes and line widths adjusted * Move training monitor to gui submodule * add metrics to training monitor title * add mean time per epoch * add ETA to finish next 10 epochs * add plateau patience fraction (when in plateau) * update dev_requirements to install version of click that does not break black * Add code coverage * add coverage for all lines within LossViewer.update_runtime() Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- .github/workflows/ci.yml | 3 +- dev_requirements.txt | 3 +- sleap/gui/learning/runners.py | 4 +- sleap/{nn => gui/widgets}/monitor.py | 364 ++++++++++++++++++--------- sleap/nn/callbacks.py | 2 +- tests/gui/test_monitor.py | 30 ++- 6 files changed, 278 insertions(+), 128 deletions(-) rename sleap/{nn => gui/widgets}/monitor.py (50%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b272438d..603d79b50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,7 +40,8 @@ jobs: python-version: 3.7 - name: Install Dependencies run: | - pip install black==20.8b1 + pip install click==8.0.4 + pip install black==21.6b0 - name: Run Black run: | black --check sleap tests diff --git a/dev_requirements.txt b/dev_requirements.txt index e0b85d4de..d230f9492 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -18,4 +18,5 @@ twine==3.3.0 PyGithub jupyterlab jedi==0.17.2 -ipykernel \ No newline at end of file +ipykernel +click==8.0.4 \ No newline at end of file diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 5cd6b28e8..052d53b2a 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -556,7 +556,7 @@ def run_gui_training( trained_job_paths = dict() if gui: - from sleap.nn.monitor import LossViewer + from sleap.gui.widgets.monitor import LossViewer from sleap.gui.widgets.imagedir import QtImageDirectoryWidget # open training monitor window @@ -603,7 +603,7 @@ def run_gui_training( if gui: print("Resetting monitor window.") - win.reset(what=str(model_type)) + win.reset(what=str(model_type), config=job) win.setWindowTitle(f"Training Model - {str(model_type)}") win.set_message(f"Preparing to run training...") if save_viz: diff --git a/sleap/nn/monitor.py b/sleap/gui/widgets/monitor.py similarity index 50% rename from sleap/nn/monitor.py rename to sleap/gui/widgets/monitor.py index 0e60469df..cab998b84 100644 --- a/sleap/nn/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -1,14 +1,16 @@ """GUI for monitoring training progress interactively.""" -from collections import deque import numpy as np -from time import time, sleep +from time import perf_counter +from sleap.nn.config.training_job import TrainingJobConfig import zmq import jsonpickle import logging from typing import Optional +from PySide2 import QtCore, QtWidgets, QtGui +from PySide2.QtCharts import QtCharts +import attr -from PySide2 import QtCore, QtWidgets, QtGui, QtCharts logger = logging.getLogger(__name__) @@ -24,17 +26,19 @@ def __init__( show_controller=True, parent=None, ): - super(LossViewer, self).__init__(parent) + super().__init__(parent) self.show_controller = show_controller self.stop_button = None self.cancel_button = None self.canceled = False - self.redraw_batch_interval = 40 self.batches_to_show = -1 # -1 to show all self.ignore_outliers = False self.log_scale = True + self.message_poll_time_ms = 20 # ms + self.redraw_batch_time_ms = 500 # ms + self.last_redraw_batch = None self.reset() self.setup_zmq(zmq_context) @@ -43,86 +47,130 @@ def __del__(self): self.unbind() def close(self): + """Disconnect from ZMQ ports and close the window.""" self.unbind() - super(LossViewer, self).close() + super().close() def unbind(self): - # close the zmq socket + """Disconnect from all ZMQ sockets.""" if self.sub is not None: self.sub.unbind(self.sub.LAST_ENDPOINT) self.sub.close() self.sub = None + if self.zmq_ctrl is not None: url = self.zmq_ctrl.LAST_ENDPOINT self.zmq_ctrl.unbind(url) self.zmq_ctrl.close() self.zmq_ctrl = None - # if we started out own zmq context, terminate it + + # If we started out own zmq context, terminate it. if not self.ctx_given and self.ctx is not None: self.ctx.term() self.ctx = None - def reset(self, what=""): - self.chart = QtCharts.QtCharts.QChart() + def reset( + self, + what: str = "", + config: TrainingJobConfig = attr.ib(factory=TrainingJobConfig), + ): + """Reset all chart series. + + Args: + what: String identifier indicating which job type the current run + corresponds to. + """ + self.chart = QtCharts.QChart() self.series = dict() - self.color = dict() - self.series["batch"] = QtCharts.QtCharts.QScatterSeries() - self.series["epoch_loss"] = QtCharts.QtCharts.QLineSeries() - self.series["val_loss"] = QtCharts.QtCharts.QLineSeries() + COLOR_TRAIN = (18, 158, 220) + COLOR_VAL = (248, 167, 52) + COLOR_BEST_VAL = (151, 204, 89) + self.series["batch"] = QtCharts.QScatterSeries() self.series["batch"].setName("Batch Training Loss") - self.series["epoch_loss"].setName("Epoch Training Loss") - self.series["val_loss"].setName("Epoch Validation Loss") - - self.color["batch"] = QtGui.QColor("blue") - self.color["epoch_loss"] = QtGui.QColor("green") - self.color["val_loss"] = QtGui.QColor("red") - - for s in self.series: - self.series[s].pen().setColor(self.color[s]) + self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) self.series["batch"].setMarkerSize(8.0) self.series["batch"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["batch"]) + + self.series["epoch_loss"] = QtCharts.QLineSeries() + self.series["epoch_loss"].setName("Epoch Training Loss") + self.series["epoch_loss"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) + pen = self.series["epoch_loss"].pen() + pen.setWidth(4) + self.series["epoch_loss"].setPen(pen) self.chart.addSeries(self.series["epoch_loss"]) + + self.series["epoch_loss_scatter"] = QtCharts.QScatterSeries() + self.series["epoch_loss_scatter"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) + self.series["epoch_loss_scatter"].setMarkerSize(12.0) + self.series["epoch_loss_scatter"].setBorderColor( + QtGui.QColor(255, 255, 255, 25) + ) + self.chart.addSeries(self.series["epoch_loss_scatter"]) + + self.series["val_loss"] = QtCharts.QLineSeries() + self.series["val_loss"].setName("Epoch Validation Loss") + self.series["val_loss"].setColor(QtGui.QColor(*COLOR_VAL, 255)) + pen = self.series["val_loss"].pen() + pen.setWidth(4) + self.series["val_loss"].setPen(pen) self.chart.addSeries(self.series["val_loss"]) - axisX = QtCharts.QtCharts.QValueAxis() + self.series["val_loss_scatter"] = QtCharts.QScatterSeries() + self.series["val_loss_scatter"].setColor(QtGui.QColor(*COLOR_VAL, 255)) + self.series["val_loss_scatter"].setMarkerSize(12.0) + self.series["val_loss_scatter"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) + self.chart.addSeries(self.series["val_loss_scatter"]) + + self.series["val_loss_best"] = QtCharts.QScatterSeries() + self.series["val_loss_best"].setName("Best Validation Loss") + self.series["val_loss_best"].setColor(QtGui.QColor(*COLOR_BEST_VAL, 255)) + self.series["val_loss_best"].setMarkerSize(12.0) + self.series["val_loss_best"].setBorderColor(QtGui.QColor(32, 32, 32, 25)) + self.chart.addSeries(self.series["val_loss_best"]) + + axisX = QtCharts.QValueAxis() axisX.setLabelFormat("%d") axisX.setTitleText("Batches") self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - # create the different Y axes that can be used + # Create the different Y axes that can be used. self.axisY = dict() - self.axisY["log"] = QtCharts.QtCharts.QLogValueAxis() + self.axisY["log"] = QtCharts.QLogValueAxis() self.axisY["log"].setBase(10) - self.axisY["linear"] = QtCharts.QtCharts.QValueAxis() + self.axisY["linear"] = QtCharts.QValueAxis() - # settings that apply to all Y axes + # Apply settings that apply to all Y axes. for axisY in self.axisY.values(): axisY.setLabelFormat("%f") axisY.setLabelsVisible(True) axisY.setMinorTickCount(1) axisY.setTitleText("Loss") - # use the default Y axis + # Use the default Y axis. axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] + # Add axes to chart and series. self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): series.attachAxis(axisX) series.attachAxis(axisY) - # self.chart.legend().hide() + # Setup legend. self.chart.legend().setVisible(True) self.chart.legend().setAlignment(QtCore.Qt.AlignTop) + self.chart.legend().setMarkerShape(QtCharts.QLegend.MarkerShapeCircle) - self.chartView = QtCharts.QtCharts.QChartView(self.chart) + # Hide scatters for epoch and val loss from legend. + for s in ("epoch_loss_scatter", "val_loss_scatter"): + self.chart.legend().markers(self.series[s])[0].setVisible(False) + + self.chartView = QtCharts.QChartView(self.chart) self.chartView.setRenderHint(QtGui.QPainter.Antialiasing) layout = QtWidgets.QVBoxLayout() layout.addWidget(self.chartView) @@ -132,33 +180,33 @@ def reset(self, what=""): field = QtWidgets.QCheckBox("Log Scale") field.setChecked(self.log_scale) - field.stateChanged.connect(lambda x: self.toggle("log_scale")) + field.stateChanged.connect(self.toggle_log_scale) control_layout.addWidget(field) field = QtWidgets.QCheckBox("Ignore Outliers") field.setChecked(self.ignore_outliers) - field.stateChanged.connect(lambda x: self.toggle("ignore_outliers")) + field.stateChanged.connect(self.toggle_ignore_outliers) control_layout.addWidget(field) control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) - # add field for how many batches to show in chart + # Add field for how many batches to show in chart. field = QtWidgets.QComboBox() - # add options self.batch_options = "200,1000,5000,All".split(",") for opt in self.batch_options: field.addItem(opt) - # set field to currently set value cur_opt_str = ( "All" if self.batches_to_show < 0 else str(self.batches_to_show) ) if cur_opt_str in self.batch_options: field.setCurrentText(cur_opt_str) - # connection action for when user selects another option + + # Set connection action for when user selects another option. field.currentIndexChanged.connect( lambda x: self.set_batches_to_show(self.batch_options[x]) ) - # store field as property and add to layout + + # Store field as property and add to layout. self.batches_to_show_field = field control_layout.addWidget(self.batches_to_show_field) @@ -179,38 +227,54 @@ def reset(self, what=""): wid.setLayout(layout) self.setCentralWidget(wid) + self.config = config self.X = [] self.Y = [] + self.best_val_x = None + self.best_val_y = None self.t0 = None + self.mean_epoch_time_min = None + self.mean_epoch_time_sec = None + self.eta_ten_epochs_min = None + self.current_job_output_type = what self.epoch = 0 self.epoch_size = 1 + self.epochs_in_plateau = 0 self.last_epoch_val_loss = None + self.penultimate_epoch_val_loss = None + self.epoch_in_plateau_flag = False self.last_batch_number = 0 self.is_running = False - def toggle(self, what): - if what == "log_scale": - self.log_scale = not self.log_scale - self.update_y_axis() - elif what == "ignore_outliers": - self.ignore_outliers = not self.ignore_outliers - elif what == "entire_history": - if self.batches_to_show > 0: - self.batches_to_show = -1 - else: - self.batches_to_show = 200 + def toggle_ignore_outliers(self): + """Toggles whether to ignore outliers in chart scaling.""" + self.ignore_outliers = not self.ignore_outliers + + def toggle_log_scale(self): + """Toggle whether to use log-scaled y-axis.""" + self.log_scale = not self.log_scale + self.update_y_axis() - def set_batches_to_show(self, val): - if val.isdigit(): - self.batches_to_show = int(val) + def set_batches_to_show(self, batches: str): + """Set the number of batches to show on the x-axis. + + Args: + batches: Number of batches as a string. If numeric, this will be converted + to an integer. If non-numeric string (e.g., "All"), then all batches + will be shown. + """ + if batches.isdigit(): + self.batches_to_show = int(batches) else: self.batches_to_show = -1 def update_y_axis(self): + """Update the y-axis when scale changes.""" to = "log" if self.log_scale else "linear" - # remove other axes + + # Remove other axes. for name, axisY in self.axisY.items(): if name != to: if axisY in self.chart.axes(): @@ -218,16 +282,23 @@ def update_y_axis(self): for series in self.chart.series(): if axisY in series.attachedAxes(): series.detachAxis(axisY) - # add axis + + # Add axis. axisY = self.axisY[to] self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) for series in self.chart.series(): series.attachAxis(axisY) - def setup_zmq(self, zmq_context: Optional[zmq.Context]): + def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): + """Connect to ZMQ ports that listen to commands and updates. - # Keep track of whether we're using an existing context (which we won't - # close when done) or are creating our own (which we should close). + Args: + zmq_context: The `zmq.Context` object to use for connections. A new one is + created if not specified and will be closed when the monitor exits. If + an existing one is provided, it will NOT be closed. + """ + # Keep track of whether we're using an existing context (which we won't close + # when done) or are creating our own (which we should close). self.ctx_given = zmq_context is not None self.ctx = zmq.Context() if zmq_context is None else zmq_context @@ -242,10 +313,10 @@ def setup_zmq(self, zmq_context: Optional[zmq.Context]): self.zmq_ctrl = self.ctx.socket(zmq.PUB) self.zmq_ctrl.bind("tcp://127.0.0.1:9000") - # Set timer to poll for messages every 20 milliseconds + # Set timer to poll for messages. self.timer = QtCore.QTimer() self.timer.timeout.connect(self.check_messages) - self.timer.start(20) + self.timer.start(self.message_poll_time_ms) def cancel(self): """Set the cancel flag.""" @@ -255,39 +326,42 @@ def cancel(self): self.cancel_button.setEnabled(False) def stop(self): - """Action to stop training.""" - + """Send command to stop training.""" if self.zmq_ctrl is not None: - # send command to stop training + # Send command to stop training. logger.info("Sending command to stop training.") self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - # Disable the button + # Disable the button to prevent double messages. if self.stop_button is not None: self.stop_button.setText("Stopping...") self.stop_button.setEnabled(False) - def add_datapoint(self, x, y, which="batch"): - """ - Adds data point to graph. + def add_datapoint(self, x: int, y: float, which: str): + """Add a data point to graph. Args: - x: typically the batch number (out of all epochs, not just current) - y: typically the loss value - which: type of data point we're adding, possible values are - * batch (loss for batch) - * epoch_loss (loss for entire epoch) - * val_loss (validation loss for for epoch) + x: The batch number (out of all epochs, not just current), or epoch. + y: The loss value. + which: Type of data point we're adding. Possible values are: + * "batch" (loss for the batch) + * "epoch_loss" (loss for the entire epoch) + * "val_loss" (validation loss for the epoch) """ - - # Keep track of all batch points if which == "batch": self.X.append(x) self.Y.append(y) - # Redraw batch at intervals (faster than plotting each) - if x % self.redraw_batch_interval == 0: + # Redraw batch at intervals (faster than plotting every batch). + draw_batch = False + if self.last_redraw_batch is None: + draw_batch = True + else: + dt = perf_counter() - self.last_redraw_batch + draw_batch = (dt * 1000) >= self.redraw_batch_time_ms + if draw_batch: + self.last_redraw_batch = perf_counter() if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: xs, ys = self.X, self.Y else: @@ -320,41 +394,80 @@ def add_datapoint(self, x, y, which="batch"): high = max(ys) + dy if self.log_scale: - low = max(low, 1e-5) # for log scale, low cannot be 0 + low = max(low, 1e-8) # for log scale, low cannot be 0 self.chart.axisY().setRange(low, high) else: - self.series[which].append(x, y) + if which == "epoch_loss": + self.series["epoch_loss"].append(x, y) + self.series["epoch_loss_scatter"].append(x, y) + elif which == "val_loss": + self.series["val_loss"].append(x, y) + self.series["val_loss_scatter"].append(x, y) + if self.best_val_y is None or y < self.best_val_y: + self.best_val_x = x + self.best_val_y = y + self.series["val_loss_best"].replace([QtCore.QPointF(x, y)]) + + def set_start_time(self, t0: float): + """Mark the start flag and time of the run. - def set_start_time(self, t0): + Args: + t0: Start time in seconds. + """ self.t0 = t0 self.is_running = True def set_end(self): + """Mark the end of the run.""" self.is_running = False def update_runtime(self): - if self.is_timer_running(): - dt = time() - self.t0 + """Update the title text with the current running time.""" + if self.is_timer_running: + dt = perf_counter() - self.t0 dt_min, dt_sec = divmod(dt, 60) - title = f"Training Epoch {self.epoch+1} / " + title = f"Training Epoch {self.epoch + 1} / " title += f"Runtime: {int(dt_min):02}:{int(dt_sec):02}" if self.last_epoch_val_loss is not None: - title += f"
Last Epoch Validation Loss: {self.last_epoch_val_loss:.3e}" + if self.penultimate_epoch_val_loss is not None: + title += ( + f"
Mean Time per Epoch: " + f"{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02} / " + f"ETA Next 10 Epochs: {int(self.eta_ten_epochs_min)} min" + ) + if self.epoch_in_plateau_flag: + title += ( + f"
Epochs in Plateau: " + f"{self.epochs_in_plateau} / " + f"{self.config.optimization.early_stopping.plateau_patience}" + ) + title += ( + f"
Last Epoch Validation Loss: " + f"{self.last_epoch_val_loss:.3e}" + ) + if self.best_val_x is not None: + best_epoch = (self.best_val_x // self.epoch_size) + 1 + title += ( + f"
Best Epoch Validation Loss: " + f"{self.best_val_y:.3e} (epoch {best_epoch})" + ) self.set_message(title) - def is_timer_running(self): + @property + def is_timer_running(self) -> bool: + """Return True if the timer has started.""" return self.t0 is not None and self.is_running - def set_message(self, text): + def set_message(self, text: str): + """Set the chart title text.""" self.chart.setTitle(text) def check_messages( - self, timeout=10, times_to_check: int = 10, do_update: bool = True + self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True ): - """ - Polls for ZMQ messages and adds any received data to graph. + """Poll for ZMQ messages and adds any received data to graph. The message is a dictionary encoded as JSON: * event - options include @@ -371,22 +484,27 @@ def check_messages( * loss * val_loss + Args: + timeout: Message polling timeout in milliseconds. This is how often we will + check for new command messages. + times_to_check: How many times to check for new messages in the queue before + going back to polling with a timeout. Helps to clear backlogs of + messages if necessary. + do_update: If True (the default), update the GUI text. """ if self.sub and self.sub.poll(timeout, zmq.POLLIN): msg = jsonpickle.decode(self.sub.recv_string()) - # logger.info(msg) - if msg["event"] == "train_begin": - self.set_start_time(time()) + self.set_start_time(perf_counter()) self.current_job_output_type = msg["what"] - # make sure message matches current training job + # Make sure message matches current training job. if msg.get("what", "") == self.current_job_output_type: - if not self.is_timer_running(): - # We must have missed the train_begin message, so start timer now - self.set_start_time(time()) + if not self.is_timer_running: + # We must have missed the train_begin message, so start timer now. + self.set_start_time(perf_counter()) if msg["event"] == "train_end": self.set_end() @@ -400,47 +518,51 @@ def check_messages( "epoch_loss", ) if "val_loss" in msg["logs"].keys(): + # update variables and add points to plot + self.penultimate_epoch_val_loss = self.last_epoch_val_loss self.last_epoch_val_loss = msg["logs"]["val_loss"] self.add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["val_loss"], "val_loss", ) + # calculate timing and flags at new epoch + if self.penultimate_epoch_val_loss is not None: + mean_epoch_time = (perf_counter() - self.t0) / ( + self.epoch + 1 + ) + self.mean_epoch_time_min, self.mean_epoch_time_sec = divmod( + mean_epoch_time, 60 + ) + self.eta_ten_epochs_min = (mean_epoch_time * 10) // 60 + + val_loss_delta = ( + self.penultimate_epoch_val_loss + - self.last_epoch_val_loss + ) + self.epoch_in_plateau_flag = ( + val_loss_delta + < self.config.optimization.early_stopping.plateau_min_delta + ) or (self.best_val_y < self.last_epoch_val_loss) + self.epochs_in_plateau = ( + self.epochs_in_plateau + 1 + if self.epoch_in_plateau_flag + else 0 + ) self.on_epoch.emit() elif msg["event"] == "batch_end": self.last_batch_number = msg["batch"] self.add_datapoint( (self.epoch * self.epoch_size) + msg["batch"], msg["logs"]["loss"], + "batch", ) - # Check for messages again (up to times_to_check times) - if times_to_check: + # Check for messages again (up to times_to_check times). + if times_to_check > 0: self.check_messages( timeout=timeout, times_to_check=times_to_check - 1, do_update=False ) if do_update: self.update_runtime() - - -if __name__ == "__main__": - app = QtWidgets.QApplication([]) - win = LossViewer() - win.show() - - def test_point(x=[0]): - x[0] += 1 - i = x[0] + 1 - win.add_datapoint(i, i % 30 + 1) - - t = QtCore.QTimer() - t.timeout.connect(test_point) - t.start(20) - - win.set_message("Waiting for 3 seconds...") - t2 = QtCore.QTimer() - t2.timeout.connect(lambda: win.set_message("Running demo...")) - t2.start(3000) - - app.exec_() diff --git a/sleap/nn/callbacks.py b/sleap/nn/callbacks.py index 40384e04b..ed420408e 100644 --- a/sleap/nn/callbacks.py +++ b/sleap/nn/callbacks.py @@ -41,7 +41,7 @@ def __del__(self): self.context.term() def on_batch_end(self, batch, logs=None): - """ Called at the end of a training batch. """ + """Called at the end of a training batch.""" if self.socket.poll(self.timeout, zmq.POLLIN): msg = jsonpickle.decode(self.socket.recv_string()) logger.info(f"Received control message: {msg}") diff --git a/tests/gui/test_monitor.py b/tests/gui/test_monitor.py index 72673ed11..51af0ca92 100644 --- a/tests/gui/test_monitor.py +++ b/tests/gui/test_monitor.py @@ -1,9 +1,35 @@ -from sleap.nn.monitor import LossViewer +from turtle import title +from sleap.gui.widgets.monitor import LossViewer +from sleap import TrainingJobConfig -def test_monitor_release(qtbot): +def test_monitor_release(qtbot, min_centroid_model_path): win = LossViewer() win.show() + + # Ensure win loads config correctly + config = TrainingJobConfig.load_json(min_centroid_model_path, False) + win.reset(what="Model Type", config=config) + assert win.config.optimization.early_stopping.plateau_patience == 10 + + # Ensure all lines of update_runtime() are run error-free + win.is_running = True + win.t0 = 0 + # Enter "last_epoch_val_loss is not None" conditional + win.last_epoch_val_loss = win.config.optimization.early_stopping.plateau_min_delta + # Enter "penultimate_epoch_val_loss is not None" conditional + win.penultimate_epoch_val_loss = win.last_epoch_val_loss + win.mean_epoch_time_min = 0 + win.mean_epoch_time_sec = 10 + win.eta_ten_epochs_min = 2 + # Enter "epoch_in_plateau_flag" conditional + win.epoch_in_plateau_flag = True + win.epochs_in_plateau = 1 + # Enter "bes_val_x" conditional + win.best_val_x = 0 + win.best_val_y = win.last_epoch_val_loss + win.update_runtime() + win.close() # Make sure the first monitor released its zmq socket