Skip to content

Commit

Permalink
feat: basic Windows support
Browse files Browse the repository at this point in the history
Signed-off-by: Graeme McHale <[email protected]>
  • Loading branch information
gmchale79 committed Oct 11, 2023
1 parent 5328059 commit f67f6ae
Show file tree
Hide file tree
Showing 23 changed files with 338 additions and 86 deletions.
11 changes: 8 additions & 3 deletions src/deadline_worker_agent/aws_credentials/aws_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

import stat

import os
import logging
from abc import ABC, abstractmethod
from configparser import ConfigParser
from pathlib import Path
from typing import Optional
from openjd.sessions import PosixSessionUser, SessionUser
from subprocess import run, DEVNULL, PIPE, STDOUT
from ..set_windows_permissions import grant_full_control

__all__ = [
"AWSConfig",
Expand All @@ -28,8 +29,12 @@ def _run_cmd_as(*, user: PosixSessionUser, cmd: list[str]) -> None:

def _setup_parent_dir(*, dir_path: Path, owner: SessionUser | None = None) -> None:
if owner is None:
create_perms: int = stat.S_IRWXU
dir_path.mkdir(mode=create_perms, exist_ok=True)
if os.name == "posix":
create_perms: int = stat.S_IRWXU
dir_path.mkdir(mode=create_perms, exist_ok=True)
else:
dir_path.mkdir(exist_ok=True)
grant_full_control(dir_path.name)
else:
assert isinstance(owner, PosixSessionUser)
_run_cmd_as(user=owner, cmd=["mkdir", "-p", str(dir_path)])
Expand Down
19 changes: 13 additions & 6 deletions src/deadline_worker_agent/aws_credentials/queue_boto3_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def __init__(
interrupt_event: Event,
worker_persistence_dir: Path,
) -> None:
if os.name != "posix":
raise NotImplementedError("Windows not supported.")
super().__init__()

self._deadline_client = deadline_client
Expand All @@ -110,7 +108,11 @@ def __init__(
self._credentials_filename = (
"aws_credentials" # note: .json extension added by JSONFileCache
)
self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.sh"

if os.name == "posix":
self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.sh"
else:
self._credentials_process_script_path = self._credential_dir / "get_aws_credentials.cmd"

self._aws_config = AWSConfig(self._os_user)
self._aws_credentials = AWSCredentials(self._os_user)
Expand Down Expand Up @@ -321,9 +323,14 @@ def _generate_credential_process_script(self) -> str:
Generates the bash script which generates the credentials as JSON output on STDOUT.
This script will be used by the installed credential process.
"""
return ("#!/bin/bash\nset -eu\ncat {0}\n").format(
(self._credential_dir / self._credentials_filename).with_suffix(".json")
)
if os.name == "posix":
return ("#!/bin/bash\nset -eu\ncat {0}\n").format(
(self._credential_dir / self._credentials_filename).with_suffix(".json")
)
else:
return ("@echo off\ntype {0}\n").format(
(self._credential_dir / self._credentials_filename).with_suffix(".json")
)

def _uninstall_credential_process(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
from __future__ import annotations

import os
import logging
from typing import Any, cast

Expand Down Expand Up @@ -34,8 +33,6 @@ def __init__(
config: Configuration,
worker_id: str,
) -> None:
if os.name != "posix":
raise NotImplementedError("Windows not supported.")
super().__init__()

self._bootstrap_session = bootstrap_session
Expand Down
13 changes: 9 additions & 4 deletions src/deadline_worker_agent/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
from .session_queue import SessionActionQueue, SessionActionStatus
from ..startup.config import ImpersonationOverrides
from ..utils import MappingWithCallbacks

from ..set_windows_permissions import grant_full_control
import subprocess

logger = LOGGER

Expand Down Expand Up @@ -219,7 +220,7 @@ def run(self) -> None:
The Worker begins by hydrating its assigned work using the UpdateWorkerSchedule API.
The scheduler then enters a loop of processing assigned actions - creating and deleting
Worker sessions as required. If no actions are assigned, the Worke idles for 5 seconds.
Worker sessions as required. If no actions are assigned, the Worker idles for 5 seconds.
If any action completes, finishes cancelation, or if the Worker is done idling, an
UpdateWorkerSchedule API request is made with any relevant changes specified in the request.
Expand Down Expand Up @@ -636,8 +637,12 @@ def _create_new_sessions(
if self._worker_logs_dir:
queue_log_dir = self._queue_log_dir_path(queue_id=session_spec["queueId"])
try:
queue_log_dir.mkdir(mode=stat.S_IRWXU, exist_ok=True)
except OSError:
if os.name == "posix":
queue_log_dir.mkdir(mode=stat.S_IRWXU, exist_ok=True)
else:
queue_log_dir.mkdir(exist_ok=True)
grant_full_control(queue_log_dir.name)
except (OSError, subprocess.CalledProcessError):
error_msg = (
f"Failed to create local session log directory on worker: {queue_log_dir}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def job_run_as_user_api_model_to_worker_agent(
)
else:
# TODO: windows support
raise NotImplementedError(f"{os.name} is not supported")
return None

return job_run_as_user

Expand Down
41 changes: 41 additions & 0 deletions src/deadline_worker_agent/set_windows_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from typing import Optional
import subprocess
import getpass


def grant_full_control(path: str, username: Optional[str] = None):
"""
Set permissions for a specified file or directory (and any child objects)
to give full control only to the specified user.
Args:
path (str): The path of the file or directory for which permissions will be set.
username (str, optional): The username for whom permissions will be granted. If none is
provided the current username will be used.
Example:
path = "C:\\example_directory_or_file"
username = "a_username"
grant_full_control(path, username)
"""

if not username:
username = getpass.getuser()

subprocess.run(
[
"icacls",
path,
# Remove any existing permissions
"/inheritance:r",
# OI - Contained objects will inherit
# CI - Sub-directories will inherit
# F - Full control
"/grant",
("{0}:(OI)(CI)(F)").format(username),
"/T", # Apply recursively for directories
],
check=True,
)
3 changes: 2 additions & 1 deletion src/deadline_worker_agent/startup/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
import logging as _logging
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(

settings = WorkerSettings(**settings_kwargs)

if settings.posix_job_user is not None:
if os.name == "posix" and settings.posix_job_user is not None:
user, group = self._get_user_and_group_from_posix_job_user(settings.posix_job_user)
self.impersonation = ImpersonationOverrides(
inactive=not settings.impersonation,
Expand Down
2 changes: 2 additions & 0 deletions src/deadline_worker_agent/startup/config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Optional
import sys
import os

from pydantic import BaseModel, BaseSettings, Field

Expand All @@ -20,6 +21,7 @@
DEFAULT_CONFIG_PATH: dict[str, Path] = {
"darwin": Path("/etc/amazon/deadline/worker.toml"),
"linux": Path("/etc/amazon/deadline/worker.toml"),
"win32": Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Config/worker.toml")),
}


Expand Down
7 changes: 6 additions & 1 deletion src/deadline_worker_agent/startup/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def detect_system_capabilities() -> Capabilities:
"linux": "linux",
"windows": "windows",
}
platform_machine = platform.machine().lower()
python_machine_to_openjd_cpu_arch = {"x86_64": "x86_64", "amd64": "x86_64"}
if openjd_os_family := python_system_to_openjd_os_family.get(platform_system):
attributes[AttributeCapabilityName("attr.worker.os.family")] = [openjd_os_family]
attributes[AttributeCapabilityName("attr.worker.cpu.arch")] = [platform.machine()]
if openjd_cpu_arch := python_machine_to_openjd_cpu_arch.get(platform_machine):
attributes[AttributeCapabilityName("attr.worker.cpu.arch")] = [openjd_cpu_arch]
else:
raise NotImplementedError(f"{platform_machine} not supported")
amounts[AmountCapabilityName("amount.worker.vcpu")] = float(psutil.cpu_count())
amounts[AmountCapabilityName("amount.worker.memory")] = float(psutil.virtual_memory().total) / (
1024.0**2
Expand Down
16 changes: 14 additions & 2 deletions src/deadline_worker_agent/startup/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
from .capabilities import Capabilities
from .config_file import ConfigFile

import os


# Default path for the worker's logs.
DEFAULT_POSIX_WORKER_LOGS_DIR = Path("/var/log/amazon/deadline")
DEFAULT_WINDOWS_WORKER_LOGS_DIR = Path(os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Logs"))
# Default path for the worker persistence directory.
# The persistence directory is expected to be located on a file-system that is local to the Worker
# Node. The Worker's ID and credentials are persisted and these should not be accessible by other
# Worker Nodes.
DEFAULT_POSIX_WORKER_PERSISTENCE_DIR = Path("/var/lib/deadline")
DEFAULT_WINDOWS_WORKER_PERSISTENCE_DIR = Path(
os.path.expandvars(r"%PROGRAMDATA%/Amazon/Deadline/Cache")
)


class WorkerSettings(BaseSettings):
Expand Down Expand Up @@ -80,8 +86,14 @@ class WorkerSettings(BaseSettings):
capabilities: Capabilities = Field(
default_factory=lambda: Capabilities(amounts={}, attributes={})
)
worker_logs_dir: Path = DEFAULT_POSIX_WORKER_LOGS_DIR
worker_persistence_dir: Path = DEFAULT_POSIX_WORKER_PERSISTENCE_DIR
worker_logs_dir: Path = (
DEFAULT_WINDOWS_WORKER_LOGS_DIR if os.name == "nt" else DEFAULT_POSIX_WORKER_LOGS_DIR
)
worker_persistence_dir: Path = (
DEFAULT_WINDOWS_WORKER_PERSISTENCE_DIR
if os.name == "nt"
else DEFAULT_POSIX_WORKER_PERSISTENCE_DIR
)
local_session_logs: bool = True

class Config:
Expand Down
14 changes: 9 additions & 5 deletions src/deadline_worker_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import signal
import os
import sys
import traceback
from concurrent.futures import Executor, Future, ThreadPoolExecutor, wait
Expand Down Expand Up @@ -108,8 +109,10 @@ def __init__(

signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
# TODO: Remove this once WA is stable or put behind a debug flag
signal.signal(signal.SIGUSR1, self._output_thread_stacks)

if os.name == "posix":
# TODO: Remove this once WA is stable or put behind a debug flag
signal.signal(signal.SIGUSR1, self._output_thread_stacks) # type: ignore

def _signal_handler(self, signum: int, frame: FrameType | None = None) -> None:
"""
Expand All @@ -134,7 +137,7 @@ def _output_thread_stacks(self, signum: int, frame: FrameType | None = None) ->
This signal is designated for application-defined behaviors. In our case, we want to output
stack traces for all running threads.
"""
if signum in (signal.SIGUSR1,):
if signum in (signal.SIGUSR1,): # type: ignore
logger.info(f"Received signal {signum}. Initiating application shutdown.")
# OUTPUT STACK TRACE FOR ALL THREADS
print("\n*** STACKTRACE - START ***\n", file=sys.stderr)
Expand All @@ -156,7 +159,7 @@ def id(self) -> str:

@property
def sessions(self) -> WorkerSessionCollection:
raise NotImplementedError("Worker.sessions property not implemeneted")
raise NotImplementedError("Worker.sessions property not implemented")

def run(self) -> None:
"""Runs the main Worker loop for processing sessions."""
Expand Down Expand Up @@ -373,7 +376,8 @@ def _get_spot_instance_shutdown_action_timeout(self, *, imdsv2_token: str) -> ti
logger.info(f"Spot {action} happening at {shutdown_time}")
# Spot gives the time in UTC with a trailing Z, but Python can't handle
# the Z so we strip it
shutdown_time = datetime.fromisoformat(shutdown_time[:-1]).astimezone(timezone.utc)
shutdown_time = datetime.fromisoformat(shutdown_time[:-1])
shutdown_time = shutdown_time.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
time_delta = shutdown_time - current_time
time_delta_seconds = int(time_delta.total_seconds())
Expand Down
34 changes: 25 additions & 9 deletions test/unit/aws_credentials/test_aws_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_setup_parent_dir,
)
from openjd.sessions import PosixSessionUser, SessionUser
import os


@pytest.fixture
Expand All @@ -27,9 +28,13 @@ def mock_run_cmd_as() -> Generator[MagicMock, None, None]:
yield mock_run_cmd_as


@pytest.fixture(params=(PosixSessionUser(user="some-user", group="some-group"), None))
def os_user(request: pytest.FixtureRequest) -> Optional[SessionUser]:
return request.param
@pytest.fixture
def os_user() -> Optional[SessionUser]:
if os.name == "posix":
return PosixSessionUser(user="user", group="group")
else:
# TODO: Revisit when Windows impersonation is added
return None


class TestSetupParentDir:
Expand All @@ -39,6 +44,10 @@ class TestSetupParentDir:
def dir_path(self) -> MagicMock:
return MagicMock()

@pytest.fixture
def set_windows_permissions(self) -> MagicMock:
return MagicMock()

def test_creates_dir(
self,
dir_path: MagicMock,
Expand All @@ -51,7 +60,8 @@ def test_creates_dir(
assert isinstance(os_user, PosixSessionUser) or os_user is None

# WHEN
_setup_parent_dir(dir_path=dir_path, owner=os_user)
with (patch.object(aws_configs_mod, "grant_full_control") as mock_grant_full_control,):
_setup_parent_dir(dir_path=dir_path, owner=os_user)

# THEN
if os_user:
Expand All @@ -62,10 +72,16 @@ def test_creates_dir(
)
mock_run_cmd_as.assert_any_call(user=os_user, cmd=["chmod", "770", str(dir_path)])
else:
mkdir.assert_called_once_with(
mode=0o700,
exist_ok=True,
)
if os.name == "posix":
mkdir.assert_called_once_with(
mode=0o700,
exist_ok=True,
)
else:
mkdir.assert_called_once_with(
exist_ok=True,
)
mock_grant_full_control.assert_called_once()

def test_sets_group_ownership(
self,
Expand Down Expand Up @@ -383,7 +399,7 @@ def test_write(

class TestAWSConfig(AWSConfigTestBase):
"""
Test class derrived from AWSConfigTestBase for AWSConfig.
Test class derived from AWSConfigTestBase for AWSConfig.
All tests are defined in the base class. This class defines the fixtures that feed into those tests.
"""
Expand Down
Loading

0 comments on commit f67f6ae

Please sign in to comment.