diff --git a/src/deadline_worker_agent/aws_credentials/aws_configs.py b/src/deadline_worker_agent/aws_credentials/aws_configs.py index a9f3cd5d..c4356265 100644 --- a/src/deadline_worker_agent/aws_credentials/aws_configs.py +++ b/src/deadline_worker_agent/aws_credentials/aws_configs.py @@ -3,7 +3,7 @@ from __future__ import annotations import stat - +import os import logging from abc import ABC, abstractmethod from configparser import ConfigParser @@ -11,6 +11,7 @@ from typing import Optional from openjd.sessions import PosixSessionUser, SessionUser from subprocess import run, DEVNULL, PIPE, STDOUT +from ..set_windows_permissions import set_user_restricted_path_permissions __all__ = [ "AWSConfig", @@ -28,8 +29,12 @@ def _run_cmd_as(*, user: PosixSessionUser, cmd: list[str]) -> None: def _setup_parent_dir(*, dir_path: Path, owner: SessionUser | None = None) -> None: if owner is None: - create_perms: int = stat.S_IRWXU - dir_path.mkdir(mode=create_perms, exist_ok=True) + if os.name == "posix": + create_perms: int = stat.S_IRWXU + dir_path.mkdir(mode=create_perms, exist_ok=True) + else: + dir_path.mkdir(exist_ok=True) + set_user_restricted_path_permissions(dir_path.name) else: assert isinstance(owner, PosixSessionUser) _run_cmd_as(user=owner, cmd=["mkdir", "-p", str(dir_path)]) diff --git a/src/deadline_worker_agent/aws_credentials/queue_boto3_session.py b/src/deadline_worker_agent/aws_credentials/queue_boto3_session.py index e80eb79d..17eef653 100644 --- a/src/deadline_worker_agent/aws_credentials/queue_boto3_session.py +++ b/src/deadline_worker_agent/aws_credentials/queue_boto3_session.py @@ -91,8 +91,6 @@ def __init__( interrupt_event: Event, worker_persistence_dir: Path, ) -> None: - if os.name != "posix": - raise NotImplementedError("Windows not supported.") super().__init__() self._deadline_client = deadline_client @@ -110,7 +108,11 @@ def __init__( self._credentials_filename = ( "aws_credentials" # note: .json extension added by JSONFileCache ) - self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.sh" + + if os.name == "posix": + self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.sh" + else: + self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.cmd" self._aws_config = AWSConfig(self._os_user) self._aws_credentials = AWSCredentials(self._os_user) @@ -321,9 +323,14 @@ def _generate_credential_process_script(self) -> str: Generates the bash script which generates the credentials as JSON output on STDOUT. This script will be used by the installed credential process. """ - return ("#!/bin/bash\nset -eu\ncat {0}\n").format( - (self._credential_dir / self._credentials_filename).with_suffix(".json") - ) + if os.name == "posix": + return ("#!/bin/bash\nset -eu\ncat {0}\n").format( + (self._credential_dir / self._credentials_filename).with_suffix(".json") + ) + else: + return ('@echo off\ntype "{0}"\n').format( + (self._credential_dir / self._credentials_filename).with_suffix(".json") + ) def _uninstall_credential_process(self) -> None: """ diff --git a/src/deadline_worker_agent/aws_credentials/worker_boto3_session.py b/src/deadline_worker_agent/aws_credentials/worker_boto3_session.py index df452785..f1d20e57 100644 --- a/src/deadline_worker_agent/aws_credentials/worker_boto3_session.py +++ b/src/deadline_worker_agent/aws_credentials/worker_boto3_session.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from __future__ import annotations -import os import logging from typing import Any, cast @@ -34,8 +33,6 @@ def __init__( config: Configuration, worker_id: str, ) -> None: - if os.name != "posix": - raise NotImplementedError("Windows not supported.") super().__init__() self._bootstrap_session = bootstrap_session diff --git a/src/deadline_worker_agent/scheduler/scheduler.py b/src/deadline_worker_agent/scheduler/scheduler.py index c5377599..cecffc85 100644 --- a/src/deadline_worker_agent/scheduler/scheduler.py +++ b/src/deadline_worker_agent/scheduler/scheduler.py @@ -54,7 +54,8 @@ from .session_queue import SessionActionQueue, SessionActionStatus from ..startup.config import ImpersonationOverrides from ..utils import MappingWithCallbacks - +from ..set_windows_permissions import set_user_restricted_path_permissions +import subprocess logger = LOGGER @@ -220,7 +221,7 @@ def run(self) -> None: The Worker begins by hydrating its assigned work using the UpdateWorkerSchedule API. The scheduler then enters a loop of processing assigned actions - creating and deleting - Worker sessions as required. If no actions are assigned, the Worke idles for 5 seconds. + Worker sessions as required. If no actions are assigned, the Worker idles for 5 seconds. If any action completes, finishes cancelation, or if the Worker is done idling, an UpdateWorkerSchedule API request is made with any relevant changes specified in the request. @@ -628,8 +629,12 @@ def _create_new_sessions( if self._worker_logs_dir: queue_log_dir = self._queue_log_dir_path(queue_id=session_spec["queueId"]) try: - queue_log_dir.mkdir(mode=stat.S_IRWXU, exist_ok=True) - except OSError: + if os.name == "posix": + queue_log_dir.mkdir(mode=stat.S_IRWXU, exist_ok=True) + else: + queue_log_dir.mkdir(exist_ok=True) + set_user_restricted_path_permissions(queue_log_dir.name) + except (OSError, subprocess.CalledProcessError): error_msg = ( f"Failed to create local session log directory on worker: {queue_log_dir}" ) diff --git a/src/deadline_worker_agent/sessions/job_entities/job_details.py b/src/deadline_worker_agent/sessions/job_entities/job_details.py index b1e3bbb0..8dd1f7df 100644 --- a/src/deadline_worker_agent/sessions/job_entities/job_details.py +++ b/src/deadline_worker_agent/sessions/job_entities/job_details.py @@ -106,7 +106,7 @@ def job_run_as_user_api_model_to_worker_agent( ) else: # TODO: windows support - raise NotImplementedError(f"{os.name} is not supported") + return None return job_run_as_user diff --git a/src/deadline_worker_agent/set_windows_permissions.py b/src/deadline_worker_agent/set_windows_permissions.py new file mode 100644 index 00000000..adbbdcfb --- /dev/null +++ b/src/deadline_worker_agent/set_windows_permissions.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import Optional +import subprocess +import getpass + + +def set_user_restricted_path_permissions(path: str, username: Optional[str] = None): + """ + Set permissions for a specified file or directory (and any child objects) + to give full control only to the specified user. + + Args: + path (str): The path of the file or directory for which permissions will be set. + username (str, optional): The username for whom permissions will be granted. If none is + provided the current username will be used. + + Example: + path = "C:\\example_directory_or_file" + username = "a_username" + set_user_restricted_path_permissions(path, username) + """ + + if not username: + username = getpass.getuser() + + subprocess.run( + [ + "icacls", + path, + # Remove any existing permissions + "/inheritance:r", + # OI - Contained objects will inherit + # CI - Sub-directories will inherit + # F - Full control + "/grant", + ("{0}:(OI)(CI)(F)").format(username), + "/T", # Apply recursively for directories + ], + check=True, + ) diff --git a/src/deadline_worker_agent/startup/config.py b/src/deadline_worker_agent/startup/config.py index 4b7c90d7..1757b0a9 100644 --- a/src/deadline_worker_agent/startup/config.py +++ b/src/deadline_worker_agent/startup/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os import logging as _logging from dataclasses import dataclass from pathlib import Path @@ -131,7 +132,7 @@ def __init__( settings = WorkerSettings(**settings_kwargs) - if settings.posix_job_user is not None: + if os.name == "posix" and settings.posix_job_user is not None: user, group = self._get_user_and_group_from_posix_job_user(settings.posix_job_user) self.impersonation = ImpersonationOverrides( inactive=not settings.impersonation, diff --git a/src/deadline_worker_agent/startup/config_file.py b/src/deadline_worker_agent/startup/config_file.py index 99c2df69..6b0f6638 100644 --- a/src/deadline_worker_agent/startup/config_file.py +++ b/src/deadline_worker_agent/startup/config_file.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Optional import sys +import os from pydantic import BaseModel, BaseSettings, Field @@ -20,6 +21,7 @@ DEFAULT_CONFIG_PATH: dict[str, Path] = { "darwin": Path("/etc/amazon/deadline/worker.toml"), "linux": Path("/etc/amazon/deadline/worker.toml"), + "win32": Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Config/worker.toml")), } diff --git a/src/deadline_worker_agent/startup/entrypoint.py b/src/deadline_worker_agent/startup/entrypoint.py index d0bcb79c..8fec46ed 100644 --- a/src/deadline_worker_agent/startup/entrypoint.py +++ b/src/deadline_worker_agent/startup/entrypoint.py @@ -50,9 +50,14 @@ def detect_system_capabilities() -> Capabilities: "linux": "linux", "windows": "windows", } + platform_machine = platform.machine().lower() + python_machine_to_openjd_cpu_arch = {"x86_64": "x86_64", "amd64": "x86_64"} if openjd_os_family := python_system_to_openjd_os_family.get(platform_system): attributes[AttributeCapabilityName("attr.worker.os.family")] = [openjd_os_family] - attributes[AttributeCapabilityName("attr.worker.cpu.arch")] = [platform.machine()] + if openjd_cpu_arch := python_machine_to_openjd_cpu_arch.get(platform_machine): + attributes[AttributeCapabilityName("attr.worker.cpu.arch")] = [openjd_cpu_arch] + else: + raise NotImplementedError(f"{platform_machine} not supported") amounts[AmountCapabilityName("amount.worker.vcpu")] = float(psutil.cpu_count()) amounts[AmountCapabilityName("amount.worker.memory")] = float(psutil.virtual_memory().total) / ( 1024.0**2 diff --git a/src/deadline_worker_agent/startup/settings.py b/src/deadline_worker_agent/startup/settings.py index 8032060e..d94d8e8e 100644 --- a/src/deadline_worker_agent/startup/settings.py +++ b/src/deadline_worker_agent/startup/settings.py @@ -11,14 +11,20 @@ from .capabilities import Capabilities from .config_file import ConfigFile +import os + # Default path for the worker's logs. DEFAULT_POSIX_WORKER_LOGS_DIR = Path("/var/log/amazon/deadline") +DEFAULT_WINDOWS_WORKER_LOGS_DIR = Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Logs")) # Default path for the worker persistence directory. # The persistence directory is expected to be located on a file-system that is local to the Worker # Node. The Worker's ID and credentials are persisted and these should not be accessible by other # Worker Nodes. DEFAULT_POSIX_WORKER_PERSISTENCE_DIR = Path("/var/lib/deadline") +DEFAULT_WINDOWS_WORKER_PERSISTENCE_DIR = Path( + os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Cache") +) class WorkerSettings(BaseSettings): @@ -84,8 +90,14 @@ class WorkerSettings(BaseSettings): capabilities: Capabilities = Field( default_factory=lambda: Capabilities(amounts={}, attributes={}) ) - worker_logs_dir: Path = DEFAULT_POSIX_WORKER_LOGS_DIR - worker_persistence_dir: Path = DEFAULT_POSIX_WORKER_PERSISTENCE_DIR + worker_logs_dir: Path = ( + DEFAULT_WINDOWS_WORKER_LOGS_DIR if os.name == "nt" else DEFAULT_POSIX_WORKER_LOGS_DIR + ) + worker_persistence_dir: Path = ( + DEFAULT_WINDOWS_WORKER_PERSISTENCE_DIR + if os.name == "nt" + else DEFAULT_POSIX_WORKER_PERSISTENCE_DIR + ) local_session_logs: bool = True host_metrics_logging: bool = True host_metrics_logging_interval_seconds: float = 60 diff --git a/src/deadline_worker_agent/worker.py b/src/deadline_worker_agent/worker.py index 899f59b3..98c0feda 100644 --- a/src/deadline_worker_agent/worker.py +++ b/src/deadline_worker_agent/worker.py @@ -4,6 +4,7 @@ import json import signal +import os import sys import traceback from contextlib import nullcontext @@ -121,8 +122,10 @@ def __init__( signal.signal(signal.SIGTERM, self._signal_handler) signal.signal(signal.SIGINT, self._signal_handler) - # TODO: Remove this once WA is stable or put behind a debug flag - signal.signal(signal.SIGUSR1, self._output_thread_stacks) + + if os.name == "posix": + # TODO: Remove this once WA is stable or put behind a debug flag + signal.signal(signal.SIGUSR1, self._output_thread_stacks) # type: ignore def _signal_handler(self, signum: int, frame: FrameType | None = None) -> None: """ @@ -147,7 +150,7 @@ def _output_thread_stacks(self, signum: int, frame: FrameType | None = None) -> This signal is designated for application-defined behaviors. In our case, we want to output stack traces for all running threads. """ - if signum in (signal.SIGUSR1,): + if signum in (signal.SIGUSR1,): # type: ignore logger.info(f"Received signal {signum}. Initiating application shutdown.") # OUTPUT STACK TRACE FOR ALL THREADS print("\n*** STACKTRACE - START ***\n", file=sys.stderr) @@ -169,7 +172,7 @@ def id(self) -> str: @property def sessions(self) -> WorkerSessionCollection: - raise NotImplementedError("Worker.sessions property not implemeneted") + raise NotImplementedError("Worker.sessions property not implemented") def run(self) -> None: """Runs the main Worker loop for processing sessions.""" @@ -388,9 +391,10 @@ def _get_spot_instance_shutdown_action_timeout(self, *, imdsv2_token: str) -> ti ) return None logger.info(f"Spot {action} happening at {shutdown_time}") - # Spot gives the time in UTC with a trailing Z, but Python can't handle + # Spot gives the time in UTC with a trailing Z, but Prior to Python 3.11 Python can't handle # the Z so we strip it - shutdown_time = datetime.fromisoformat(shutdown_time[:-1]).astimezone(timezone.utc) + shutdown_time = datetime.fromisoformat(shutdown_time[:-1]) + shutdown_time = shutdown_time.replace(tzinfo=timezone.utc) current_time = datetime.now(timezone.utc) time_delta = shutdown_time - current_time time_delta_seconds = int(time_delta.total_seconds()) diff --git a/test/unit/aws_credentials/test_aws_configs.py b/test/unit/aws_credentials/test_aws_configs.py index 2323fa55..c6e98fcd 100644 --- a/test/unit/aws_credentials/test_aws_configs.py +++ b/test/unit/aws_credentials/test_aws_configs.py @@ -14,6 +14,7 @@ _setup_parent_dir, ) from openjd.sessions import PosixSessionUser, SessionUser +import os @pytest.fixture @@ -27,9 +28,13 @@ def mock_run_cmd_as() -> Generator[MagicMock, None, None]: yield mock_run_cmd_as -@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"), None)) -def os_user(request: pytest.FixtureRequest) -> Optional[SessionUser]: - return request.param +@pytest.fixture +def os_user() -> Optional[SessionUser]: + if os.name == "posix": + return PosixSessionUser(user="user", group="group") + else: + # TODO: Revisit when Windows impersonation is added + return None class TestSetupParentDir: @@ -39,6 +44,10 @@ class TestSetupParentDir: def dir_path(self) -> MagicMock: return MagicMock() + @pytest.fixture + def set_windows_permissions(self) -> MagicMock: + return MagicMock() + def test_creates_dir( self, dir_path: MagicMock, @@ -51,7 +60,12 @@ def test_creates_dir( assert isinstance(os_user, PosixSessionUser) or os_user is None # WHEN - _setup_parent_dir(dir_path=dir_path, owner=os_user) + with ( + patch.object( + aws_configs_mod, "set_user_restricted_path_permissions" + ) as mock_set_user_restricted_path_permissions, + ): + _setup_parent_dir(dir_path=dir_path, owner=os_user) # THEN if os_user: @@ -62,10 +76,16 @@ def test_creates_dir( ) mock_run_cmd_as.assert_any_call(user=os_user, cmd=["chmod", "770", str(dir_path)]) else: - mkdir.assert_called_once_with( - mode=0o700, - exist_ok=True, - ) + if os.name == "posix": + mkdir.assert_called_once_with( + mode=0o700, + exist_ok=True, + ) + else: + mkdir.assert_called_once_with( + exist_ok=True, + ) + mock_set_user_restricted_path_permissions.assert_called_once() def test_sets_group_ownership( self, @@ -383,7 +403,7 @@ def test_write( class TestAWSConfig(AWSConfigTestBase): """ - Test class derrived from AWSConfigTestBase for AWSConfig. + Test class derived from AWSConfigTestBase for AWSConfig. All tests are defined in the base class. This class defines the fixtures that feed into those tests. """ diff --git a/test/unit/aws_credentials/test_queue_boto3_session.py b/test/unit/aws_credentials/test_queue_boto3_session.py index 8748efad..42fc65ea 100644 --- a/test/unit/aws_credentials/test_queue_boto3_session.py +++ b/test/unit/aws_credentials/test_queue_boto3_session.py @@ -51,9 +51,13 @@ def deadline_client() -> MagicMock: return MagicMock() -@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"), None)) -def os_user(request: pytest.FixtureRequest) -> Optional[SessionUser]: - return request.param +@pytest.fixture +def os_user() -> Optional[SessionUser]: + if os.name == "posix": + return PosixSessionUser(user="user", group="group") + else: + # TODO: Revisit when Windows impersonation is added + return None class TestInit: @@ -614,9 +618,14 @@ def test_success( session._install_credential_process() # THEN - credentials_process_script_path = ( - Path(tmpdir) / "queues" / queue_id / "get_aws_credentials.sh" - ) + if os.name == "posix": + credentials_process_script_path = ( + Path(tmpdir) / "queues" / queue_id / "get_aws_credentials.sh" + ) + else: + credentials_process_script_path = ( + Path(tmpdir) / "queues" / queue_id / "get_aws_credentials.cmd" + ) mock_os_open.assert_called_once_with( path=str(credentials_process_script_path), flags=os.O_WRONLY | os.O_CREAT | os.O_TRUNC, diff --git a/test/unit/conftest.py b/test/unit/conftest.py index e5f34420..65ba98e2 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -51,16 +51,24 @@ def logs_client() -> MagicMock: return MagicMock() -@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"),)) -def posix_job_user(request: pytest.FixtureRequest) -> Optional[SessionUser]: - return request.param +@pytest.fixture() +def os_user() -> Optional[SessionUser]: + if os.name == "posix": + return PosixSessionUser(user="some-user", group="some-group") + else: + return None -@pytest.fixture(params=(False,)) +@pytest.fixture(params=[(os.name == "posix",)]) def impersonation( - request: pytest.FixtureRequest, posix_job_user: Optional[SessionUser] + request: pytest.FixtureRequest, os_user: Optional[SessionUser] ) -> ImpersonationOverrides: - return ImpersonationOverrides(inactive=request.param, posix_job_user=posix_job_user) + (posix_os,) = request.param + + if posix_os: + return ImpersonationOverrides(inactive=False, posix_job_user=os_user) + else: + return ImpersonationOverrides(inactive=True) @pytest.fixture @@ -242,7 +250,7 @@ def job_run_as_user() -> JobRunAsUser | None: """The OS user/group associated with the job's queue""" # TODO: windows support if os.name != "posix": - raise NotImplementedError(f"{os.name} is not supported") + return None return JobRunAsUser(posix=PosixSessionUser(user="job-user", group="job-user")) diff --git a/test/unit/install/test_install.py b/test/unit/install/test_install.py index e708e130..f0cd4dc0 100644 --- a/test/unit/install/test_install.py +++ b/test/unit/install/test_install.py @@ -5,7 +5,6 @@ from subprocess import CalledProcessError from typing import Generator, Optional from unittest.mock import MagicMock, patch -import os import sysconfig import pytest @@ -288,4 +287,4 @@ def test_unsupported_platform_raises(platform: str, capsys: pytest.CaptureFixtur assert raise_ctx.value.code == 1 capture = capsys.readouterr() - assert capture.out == f"ERROR: Unsupported platform {platform}{os.linesep}" + assert capture.out == f"ERROR: Unsupported platform {platform}\n" diff --git a/test/unit/log_sync/test_cloudwatch.py b/test/unit/log_sync/test_cloudwatch.py index 9740ed01..6290edb1 100644 --- a/test/unit/log_sync/test_cloudwatch.py +++ b/test/unit/log_sync/test_cloudwatch.py @@ -36,7 +36,7 @@ def mock_module_logger() -> Generator[MagicMock, None, None]: class TestCloudWatchLogEventBatch: @fixture(autouse=True) def now(self) -> datetime: - return datetime.fromtimestamp(123) + return datetime(2000, 1, 1) @fixture(autouse=True) def datetime_mock(self, now: datetime) -> Generator[MagicMock, None, None]: @@ -51,7 +51,7 @@ def datetime_mock(self, now: datetime) -> Generator[MagicMock, None, None]: @fixture def event(self, now: datetime) -> PartitionedCloudWatchLogEvent: return PartitionedCloudWatchLogEvent( - log_event=CloudWatchLogEvent(timestamp=int(now.timestamp()), message="abc"), + log_event=CloudWatchLogEvent(timestamp=int(now.timestamp() * 1000), message="abc"), size=len("abc".encode("utf-8")), ) @@ -168,10 +168,12 @@ def test_valid_log_event( datetime_mock: MagicMock, ): # GIVEN - now = datetime.fromtimestamp(1) + now = datetime(2000, 1, 1) datetime_mock.now.return_value = now event = PartitionedCloudWatchLogEvent( - log_event=CloudWatchLogEvent(message="abc", timestamp=int(now.timestamp())), + log_event=CloudWatchLogEvent( + message="abc", timestamp=(int(now.timestamp()) * 1000) + ), size=3, ) batch = CloudWatchLogEventBatch() diff --git a/test/unit/scheduler/test_scheduler.py b/test/unit/scheduler/test_scheduler.py index 9eda2926..03954e03 100644 --- a/test/unit/scheduler/test_scheduler.py +++ b/test/unit/scheduler/test_scheduler.py @@ -10,6 +10,7 @@ from openjd.sessions import ActionState, ActionStatus from botocore.exceptions import ClientError import pytest +import os from deadline_worker_agent.api_models import ( AssignedSession, @@ -506,6 +507,9 @@ def test_local_logging( session_log_file_path = MagicMock() with ( + patch.object( + scheduler_mod, "set_user_restricted_path_permissions" + ) as mock_set_user_restricted_path_permissions, patch.object(scheduler, "_executor"), patch.object(scheduler_mod.LogConfiguration, "from_boto") as mock_log_config_from_boto, patch.object( @@ -520,7 +524,11 @@ def test_local_logging( # THEN mock_queue_log_dir.assert_called_once_with(queue_id=queue_id) - queue_log_dir_path.mkdir.assert_called_once_with(mode=0o700, exist_ok=True) + if os.name == "posix": + queue_log_dir_path.mkdir.assert_called_once_with(mode=0o700, exist_ok=True) + else: + queue_log_dir_path.mkdir.assert_called_once_with(exist_ok=True) + mock_set_user_restricted_path_permissions.assert_called_once() mock_queue_session_log_file_path.assert_called_once_with( session_id=session_id, queue_log_dir=queue_log_dir_path ) @@ -587,6 +595,9 @@ def test_local_logging_os_error( with ( patch.object(scheduler, "_executor"), + patch.object( + scheduler_mod, "set_user_restricted_path_permissions" + ) as mock_set_user_restricted_path_permissions, patch.object(scheduler_mod.LogConfiguration, "from_boto") as mock_log_config_from_boto, patch.object( scheduler, "_queue_log_dir_path", return_value=queue_log_dir_path @@ -604,11 +615,17 @@ def test_local_logging_os_error( # THEN mock_queue_log_dir.assert_called_once_with(queue_id=queue_id) - queue_log_dir_path.mkdir.assert_called_once_with(mode=0o700, exist_ok=True) + if os.name == "posix": + queue_log_dir_path.mkdir.assert_called_once_with(mode=0o700, exist_ok=True) + else: + queue_log_dir_path.mkdir.assert_called_once_with(exist_ok=True) if mkdir_side_effect: mock_queue_session_log_file_path.assert_not_called() + mock_set_user_restricted_path_permissions.assert_not_called() else: mock_queue_session_log_file_path.assert_called_once() + if os.name != "posix": + mock_set_user_restricted_path_permissions.assert_called_once() if mkdir_side_effect: session_log_file_path.touch.asset_not_called() else: diff --git a/test/unit/scheduler/test_session_cleanup.py b/test/unit/scheduler/test_session_cleanup.py index b37da445..2a8280d3 100644 --- a/test/unit/scheduler/test_session_cleanup.py +++ b/test/unit/scheduler/test_session_cleanup.py @@ -8,6 +8,7 @@ from openjd.sessions import SessionUser, PosixSessionUser import pytest +import os from deadline_worker_agent.scheduler.session_cleanup import ( SessionUserCleanupManager, @@ -20,6 +21,11 @@ def __init__(self, user: str): self.user = user +class WindowsSessionUser(SessionUser): + def __init__(self, user: str): + self.user = user + + class TestSessionUserCleanupManager: @pytest.fixture def manager(self) -> SessionUserCleanupManager: @@ -34,8 +40,11 @@ def user_session_map_lock_mock( yield mock @pytest.fixture - def os_user(self) -> PosixSessionUser: - return PosixSessionUser(user="user", group="group") + def os_user(self) -> SessionUser: + if os.name == "posix": + return PosixSessionUser(user="user", group="group") + else: + return WindowsSessionUser(user="user") @pytest.fixture def session(self, os_user: PosixSessionUser) -> MagicMock: @@ -45,6 +54,7 @@ def session(self, os_user: PosixSessionUser) -> MagicMock: return session_stub class TestRegister: + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_registers_session( self, manager: SessionUserCleanupManager, @@ -99,6 +109,7 @@ def test_register_raises_windows_not_supported( user_session_map_lock_mock.__exit__.assert_not_called() class TestDeregister: + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_deregisters_session( self, manager: SessionUserCleanupManager, @@ -119,6 +130,7 @@ def test_deregisters_session( user_session_map_lock_mock.__enter__.assert_called_once() user_session_map_lock_mock.__exit__.assert_called_once() + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_deregister_skipped_no_user( self, manager: SessionUserCleanupManager, @@ -166,9 +178,10 @@ def cleanup_session_user_processes_mock(self) -> Generator[MagicMock, None, None with patch.object(SessionUserCleanupManager, "cleanup_session_user_processes") as mock: yield mock + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_calls_cleanup_session_user_processes( self, - os_user: PosixSessionUser, + os_user: SessionUser, manager: SessionUserCleanupManager, cleanup_session_user_processes_mock: MagicMock, ): @@ -178,9 +191,10 @@ def test_calls_cleanup_session_user_processes( # THEN cleanup_session_user_processes_mock.assert_called_once_with(os_user) + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_skips_cleanup_when_configured_to( self, - os_user: PosixSessionUser, + os_user: SessionUser, cleanup_session_user_processes_mock: MagicMock, ): # GIVEN @@ -196,27 +210,32 @@ class TestCleanupSessionUserProcesses: @pytest.fixture def agent_user( self, - os_user: PosixSessionUser, - ) -> PosixSessionUser: - return PosixSessionUser(user=f"agent_{os_user.user}", group=f"agent_{os_user.group}") + ) -> SessionUser: + if os.name == "posix": + return PosixSessionUser(user="agent_user", group="agent_group") + else: + return WindowsSessionUser(user="user") + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") @pytest.fixture(autouse=True) def subprocess_check_output_mock( self, - agent_user: PosixSessionUser, + agent_user: SessionUser, ) -> Generator[MagicMock, None, None]: with patch.object( session_cleanup_mod.subprocess, "check_output", - return_value=agent_user.user, + return_value=agent_user.user, # type: ignore ) as mock: yield mock + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") @pytest.fixture(autouse=True) def subprocess_run_mock(self) -> Generator[MagicMock, None, None]: with patch.object(session_cleanup_mod.subprocess, "run") as mock: yield mock + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_cleans_up_processes( self, os_user: PosixSessionUser, @@ -258,6 +277,7 @@ def test_not_posix_user( assert str(raised_err.value) == "Windows not supported" subprocess_run_mock.assert_not_called() + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_no_processes_to_clean_up( self, os_user: PosixSessionUser, @@ -278,6 +298,7 @@ def test_no_processes_to_clean_up( in caplog.text ) + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_fails_to_clean_up_processes( self, os_user: PosixSessionUser, @@ -298,6 +319,7 @@ def test_fails_to_clean_up_processes( assert f"Failed to stop processes running as '{os_user.user}': {err}" in caplog.text assert raised_err.value is err + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_skips_if_session_user_is_agent_user( self, subprocess_run_mock: MagicMock, diff --git a/test/unit/sessions/job_entities/test_job_entities.py b/test/unit/sessions/job_entities/test_job_entities.py index 93305af4..781d4efd 100644 --- a/test/unit/sessions/job_entities/test_job_entities.py +++ b/test/unit/sessions/job_entities/test_job_entities.py @@ -18,6 +18,7 @@ import pytest +import os from deadline_worker_agent.api_models import ( Attachments, @@ -174,6 +175,7 @@ def test_has_path_mapping_rules( assert job_details.path_mapping_rules not in (None, []) assert len(job_details.path_mapping_rules) == len(path_mapping_rules) + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_job_run_as_user(self) -> None: """Ensures that if we receive a job_run_as_user field in the response, that the created entity has a (Posix) SessionUser created with the @@ -250,10 +252,11 @@ def test_old_jobs_run_as_existence(self, jobs_run_as_data: dict[str, str]) -> No # THEN assert not hasattr(entity_obj, "jobs_run_as") - assert entity_obj.job_run_as_user is not None - assert isinstance(entity_obj.job_run_as_user.posix, PosixSessionUser) - assert entity_obj.job_run_as_user.posix.user == expected_user - assert entity_obj.job_run_as_user.posix.group == expected_group + if os.name == "posix": + assert entity_obj.job_run_as_user is not None + assert isinstance(entity_obj.job_run_as_user.posix, PosixSessionUser) + assert entity_obj.job_run_as_user.posix.user == expected_user + assert entity_obj.job_run_as_user.posix.group == expected_group # TODO: remove once service no longer sends jobsRunAs def test_only_old_jobs_run_as(self) -> None: @@ -280,10 +283,11 @@ def test_only_old_jobs_run_as(self) -> None: # THEN assert not hasattr(entity_obj, "jobs_run_as") - assert entity_obj.job_run_as_user is not None - assert isinstance(entity_obj.job_run_as_user.posix, PosixSessionUser) - assert entity_obj.job_run_as_user.posix.user == expected_user - assert entity_obj.job_run_as_user.posix.group == expected_group + if os.name == "posix": + assert entity_obj.job_run_as_user is not None + assert isinstance(entity_obj.job_run_as_user.posix, PosixSessionUser) + assert entity_obj.job_run_as_user.posix.user == expected_user + assert entity_obj.job_run_as_user.posix.group == expected_group # TODO: remove once service no longer sends jobsRunAs def test_only_empty_old_jobs_run_as(self) -> None: diff --git a/test/unit/sessions/test_session.py b/test/unit/sessions/test_session.py index 59543405..6bc1e91b 100644 --- a/test/unit/sessions/test_session.py +++ b/test/unit/sessions/test_session.py @@ -10,6 +10,8 @@ from unittest.mock import patch, MagicMock, ANY import pytest +import os + from openjd.model.v2023_09 import ( Action, Environment, @@ -52,9 +54,12 @@ import deadline_worker_agent.sessions.session as session_mod -@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"),)) -def os_user(request: pytest.FixtureRequest) -> Optional[SessionUser]: - return request.param +@pytest.fixture +def os_user() -> Optional[SessionUser]: + if os.name == "posix": + return PosixSessionUser(user="some-user", group="some-group") + else: + return None @pytest.fixture @@ -516,6 +521,7 @@ def mock_asset_sync(self, session: Session) -> Generator[MagicMock, None, None]: @pytest.mark.parametrize( "job_attachments_file_system", [e.value for e in JobAttachmentsFileSystem] ) + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") def test_asset_loading_method( self, session: Session, diff --git a/test/unit/startup/test_config.py b/test/unit/startup/test_config.py index 6ddb7fc4..a39e351a 100644 --- a/test/unit/startup/test_config.py +++ b/test/unit/startup/test_config.py @@ -9,8 +9,9 @@ from typing import Any, Generator, List, Optional import logging import pytest +import os -from openjd.sessions import PosixSessionUser +from openjd.sessions import SessionUser, PosixSessionUser from deadline_worker_agent.startup.cli_args import ParsedCommandLineArguments from deadline_worker_agent.startup import config as config_mod @@ -75,6 +76,14 @@ def arg_parser( yield arg_parser +@pytest.fixture +def os_user() -> Optional[SessionUser]: + if os.name == "posix": + return PosixSessionUser(user="user", group="group") + else: + return None + + class TestLoad: """Tests for Configuration.load()""" @@ -266,8 +275,18 @@ def test_uses_parsed_cleanup_session_user_processes( @pytest.mark.parametrize( argnames="worker_logs_dir", argvalues=( - Path("/foo"), - Path("/bar"), + pytest.param( + Path("/foo"), marks=pytest.mark.skipif(os.name != "posix", reason="Not posix") + ), + pytest.param( + Path("/bar"), marks=pytest.mark.skipif(os.name != "posix", reason="Not posix") + ), + pytest.param( + Path("D:\\foo"), marks=pytest.mark.skipif(os.name != "nt", reason="Not windows") + ), + pytest.param( + Path("D:\\bar"), marks=pytest.mark.skipif(os.name != "nt", reason="Not windows") + ), ), ) def test_uses_worker_logs_dir( @@ -289,8 +308,18 @@ def test_uses_worker_logs_dir( @pytest.mark.parametrize( argnames="persistence_dir", argvalues=( - Path("/foo"), - Path("/bar"), + pytest.param( + Path("/foo"), marks=pytest.mark.skipif(os.name != "posix", reason="Not posix") + ), + pytest.param( + Path("/bar"), marks=pytest.mark.skipif(os.name != "posix", reason="Not posix") + ), + pytest.param( + Path("D:\\foo"), marks=pytest.mark.skipif(os.name != "nt", reason="Not windows") + ), + pytest.param( + Path("D:\\bar"), marks=pytest.mark.skipif(os.name != "nt", reason="Not windows") + ), ), ) def test_uses_worker_persistence_dir( @@ -598,6 +627,7 @@ def test_impersonation_passed_to_settings_initializer( else: assert "impersonation" not in call.kwargs + @pytest.mark.skipif(os.name != "posix", reason="Posix-only test.") @pytest.mark.parametrize( argnames="posix_job_user", argvalues=("user:group", None), @@ -776,8 +806,9 @@ def test_local_session_logs_passed_to_settings_initializer( argvalues=( pytest.param( "user:group", - PosixSessionUser(group="group", user="user"), + "os_user", id="has-posix-job-user-setting", + marks=pytest.mark.skipif(os.name != "posix", reason="Posix-only test."), ), pytest.param( None, @@ -792,6 +823,7 @@ def test_uses_worker_settings( expected_config_posix_job_user: PosixSessionUser | None, parsed_args: ParsedCommandLineArguments, mock_worker_settings_cls: MagicMock, + request, ) -> None: """Tests that any parsed_cli_args without a value of None are passed as kwargs when creating a WorkerSettings instance""" @@ -822,9 +854,12 @@ def test_uses_worker_settings( assert config.capabilities is mock_worker_settings.capabilities assert config.impersonation.inactive == (not mock_worker_settings.impersonation) if expected_config_posix_job_user: + # TODO: This is needed because we are using a fixture with a parameterized call + # but let's revisit whether this can be simplified when Windows impersonation is added + posix_user: PosixSessionUser = request.getfixturevalue(expected_config_posix_job_user) assert isinstance(config.impersonation.posix_job_user, PosixSessionUser) - assert config.impersonation.posix_job_user.group == expected_config_posix_job_user.group - assert config.impersonation.posix_job_user.user == expected_config_posix_job_user.user + assert config.impersonation.posix_job_user.group == posix_user.group + assert config.impersonation.posix_job_user.user == posix_user.user else: assert config.impersonation.posix_job_user is None assert config.worker_logs_dir is mock_worker_settings.worker_logs_dir diff --git a/test/unit/startup/test_settings.py b/test/unit/startup/test_settings.py index 1f4c8e57..d6c709dd 100644 --- a/test/unit/startup/test_settings.py +++ b/test/unit/startup/test_settings.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, Mock, patch from typing import Any, Generator, NamedTuple, Type import pytest +import os from pathlib import Path from pydantic import ConstrainedStr @@ -105,14 +106,18 @@ class FieldTestCaseParams(NamedTuple): field_name="worker_logs_dir", expected_type=Path, expected_required=False, - expected_default=Path("/var/log/amazon/deadline"), + expected_default=Path("/var/log/amazon/deadline") + if os.name == "posix" + else Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Logs")), expected_default_factory_return_value=None, ), FieldTestCaseParams( field_name="worker_persistence_dir", expected_type=Path, expected_required=False, - expected_default=Path("/var/lib/deadline"), + expected_default=Path("/var/lib/deadline") + if os.name == "posix" + else Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Cache")), expected_default_factory_return_value=None, ), FieldTestCaseParams( diff --git a/test/unit/test_set_windows_permissions.py b/test/unit/test_set_windows_permissions.py new file mode 100644 index 00000000..60c7be0c --- /dev/null +++ b/test/unit/test_set_windows_permissions.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import unittest +from unittest.mock import patch +from deadline_worker_agent.set_windows_permissions import set_user_restricted_path_permissions + + +class TestSetWindowsPermissions(unittest.TestCase): + @patch("subprocess.run") + @patch("getpass.getuser", return_value="testuser") + def test_set_user_restricted_path_permissions_with_default_username( + self, mock_getuser, mock_subprocess_run + ): + # GIVEN + path = "C:\\example_directory_or_file" + + # WHEN + set_user_restricted_path_permissions(path) + + # THEN + expected_command = [ + "icacls", + path, + "/inheritance:r", + "/grant", + "{0}:(OI)(CI)(F)".format("testuser"), + "/T", + ] + mock_subprocess_run.assert_called_once_with(expected_command, check=True) + mock_getuser.assert_called_once() + + @patch("subprocess.run") + @patch("getpass.getuser") + def test_set_user_restricted_path_permissions_with_specific_username( + self, mock_getuser, mock_subprocess_run + ): + # GIVEN + path = "C:\\example_directory_or_file" + custom_username = "customuser" + + # WHEN + set_user_restricted_path_permissions(path, username=custom_username) + + # THEN + expected_command = [ + "icacls", + path, + "/inheritance:r", + "/grant", + "{0}:(OI)(CI)(F)".format(custom_username), + "/T", + ] + mock_subprocess_run.assert_called_once_with(expected_command, check=True) + mock_getuser.assert_not_called() + + +if __name__ == "__main__": + unittest.main()