Skip to content

Commit

Permalink
fix: handle unix socket name collisions
Browse files Browse the repository at this point in the history
Signed-off-by: Jericho Tolentino <[email protected]>
  • Loading branch information
jericht committed Apr 30, 2024
1 parent d4a224f commit 77cce05
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 42 deletions.
39 changes: 29 additions & 10 deletions src/openjd/adaptor_runtime/_http/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ def get_process_socket_path(self, namespace: str | None = None, *, create_dir: b
len(socket_name) <= _PID_MAX_LENGTH
), f"PID too long. Only PIDs up to {_PID_MAX_LENGTH} digits are supported."

return os.path.join(self.get_socket_dir(namespace, create=create_dir), socket_name)

def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) -> str:
return self.get_socket_path(socket_name, namespace, create_dir=create_dir)

def get_socket_path(
self,
base_socket_name: str,
namespace: str | None = None,
*,
create_dir: bool = False,
) -> str:
"""
Gets the base directory for sockets used in Adaptor IPC
Args:
base_socket_name (str): The name of the socket
namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go
create (bool): Whether to create the directory or not. Default is false.
create_dir (bool): Whether to create the directory or not. Default is false.
Raises:
NonvalidSocketPathException: Raised if the user has configured a socket base directory
Expand All @@ -77,11 +84,19 @@ def get_socket_dir(self, namespace: str | None = None, *, create: bool = False)
not be raised if the user has configured a socket base directory.
"""

def create_dir(path: str) -> str:
if create:
def mkdir(path: str) -> str:
if create_dir:
os.makedirs(path, mode=0o700, exist_ok=True)
return path

def gen_socket_path(dir: str, base_name: str):
i = 0
name = base_name
while os.path.exists(os.path.join(dir, name)):
i = i + 1
name = f"{base_name}_{i}"
return os.path.join(dir, name)

rel_path = os.path.join(".openjd", "adaptors", "sockets")
if namespace:
rel_path = os.path.join(rel_path, namespace)
Expand All @@ -91,18 +106,21 @@ def create_dir(path: str) -> str:
# First try home directory
home_dir = os.path.expanduser("~")
socket_dir = os.path.join(home_dir, rel_path)
socket_path = gen_socket_path(socket_dir, base_socket_name)
try:
self.verify_socket_path(socket_dir)
self.verify_socket_path(socket_path)
except NonvalidSocketPathException as e:
reasons.append(f"Cannot create sockets directory in the home directory because: {e}")
else:
return create_dir(socket_dir)
mkdir(socket_dir)
return socket_path

# Last resort is the temp directory
temp_dir = tempfile.gettempdir()
socket_dir = os.path.join(temp_dir, rel_path)
socket_path = gen_socket_path(socket_dir, base_socket_name)
try:
self.verify_socket_path(socket_dir)
self.verify_socket_path(socket_path)
except NonvalidSocketPathException as e:
reasons.append(f"Cannot create sockets directory in the temp directory because: {e}")
else:
Expand All @@ -113,7 +131,8 @@ def create_dir(path: str) -> str:
"sticky bit (restricted deletion flag) set"
)
else:
return create_dir(socket_dir)
mkdir(socket_dir)
return socket_path

raise NoSocketPathFoundException(
"Failed to find a suitable base directory to create sockets in for the following "
Expand Down
71 changes: 39 additions & 32 deletions test/openjd/adaptor_runtime/unit/http/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,21 @@ class TestGetProcessSocketPath:
Tests for SocketDirectories.get_process_socket_path()
"""

@pytest.fixture
def socket_dir(self) -> str:
return "/path/to/socket/dir"

@pytest.fixture(autouse=True)
def mock_socket_dir(self, socket_dir: str) -> Generator[MagicMock, None, None]:
with patch.object(SocketDirectories, "get_socket_dir") as mock:
mock.return_value = socket_dir
yield mock

@pytest.mark.parametrize(
argnames=["create_dir"],
argvalues=[[True], [False]],
ids=["creates dir", "does not create dir"],
)
@patch.object(sockets.os, "getpid", return_value=1234)
def test_gets_path(
self,
mock_getpid: MagicMock,
socket_dir: str,
mock_socket_dir: MagicMock,
create_dir: bool,
) -> None:
# GIVEN
namespace = "my-namespace"
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_process_socket_path(namespace, create_dir=create_dir)
result = subject.get_process_socket_path(namespace)

# THEN
assert result == os.path.join(socket_dir, str(mock_getpid.return_value))
assert result.endswith(os.path.join(namespace, str(mock_getpid.return_value)))
mock_getpid.assert_called_once()
mock_socket_dir.assert_called_once_with(namespace, create=create_dir)

@patch.object(sockets.os, "getpid", return_value="a" * (sockets._PID_MAX_LENGTH + 1))
def test_asserts_max_pid_length(self, mock_getpid: MagicMock):
Expand All @@ -79,11 +60,17 @@ def test_asserts_max_pid_length(self, mock_getpid: MagicMock):
)
mock_getpid.assert_called_once()

class TestGetSocketDir:
class TestGetSocketPath:
"""
Tests for SocketDirectories.get_socket_dir()
Tests for SocketDirectories.get_socket_path()
"""

@pytest.fixture(autouse=True)
def mock_exists(self) -> Generator[MagicMock, None, None]:
with patch.object(sockets.os.path, "exists") as mock:
mock.return_value = False
yield mock

@pytest.fixture(autouse=True)
def mock_makedirs(self) -> Generator[MagicMock, None, None]:
with patch.object(sockets.os, "makedirs") as mock:
Expand Down Expand Up @@ -116,7 +103,7 @@ def test_gets_home_dir(
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_socket_dir()
result = subject.get_socket_path("sock")

# THEN
mock_expanduser.assert_called_once_with("~")
Expand All @@ -138,7 +125,7 @@ def test_gets_temp_dir(
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_socket_dir()
result = subject.get_socket_path("sock")

# THEN
mock_gettempdir.assert_called_once()
Expand All @@ -160,11 +147,13 @@ def test_create_dir(self, mock_makedirs: MagicMock, create: bool) -> None:
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_socket_dir(create=create)
result = subject.get_socket_path("sock", create_dir=create)

# THEN
if create:
mock_makedirs.assert_called_once_with(result, mode=0o700, exist_ok=True)
mock_makedirs.assert_called_once_with(
os.path.dirname(result), mode=0o700, exist_ok=True
)
else:
mock_makedirs.assert_not_called()

Expand All @@ -174,20 +163,20 @@ def test_uses_namespace(self) -> None:
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_socket_dir(namespace)
result = subject.get_socket_path("sock", namespace)

# THEN
assert result.endswith(namespace)
assert os.path.dirname(result).endswith(namespace)

@patch.object(SocketDirectoriesStub, "verify_socket_path")
def test_raises_when_no_valid_dir_found(self, mock_verify_socket_path: MagicMock) -> None:
def test_raises_when_no_valid_path_found(self, mock_verify_socket_path: MagicMock) -> None:
# GIVEN
mock_verify_socket_path.side_effect = NonvalidSocketPathException()
subject = SocketDirectoriesStub()

# WHEN
with pytest.raises(NoSocketPathFoundException) as raised_exc:
subject.get_socket_dir()
subject.get_socket_path("sock")

# THEN
assert raised_exc.match(
Expand All @@ -211,7 +200,7 @@ def test_raises_when_no_tmpdir_sticky_bit(

# WHEN
with pytest.raises(NoSocketPathFoundException) as raised_exc:
subject.get_socket_dir()
subject.get_socket_path("sock")

# THEN
assert raised_exc.match(
Expand All @@ -221,6 +210,24 @@ def test_raises_when_no_tmpdir_sticky_bit(
)
)

@patch.object(sockets.os.path, "exists")
def test_handles_socket_name_collisions(
self,
mock_exists: MagicMock,
) -> None:
# GIVEN
sock_name = "sock"
existing_sock_names = [sock_name, f"{sock_name}_1", f"{sock_name}_2"]
mock_exists.side_effect = ([True] * len(existing_sock_names)) + [False]
subject = SocketDirectoriesStub()

# WHEN
result = subject.get_socket_path(sock_name)

# THEN
assert result.endswith(f"{sock_name}_3")
mock_exists.call_count == len(existing_sock_names) + 1


class TestLinuxSocketDirectories:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 77cce05

Please sign in to comment.