Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix port already in use when running inference #2064

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
27 changes: 27 additions & 0 deletions sleap/gui/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Generic module containing utilities used for the GUI."""

import zmq
import time
from typing import Optional


Expand All @@ -12,6 +13,7 @@ def is_port_free(port: int, zmq_context: Optional[zmq.Context] = None) -> bool:
try:
socket.bind(address)
socket.unbind(address)
time.sleep(0.1)
7174Andy marked this conversation as resolved.
Show resolved Hide resolved
7174Andy marked this conversation as resolved.
Show resolved Hide resolved
return True
except zmq.error.ZMQError:
return False
Expand All @@ -26,3 +28,28 @@ def select_zmq_port(zmq_context: Optional[zmq.Context] = None) -> int:
port = socket.bind_to_random_port("tcp://127.0.0.1")
socket.close()
return port


def find_free_port(port: int, zmq_context: zmq.Context):
"""Find free port to bind to.

Args:
port: The port to start searching from.
zmq_context: The ZMQ context to use.

Returns:
The free port.
"""
attempts = 0
max_attempts = 10
while not is_port_free(port=port, zmq_context=zmq_context):
if attempts >= max_attempts:
raise RuntimeError(
f"Could not find free port to display training progress after "
f"{max_attempts} attempts. Please check your network settings "
"or use the CLI `sleap-train` command."
)
port = select_zmq_port(zmq_context=zmq_context)
attempts += 1

Comment on lines +43 to +54
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation and improve error handling

The function should validate the port number and provide more specific error messages about TIME_WAIT states.

+    if not isinstance(port, int) or port < 1 or port > 65535:
+        raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
+
     attempts = 0
     max_attempts = 10
     while not is_port_free(port=port, zmq_context=zmq_context):
         if attempts >= max_attempts:
             raise RuntimeError(
-                f"Could not find free port to display training progress after "
-                f"{max_attempts} attempts. Please check your network settings "
-                "or use the CLI `sleap-train` command."
+                f"Could not find free port after {max_attempts} attempts. "
+                "This might be due to ports in TIME_WAIT state. "
+                "Please wait a few minutes and try again, or use a different "
+                "port range. Alternatively, use the CLI `sleap-train` command."
             )
         port = select_zmq_port(zmq_context=zmq_context)
         attempts += 1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
attempts = 0
max_attempts = 10
while not is_port_free(port=port, zmq_context=zmq_context):
if attempts >= max_attempts:
raise RuntimeError(
f"Could not find free port to display training progress after "
f"{max_attempts} attempts. Please check your network settings "
"or use the CLI `sleap-train` command."
)
port = select_zmq_port(zmq_context=zmq_context)
attempts += 1
if not isinstance(port, int) or port < 1 or port > 65535:
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
attempts = 0
max_attempts = 10
while not is_port_free(port=port, zmq_context=zmq_context):
if attempts >= max_attempts:
raise RuntimeError(
f"Could not find free port after {max_attempts} attempts. "
"This might be due to ports in TIME_WAIT state. "
"Please wait a few minutes and try again, or use a different "
"port range. Alternatively, use the CLI `sleap-train` command."
)
port = select_zmq_port(zmq_context=zmq_context)
attempts += 1

return port
26 changes: 1 addition & 25 deletions sleap/gui/widgets/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import matplotlib.transforms as mtransforms
from qtpy import QtCore, QtWidgets

from sleap.gui.utils import is_port_free, select_zmq_port
from sleap.gui.utils import find_free_port
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding retry logic for port binding.

While using find_free_port helps, there's still a potential race condition between finding a free port and binding to it. Another process could take the port in this window.

Consider implementing a retry mechanism:

 def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None):
     # ... existing setup code ...
     
     # Find a free port and bind to it.
     self.zmq_ports["publish_port"] = find_free_port(
         port=self.zmq_ports["publish_port"], zmq_context=self.ctx
     )
     publish_address = f"tcp://127.0.0.1:{self.zmq_ports['publish_port']}"
-    self.sub.bind(publish_address)
+    max_retries = 3
+    for attempt in range(max_retries):
+        try:
+            self.sub.bind(publish_address)
+            break
+        except zmq.error.ZMQError as e:
+            if attempt == max_retries - 1:
+                raise RuntimeError(f"Failed to bind to {publish_address} after {max_retries} attempts: {e}")
+            self.zmq_ports["publish_port"] = find_free_port(
+                port=self.zmq_ports["publish_port"], zmq_context=self.ctx
+            )
+            publish_address = f"tcp://127.0.0.1:{self.zmq_ports['publish_port']}"

Also applies to: 1093-1100

from sleap.gui.widgets.mpl import MplCanvas
from sleap.nn.config.training_job import TrainingJobConfig

Expand Down Expand Up @@ -788,30 +788,6 @@ def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None):
self.sub = self.ctx.socket(zmq.SUB)
self.sub.subscribe("")

def find_free_port(port: int, zmq_context: zmq.Context):
"""Find free port to bind to.

Args:
port: The port to start searching from.
zmq_context: The ZMQ context to use.

Returns:
The free port.
"""
attempts = 0
max_attempts = 10
while not is_port_free(port=port, zmq_context=zmq_context):
if attempts >= max_attempts:
raise RuntimeError(
f"Could not find free port to display training progress after "
f"{max_attempts} attempts. Please check your network settings "
"or use the CLI `sleap-train` command."
)
port = select_zmq_port(zmq_context=self.ctx)
attempts += 1

return port

# Find a free port and bind to it.
self.zmq_ports["publish_port"] = find_free_port(
port=self.zmq_ports["publish_port"], zmq_context=self.ctx
Expand Down
Loading