diff --git a/src/openjd/adaptor_runtime/_http/sockets.py b/src/openjd/adaptor_runtime/_http/sockets.py index 58e3e34..f057f6d 100644 --- a/src/openjd/adaptor_runtime/_http/sockets.py +++ b/src/openjd/adaptor_runtime/_http/sockets.py @@ -60,15 +60,22 @@ def get_process_socket_path(self, namespace: str | None = None, *, create_dir: b len(socket_name) <= _PID_MAX_LENGTH ), f"PID too long. Only PIDs up to {_PID_MAX_LENGTH} digits are supported." - return os.path.join(self.get_socket_dir(namespace, create=create_dir), socket_name) - - def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) -> str: + return self.get_socket_path(socket_name, namespace, create_dir=create_dir) + + def get_socket_path( + self, + base_socket_name: str, + namespace: str | None = None, + *, + create_dir: bool = False, + ) -> str: """ Gets the base directory for sockets used in Adaptor IPC Args: + base_socket_name (str): The name of the socket namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go - create (bool): Whether to create the directory or not. Default is false. + create_dir (bool): Whether to create the directory or not. Default is false. Raises: NonvalidSocketPathException: Raised if the user has configured a socket base directory @@ -77,11 +84,19 @@ def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) not be raised if the user has configured a socket base directory. """ - def create_dir(path: str) -> str: - if create: + def mkdir(path: str) -> str: + if create_dir: os.makedirs(path, mode=0o700, exist_ok=True) return path + def gen_socket_path(dir: str, base_name: str): + i = 0 + name = base_name + while os.path.exists(os.path.join(dir, name)): + i = i + 1 + name = f"{base_name}_{i}" + return os.path.join(dir, name) + rel_path = os.path.join(".openjd", "adaptors", "sockets") if namespace: rel_path = os.path.join(rel_path, namespace) @@ -91,18 +106,21 @@ def create_dir(path: str) -> str: # First try home directory home_dir = os.path.expanduser("~") socket_dir = os.path.join(home_dir, rel_path) + socket_path = gen_socket_path(socket_dir, base_socket_name) try: - self.verify_socket_path(socket_dir) + self.verify_socket_path(socket_path) except NonvalidSocketPathException as e: reasons.append(f"Cannot create sockets directory in the home directory because: {e}") else: - return create_dir(socket_dir) + mkdir(socket_dir) + return socket_path # Last resort is the temp directory temp_dir = tempfile.gettempdir() socket_dir = os.path.join(temp_dir, rel_path) + socket_path = gen_socket_path(socket_dir, base_socket_name) try: - self.verify_socket_path(socket_dir) + self.verify_socket_path(socket_path) except NonvalidSocketPathException as e: reasons.append(f"Cannot create sockets directory in the temp directory because: {e}") else: @@ -113,7 +131,8 @@ def create_dir(path: str) -> str: "sticky bit (restricted deletion flag) set" ) else: - return create_dir(socket_dir) + mkdir(socket_dir) + return socket_path raise NoSocketPathFoundException( "Failed to find a suitable base directory to create sockets in for the following " diff --git a/test/openjd/adaptor_runtime/unit/http/test_sockets.py b/test/openjd/adaptor_runtime/unit/http/test_sockets.py index f6e9c5f..c0cb336 100644 --- a/test/openjd/adaptor_runtime/unit/http/test_sockets.py +++ b/test/openjd/adaptor_runtime/unit/http/test_sockets.py @@ -29,40 +29,21 @@ class TestGetProcessSocketPath: Tests for SocketDirectories.get_process_socket_path() """ - @pytest.fixture - def socket_dir(self) -> str: - return "/path/to/socket/dir" - - @pytest.fixture(autouse=True) - def mock_socket_dir(self, socket_dir: str) -> Generator[MagicMock, None, None]: - with patch.object(SocketDirectories, "get_socket_dir") as mock: - mock.return_value = socket_dir - yield mock - - @pytest.mark.parametrize( - argnames=["create_dir"], - argvalues=[[True], [False]], - ids=["creates dir", "does not create dir"], - ) @patch.object(sockets.os, "getpid", return_value=1234) def test_gets_path( self, mock_getpid: MagicMock, - socket_dir: str, - mock_socket_dir: MagicMock, - create_dir: bool, ) -> None: # GIVEN namespace = "my-namespace" subject = SocketDirectoriesStub() # WHEN - result = subject.get_process_socket_path(namespace, create_dir=create_dir) + result = subject.get_process_socket_path(namespace) # THEN - assert result == os.path.join(socket_dir, str(mock_getpid.return_value)) + assert result.endswith(os.path.join(namespace, str(mock_getpid.return_value))) mock_getpid.assert_called_once() - mock_socket_dir.assert_called_once_with(namespace, create=create_dir) @patch.object(sockets.os, "getpid", return_value="a" * (sockets._PID_MAX_LENGTH + 1)) def test_asserts_max_pid_length(self, mock_getpid: MagicMock): @@ -79,11 +60,17 @@ def test_asserts_max_pid_length(self, mock_getpid: MagicMock): ) mock_getpid.assert_called_once() - class TestGetSocketDir: + class TestGetSocketPath: """ - Tests for SocketDirectories.get_socket_dir() + Tests for SocketDirectories.get_socket_path() """ + @pytest.fixture(autouse=True) + def mock_exists(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os.path, "exists") as mock: + mock.return_value = False + yield mock + @pytest.fixture(autouse=True) def mock_makedirs(self) -> Generator[MagicMock, None, None]: with patch.object(sockets.os, "makedirs") as mock: @@ -116,7 +103,7 @@ def test_gets_home_dir( subject = SocketDirectoriesStub() # WHEN - result = subject.get_socket_dir() + result = subject.get_socket_path("sock") # THEN mock_expanduser.assert_called_once_with("~") @@ -138,7 +125,7 @@ def test_gets_temp_dir( subject = SocketDirectoriesStub() # WHEN - result = subject.get_socket_dir() + result = subject.get_socket_path("sock") # THEN mock_gettempdir.assert_called_once() @@ -160,11 +147,13 @@ def test_create_dir(self, mock_makedirs: MagicMock, create: bool) -> None: subject = SocketDirectoriesStub() # WHEN - result = subject.get_socket_dir(create=create) + result = subject.get_socket_path("sock", create_dir=create) # THEN if create: - mock_makedirs.assert_called_once_with(result, mode=0o700, exist_ok=True) + mock_makedirs.assert_called_once_with( + os.path.dirname(result), mode=0o700, exist_ok=True + ) else: mock_makedirs.assert_not_called() @@ -174,20 +163,20 @@ def test_uses_namespace(self) -> None: subject = SocketDirectoriesStub() # WHEN - result = subject.get_socket_dir(namespace) + result = subject.get_socket_path("sock", namespace) # THEN - assert result.endswith(namespace) + assert os.path.dirname(result).endswith(namespace) @patch.object(SocketDirectoriesStub, "verify_socket_path") - def test_raises_when_no_valid_dir_found(self, mock_verify_socket_path: MagicMock) -> None: + def test_raises_when_no_valid_path_found(self, mock_verify_socket_path: MagicMock) -> None: # GIVEN mock_verify_socket_path.side_effect = NonvalidSocketPathException() subject = SocketDirectoriesStub() # WHEN with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + subject.get_socket_path("sock") # THEN assert raised_exc.match( @@ -211,7 +200,7 @@ def test_raises_when_no_tmpdir_sticky_bit( # WHEN with pytest.raises(NoSocketPathFoundException) as raised_exc: - subject.get_socket_dir() + subject.get_socket_path("sock") # THEN assert raised_exc.match( @@ -221,6 +210,24 @@ def test_raises_when_no_tmpdir_sticky_bit( ) ) + @patch.object(sockets.os.path, "exists") + def test_handles_socket_name_collisions( + self, + mock_exists: MagicMock, + ) -> None: + # GIVEN + sock_name = "sock" + existing_sock_names = [sock_name, f"{sock_name}_1", f"{sock_name}_2"] + mock_exists.side_effect = ([True] * len(existing_sock_names)) + [False] + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_socket_path(sock_name) + + # THEN + assert result.endswith(f"{sock_name}_3") + mock_exists.call_count == len(existing_sock_names) + 1 + class TestLinuxSocketDirectories: @pytest.mark.parametrize(