diff --git a/sleap/gui/widgets/monitor.py b/sleap/gui/widgets/monitor.py index 5b0ce1ae8..fff8a0327 100644 --- a/sleap/gui/widgets/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -619,33 +619,47 @@ def __init__( self.canvas = None self.reset() - self.setup_zmq(zmq_context) + self._setup_zmq(zmq_context) def __del__(self): - self.unbind() + self._unbind() - def close(self): - """Disconnect from ZMQ ports and close the window.""" - self.unbind() - super().close() + @property + def is_timer_running(self) -> bool: + """Return True if the timer has started.""" + return self.t0 is not None and self.is_running - def unbind(self): - """Disconnect from all ZMQ sockets.""" - if self.sub is not None: - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" - if self.zmq_ctrl is not None: - url = self.zmq_ctrl.LAST_ENDPOINT - self.zmq_ctrl.unbind(url) - self.zmq_ctrl.close() - self.zmq_ctrl = None + return self._log_scale - # If we started out own zmq context, terminate it. - if not self.ctx_given and self.ctx is not None: - self.ctx.term() - self.ctx = None + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + # Set the log scale on the canvas + self.canvas.log_scale = self._log_scale + + @property + def ignore_outliers(self): + """Returns True if the plot ignores outliers.""" + + return self._ignore_outliers + + @ignore_outliers.setter + def ignore_outliers(self, val): + """Sets whether to ignore outliers in the plot.""" + + if isinstance(val, bool): + self._ignore_outliers = val + + # Set the ignore_outliers on the canvas + self.canvas.ignore_outliers = self._ignore_outliers def reset( self, @@ -680,12 +694,12 @@ def reset( field = QtWidgets.QCheckBox("Log Scale") field.setChecked(self.log_scale) - field.stateChanged.connect(self.toggle_log_scale) + field.stateChanged.connect(self._toggle_log_scale) control_layout.addWidget(field) field = QtWidgets.QCheckBox("Ignore Outliers") field.setChecked(self.ignore_outliers) - field.stateChanged.connect(self.toggle_ignore_outliers) + field.stateChanged.connect(self._toggle_ignore_outliers) control_layout.addWidget(field) control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) @@ -703,7 +717,7 @@ def reset( # Set connection action for when user selects another option. field.currentIndexChanged.connect( - lambda x: self.set_batches_to_show(self.batch_options[x]) + lambda x: self._set_batches_to_show(self.batch_options[x]) ) # Store field as property and add to layout. @@ -713,10 +727,10 @@ def reset( control_layout.addStretch(1) self.stop_button = QtWidgets.QPushButton("Stop Early") - self.stop_button.clicked.connect(self.stop) + self.stop_button.clicked.connect(self._stop) control_layout.addWidget(self.stop_button) self.cancel_button = QtWidgets.QPushButton("Cancel Training") - self.cancel_button.clicked.connect(self.cancel) + self.cancel_button.clicked.connect(self._cancel) control_layout.addWidget(self.cancel_button) widget = QtWidgets.QWidget() @@ -748,62 +762,16 @@ def reset( self.last_batch_number = 0 self.is_running = False - @property - def log_scale(self): - """Returns True if the plot has a log scale for y-axis.""" - - return self._log_scale - - @log_scale.setter - def log_scale(self, val): - """Sets the scale of the y axis to log if True else linear.""" - - if isinstance(val, bool): - self._log_scale = val - - # Set the log scale on the canvas - self.canvas.log_scale = self._log_scale - - @property - def ignore_outliers(self): - """Returns True if the plot ignores outliers.""" - - return self._ignore_outliers - - @ignore_outliers.setter - def ignore_outliers(self, val): - """Sets whether to ignore outliers in the plot.""" - - if isinstance(val, bool): - self._ignore_outliers = val - - # Set the ignore_outliers on the canvas - self.canvas.ignore_outliers = self._ignore_outliers - - def toggle_ignore_outliers(self): - """Toggles whether to ignore outliers in chart scaling.""" - - self.ignore_outliers = not self.ignore_outliers - - def toggle_log_scale(self): - """Toggle whether to use log-scaled y-axis.""" - - self.log_scale = not self.log_scale - - def set_batches_to_show(self, batches: str): - """Set the number of batches to show on the x-axis. + def set_message(self, text: str): + """Set the chart title text.""" + self.canvas.set_title(text) - Args: - batches: Number of batches as a string. If numeric, this will be converted - to an integer. If non-numeric string (e.g., "All"), then all batches - will be shown. - """ - if batches.isdigit(): - self.batches_to_show = int(batches) - else: - self.batches_to_show = -1 + def close(self): + """Disconnect from ZMQ ports and close the window.""" + self._unbind() + super().close() - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): + def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None): """Connect to ZMQ ports that listen to commands and updates. Args: @@ -865,124 +833,23 @@ def find_free_port(port: int, zmq_context: zmq.Context): # Set timer to poll for messages. self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.check_messages) + self.timer.timeout.connect(self._check_messages) self.timer.start(self.message_poll_time_ms) - def cancel(self): - """Set the cancel flag.""" - self.canceled = True - if self.cancel_button is not None: - self.cancel_button.setText("Canceling...") - self.cancel_button.setEnabled(False) - - def stop(self): - """Send command to stop training.""" - if self.zmq_ctrl is not None: - # Send command to stop training. - logger.info("Sending command to stop training.") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - - # Disable the button to prevent double messages. - if self.stop_button is not None: - self.stop_button.setText("Stopping...") - self.stop_button.setEnabled(False) - - def add_datapoint(self, x: int, y: float, which: str): - """Add a data point to graph. + def _set_batches_to_show(self, batches: str): + """Set the number of batches to show on the x-axis. Args: - x: The batch number (out of all epochs, not just current), or epoch. - y: The loss value. - which: Type of data point we're adding. Possible values are: - * "batch" (loss for the batch) - * "epoch_loss" (loss for the entire epoch) - * "val_loss" (validation loss for the epoch) + batches: Number of batches as a string. If numeric, this will be converted + to an integer. If non-numeric string (e.g., "All"), then all batches + will be shown. """ - if which == "batch": - self.X.append(x) - self.Y.append(y) - - # Redraw batch at intervals (faster than plotting every batch). - draw_batch = False - if self.last_redraw_batch is None: - draw_batch = True - else: - dt = perf_counter() - self.last_redraw_batch - draw_batch = (dt * 1000) >= self.redraw_batch_time_ms - - if draw_batch: - self.last_redraw_batch = perf_counter() - if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: - xs, ys = self.X, self.Y - else: - xs, ys = ( - self.X[-self.batches_to_show :], - self.Y[-self.batches_to_show :], - ) - - # Set data, resize and redraw the plot - self._set_data_on_scatter(xs, ys, which) - self._resize_axes(xs, ys) - + if batches.isdigit(): + self.batches_to_show = int(batches) else: + self.batches_to_show = -1 - if which == "val_loss": - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self._set_data_on_scatter([x], [y], "val_loss_best") - - # Add data and redraw the plot - self._add_data_to_plot(x, y, which) - self._redraw_plot() - - def _set_data_on_scatter(self, xs, ys, which): - """Add data to a scatter plot. - - Not to be used with line plots. - - Args: - xs: The x-coordinates of the data points. - ys: The y-coordinates of the data points. - which: The type of data point. Possible values are: - * "batch" - * "val_loss_best" - """ - - self.canvas.set_data_on_scatter(xs, ys, which) - - def _add_data_to_plot(self, x, y, which): - """Add data to a line plot. - - Not to be used with scatter plots. - - Args: - x: The x-coordinate of the data point. - y: The y-coordinate of the data point. - which: The type of data point. Possible values are: - * "epoch_loss" - * "val_loss" - """ - - self.canvas.add_data_to_plot(x, y, which) - - def _redraw_plot(self): - """Redraw the plot.""" - - self.canvas.redraw_plot() - - def _resize_axes(self, x, y): - """Resize axes to fit data. - - This is only called when plotting batches. - - Args: - x: The x-coordinates of the data points. - y: The y-coordinates of the data points. - """ - self.canvas.resize_axes(x, y) - - def set_start_time(self, t0: float): + def _set_start_time(self, t0: float): """Mark the start flag and time of the run. Args: @@ -991,11 +858,7 @@ def set_start_time(self, t0: float): self.t0 = t0 self.is_running = True - def set_end(self): - """Mark the end of the run.""" - self.is_running = False - - def update_runtime(self): + def _update_runtime(self): """Update the title text with the current running time.""" if self.is_timer_running: @@ -1019,16 +882,7 @@ def update_runtime(self): epoch_size=self.epoch_size, ) - @property - def is_timer_running(self) -> bool: - """Return True if the timer has started.""" - return self.t0 is not None and self.is_running - - def set_message(self, text: str): - """Set the chart title text.""" - self.canvas.set_title(text) - - def check_messages( + def _check_messages( self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True ): """Poll for ZMQ messages and adds any received data to graph. @@ -1060,7 +914,7 @@ def check_messages( msg = jsonpickle.decode(self.sub.recv_string()) if msg["event"] == "train_begin": - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) self.current_job_output_type = msg["what"] # Make sure message matches current training job. @@ -1068,15 +922,15 @@ def check_messages( if not self.is_timer_running: # We must have missed the train_begin message, so start timer now. - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) if msg["event"] == "train_end": - self.set_end() + self._set_end() elif msg["event"] == "epoch_begin": self.epoch = msg["epoch"] elif msg["event"] == "epoch_end": self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["loss"], "epoch_loss", @@ -1085,7 +939,7 @@ def check_messages( # update variables and add points to plot self.penultimate_epoch_val_loss = self.last_epoch_val_loss self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["val_loss"], "val_loss", @@ -1116,7 +970,7 @@ def check_messages( self.on_epoch.emit() elif msg["event"] == "batch_end": self.last_batch_number = msg["batch"] - self.add_datapoint( + self._add_datapoint( (self.epoch * self.epoch_size) + msg["batch"], msg["logs"]["loss"], "batch", @@ -1124,9 +978,155 @@ def check_messages( # Check for messages again (up to times_to_check times). if times_to_check > 0: - self.check_messages( + self._check_messages( timeout=timeout, times_to_check=times_to_check - 1, do_update=False ) if do_update: - self.update_runtime() + self._update_runtime() + + def _add_datapoint(self, x: int, y: float, which: str): + """Add a data point to graph. + + Args: + x: The batch number (out of all epochs, not just current), or epoch. + y: The loss value. + which: Type of data point we're adding. Possible values are: + * "batch" (loss for the batch) + * "epoch_loss" (loss for the entire epoch) + * "val_loss" (validation loss for the epoch) + """ + if which == "batch": + self.X.append(x) + self.Y.append(y) + + # Redraw batch at intervals (faster than plotting every batch). + draw_batch = False + if self.last_redraw_batch is None: + draw_batch = True + else: + dt = perf_counter() - self.last_redraw_batch + draw_batch = (dt * 1000) >= self.redraw_batch_time_ms + + if draw_batch: + self.last_redraw_batch = perf_counter() + if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: + xs, ys = self.X, self.Y + else: + xs, ys = ( + self.X[-self.batches_to_show :], + self.Y[-self.batches_to_show :], + ) + + # Set data, resize and redraw the plot + self._set_data_on_scatter(xs, ys, which) + self._resize_axes(xs, ys) + + else: + + if which == "val_loss": + if self.best_val_y is None or y < self.best_val_y: + self.best_val_x = x + self.best_val_y = y + self._set_data_on_scatter([x], [y], "val_loss_best") + + # Add data and redraw the plot + self._add_data_to_plot(x, y, which) + self._redraw_plot() + + def _set_data_on_scatter(self, xs, ys, which): + """Add data to a scatter plot. + + Not to be used with line plots. + + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ + + self.canvas.set_data_on_scatter(xs, ys, which) + + def _add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + self.canvas.add_data_to_plot(x, y, which) + + def _redraw_plot(self): + """Redraw the plot.""" + + self.canvas.redraw_plot() + + def _resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + self.canvas.resize_axes(x, y) + + def _toggle_ignore_outliers(self): + """Toggles whether to ignore outliers in chart scaling.""" + + self.ignore_outliers = not self.ignore_outliers + + def _toggle_log_scale(self): + """Toggle whether to use log-scaled y-axis.""" + + self.log_scale = not self.log_scale + + def _stop(self): + """Send command to stop training.""" + if self.zmq_ctrl is not None: + # Send command to stop training. + logger.info("Sending command to stop training.") + self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) + + # Disable the button to prevent double messages. + if self.stop_button is not None: + self.stop_button.setText("Stopping...") + self.stop_button.setEnabled(False) + + def _cancel(self): + """Set the cancel flag.""" + self.canceled = True + if self.cancel_button is not None: + self.cancel_button.setText("Canceling...") + self.cancel_button.setEnabled(False) + + def _unbind(self): + """Disconnect from all ZMQ sockets.""" + if self.sub is not None: + self.sub.unbind(self.sub.LAST_ENDPOINT) + self.sub.close() + self.sub = None + + if self.zmq_ctrl is not None: + url = self.zmq_ctrl.LAST_ENDPOINT + self.zmq_ctrl.unbind(url) + self.zmq_ctrl.close() + self.zmq_ctrl = None + + # If we started out own zmq context, terminate it. + if not self.ctx_given and self.ctx is not None: + self.ctx.term() + self.ctx = None + + def _set_end(self): + """Mark the end of the run.""" + self.is_running = False diff --git a/tests/gui/test_monitor.py b/tests/gui/test_monitor.py index 7ea81d6dc..e0abea692 100644 --- a/tests/gui/test_monitor.py +++ b/tests/gui/test_monitor.py @@ -30,7 +30,7 @@ def test_monitor_release(qtbot, min_centroid_model_path): # Enter "bes_val_x" conditional win.best_val_x = 0 win.best_val_y = win.last_epoch_val_loss - win.update_runtime() + win._update_runtime() win.close()