Skip to content

Commit

Permalink
fix windows tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jericho Tolentino <[email protected]>
  • Loading branch information
jericht committed May 28, 2024
1 parent 6ff4dab commit 302f267
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 167 deletions.
6 changes: 6 additions & 0 deletions src/openjd/adaptor_runtime/_background/backend_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
from pathlib import Path
from threading import Thread, Event
import traceback
from types import FrameType
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -126,6 +127,11 @@ def run(self, *, on_connection_file_written: List[Callable[[], None]] | None = N
_logger.info("Shutting down server...")
shutdown_event.set()
raise
except Exception as e:
_logger.critical(f"Unexpected error occurred when writing to connection file: {e}")
_logger.critical(traceback.format_exc())
_logger.info("Shutting down server")
shutdown_event.set()
else:
if on_connection_file_written:
callbacks = list(on_connection_file_written)
Expand Down
84 changes: 65 additions & 19 deletions src/openjd/adaptor_runtime/_background/frontend_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from threading import Event
from types import FrameType
from types import ModuleType
from typing import Optional, Dict
from typing import Optional, Callable, Dict

from .._osname import OSName
from ..process._logging import _ADAPTOR_OUTPUT_LEVEL
Expand Down Expand Up @@ -169,7 +169,7 @@ def init(

# Wait for backend process to create connection file
try:
_wait_for_file(str(connection_file_path), timeout_s=5)
_wait_for_connection_file(str(connection_file_path), max_retries=5, interval_s=1)
except TimeoutError:
_logger.error(
"Backend process failed to write connection file in time at: "
Expand Down Expand Up @@ -414,38 +414,84 @@ def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None:
self.cancel()


def _wait_for_file(filepath: str, timeout_s: float, interval_s: float = 1) -> None:
def _wait_for_connection_file(
filepath: str, max_retries: int, interval_s: float = 1
) -> ConnectionSettings:
"""
Waits for a file at the specified path to exist and to be openable.
Waits for a connection file at the specified path to exist, be openable, and have connection settings.
Args:
filepath (str): The file path to check.
timeout_s (float): The max duration to wait before timing out, in seconds.
max_retries (int): The max number of retries before timing out.
interval_s (float, optional): The interval between checks, in seconds. Default is 0.01s.
Raises:
TimeoutError: Raised when the file does not exist after timeout_s seconds.
"""
wait_for(
description=f"File '{filepath}' to exist",
predicate=lambda: os.path.exists(filepath),
interval_s=interval_s,
max_retries=max_retries,
)

def _wait():
if time.time() - start < timeout_s:
time.sleep(interval_s)
else:
raise TimeoutError(f"Timed out after {timeout_s}s waiting for file at {filepath}")

start = time.time()
while not os.path.exists(filepath):
_wait()
# Wait before opening to give the backend time to open it first
time.sleep(interval_s)

while True:
# Wait before opening to give the backend time to open it first
_wait()
def file_is_openable() -> bool:
try:
open(filepath, mode="r").close()
break
except IOError:
# File is not available yet
pass
return False
else:
return True

wait_for(
description=f"File '{filepath}' to be openable",
predicate=file_is_openable,
interval_s=interval_s,
max_retries=max_retries,
)

def connection_file_loadable() -> bool:
try:
ConnectionSettingsFileLoader(Path(filepath)).load()
except Exception:
return False
else:
return True

wait_for(
description=f"File '{filepath}' to have valid ConnectionSettings",
predicate=connection_file_loadable,
interval_s=interval_s,
max_retries=max_retries,
)

return ConnectionSettingsFileLoader(Path(filepath)).load()


def wait_for(
*,
description: str,
predicate: Callable[[], bool],
interval_s: float,
max_retries: int | None = None,
) -> None:
if max_retries is not None:
assert max_retries >= 0, "max_retries must be a non-negative integer"
assert interval_s > 0, "interval_s must be a positive number"

_logger.info(f"Waiting for {description}")
retry_count = 0
while not predicate():
if max_retries is not None and retry_count >= max_retries:
raise TimeoutError(f"Timed out waiting for {description}")

_logger.info(f"Retrying in {interval_s}s...")
retry_count += 1
time.sleep(interval_s)


class AdaptorFailedException(Exception):
Expand Down
4 changes: 2 additions & 2 deletions src/openjd/adaptor_runtime/_background/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def load(self) -> ConnectionSettings:
with open(self.file_path) as conn_file:
loaded_settings = json.load(conn_file)
except OSError as e:
errmsg = f"Failed to open connection file: {e}"
errmsg = f"Failed to open connection file '{self.file_path}': {e}"
_logger.error(errmsg)
raise ConnectionSettingsLoadingError(errmsg) from e
except json.JSONDecodeError as e:
errmsg = f"Failed to decode connection file: {e}"
errmsg = f"Failed to decode connection file '{self.file_path}': {e}"
_logger.error(errmsg)
raise ConnectionSettingsLoadingError(errmsg) from e
return DataclassMapper(ConnectionSettings).map(loaded_settings)
Expand Down
48 changes: 37 additions & 11 deletions src/openjd/adaptor_runtime/_http/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,18 @@ def gen_socket_path(dir: str, base_name: str):
else:
socket_dir = os.path.realpath(base_dir)

# Check that the sticky bit is set if the dir is world writable
socket_dir_stat = os.stat(socket_dir)
if socket_dir_stat.st_mode & stat.S_IWOTH and not socket_dir_stat.st_mode & stat.S_ISVTX:
raise NoSocketPathFoundException(
f"Cannot use directory {socket_dir} because it is world writable and does not "
"have the sticky bit (restricted deletion flag) set"
)

if namespace:
socket_dir = os.path.join(socket_dir, namespace)

mkdir(socket_dir)

socket_path = gen_socket_path(socket_dir, base_socket_name)
try:
self.verify_socket_path(socket_path)
except NonvalidSocketPathException as e:
raise NoSocketPathFoundException(
f"Socket path '{socket_path}' failed verification: {e}"
) from e
mkdir(socket_dir)

return socket_path

Expand All @@ -150,7 +143,37 @@ def verify_socket_path(self, path: str) -> None: # pragma: no cover
pass


class LinuxSocketPaths(SocketPaths):
class WindowsSocketPaths(SocketPaths):
"""
Specialization for verifying socket paths on Windows systems.
"""

def verify_socket_path(self, path: str) -> None:
# TODO: Verify Windows permissions of parent directories are least privileged
pass


class UnixSocketPaths(SocketPaths):
"""
Specialization for verifying socket paths on Unix systems.
"""

def verify_socket_path(self, path: str) -> None:
# Walk up directories and check that the sticky bit is set if the dir is world writable
prev_path = path
curr_path = os.path.dirname(path)
while prev_path != curr_path and len(curr_path) > 0:
path_stat = os.stat(curr_path)
if path_stat.st_mode & stat.S_IWOTH and not path_stat.st_mode & stat.S_ISVTX:
raise NoSocketPathFoundException(
f"Cannot use directory {curr_path} because it is world writable and does not "
"have the sticky bit (restricted deletion flag) set"
)
prev_path = curr_path
curr_path = os.path.dirname(curr_path)


class LinuxSocketPaths(UnixSocketPaths):
"""
Specialization for socket paths in Linux systems.
"""
Expand All @@ -161,6 +184,7 @@ class LinuxSocketPaths(SocketPaths):
_socket_name_max_length = 108 - 1

def verify_socket_path(self, path: str) -> None:
super().verify_socket_path(path)
path_length = len(path.encode("utf-8"))
if path_length > self._socket_name_max_length:
raise NonvalidSocketPathException(
Expand All @@ -170,7 +194,7 @@ def verify_socket_path(self, path: str) -> None:
)


class MacOSSocketPaths(SocketPaths):
class MacOSSocketPaths(UnixSocketPaths):
"""
Specialization for socket paths in macOS systems.
"""
Expand All @@ -181,6 +205,7 @@ class MacOSSocketPaths(SocketPaths):
_socket_name_max_length = 104 - 1

def verify_socket_path(self, path: str) -> None:
super().verify_socket_path(path)
path_length = len(path.encode("utf-8"))
if path_length > self._socket_name_max_length:
raise NonvalidSocketPathException(
Expand All @@ -193,6 +218,7 @@ def verify_socket_path(self, path: str) -> None:
_os_map: dict[str, type[SocketPaths]] = {
OSName.LINUX: LinuxSocketPaths,
OSName.MACOS: MacOSSocketPaths,
OSName.WINDOWS: WindowsSocketPaths,
}


Expand Down
6 changes: 3 additions & 3 deletions src/openjd/adaptor_runtime/_utils/_secure_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_file_owner_in_windows(filepath: "StrOrBytesPath") -> str: # pragma: is-
Returns:
str: A string in the format 'DOMAIN\\Username' representing the file's owner.
"""
sd = win32security.GetFileSecurity(filepath, win32security.OWNER_SECURITY_INFORMATION)
sd = win32security.GetFileSecurity(str(filepath), win32security.OWNER_SECURITY_INFORMATION)
owner_sid = sd.GetSecurityDescriptorOwner()
name, domain, _ = win32security.LookupAccountSid(None, owner_sid)
return f"{domain}\\{name}"
Expand All @@ -108,13 +108,13 @@ def set_file_permissions_in_windows(filepath: "StrOrBytesPath") -> None: # prag
dacl.AddAccessAllowedAce(win32security.ACL_REVISION, win32con.DELETE, user_sid)

# Apply the DACL to the file
sd = win32security.GetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION)
sd = win32security.GetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION)
sd.SetSecurityDescriptorDacl(
1, # A flag that indicates the presence of a DACL in the security descriptor.
dacl, # An ACL structure that specifies the DACL for the security descriptor.
0, # Don't retrieve the default DACL
)
win32security.SetFileSecurity(filepath, win32security.DACL_SECURITY_INFORMATION, sd)
win32security.SetFileSecurity(str(filepath), win32security.DACL_SECURITY_INFORMATION, sd)


def _get_flags_from_mode_str(open_mode: str) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
FrontendRunner,
HTTPError,
)
from openjd.adaptor_runtime._background.loaders import ConnectionSettingsFileLoader
from openjd.adaptor_runtime._background.loaders import (
ConnectionSettingsLoadingError,
ConnectionSettingsFileLoader,
)
from openjd.adaptor_runtime._osname import OSName

mod_path = (Path(__file__).parent.parent).resolve()
Expand Down Expand Up @@ -77,7 +80,6 @@ def initialized_setup(
adaptor_module=sys.modules[AdaptorExample.__module__],
connection_file_path=connection_file_path,
)
conn_settings = ConnectionSettingsFileLoader(connection_file_path).load()

match = re.search("Started backend process. PID: ([0-9]+)", caplog.text)
assert match is not None
Expand All @@ -96,9 +98,16 @@ def initialized_setup(
# Once all handles are closed, the system automatically cleans up the named pipe.
if OSName.is_posix():
try:
os.remove(conn_settings.socket)
except FileNotFoundError:
pass # Already deleted
conn_settings = ConnectionSettingsFileLoader(connection_file_path).load()
except ConnectionSettingsLoadingError as e:
print(
f"Failed to load connection settings, socket file cleanup will be skipped: {e}"
)
else:
try:
os.remove(conn_settings.socket)
except FileNotFoundError:
pass # Already deleted

def test_init(
self,
Expand Down
Loading

0 comments on commit 302f267

Please sign in to comment.