From a0fc35c419bba964ffa5146f8bb65c54064fc929 Mon Sep 17 00:00:00 2001 From: Josh Usiskin <56369778+jusiskin@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:33:02 -0600 Subject: [PATCH] feat: directly send cancel OS signals on Linux (#479) Signed-off-by: Josh Usiskin <56369778+jusiskin@users.noreply.github.com> --- pyproject.toml | 4 + .../installer/install.sh | 13 + src/deadline_worker_agent/linux/__init__.py | 1 + .../linux/capabilities.py | 236 ++++++++++++++ .../startup/entrypoint.py | 6 + test/e2e/test_cap_kill.py | 129 ++++++++ test/unit/linux/test_capabilities.py | 288 ++++++++++++++++++ 7 files changed, 677 insertions(+) create mode 100644 src/deadline_worker_agent/linux/__init__.py create mode 100644 src/deadline_worker_agent/linux/capabilities.py create mode 100644 test/e2e/test_cap_kill.py create mode 100644 test/unit/linux/test_capabilities.py diff --git a/pyproject.toml b/pyproject.toml index 8cef33b0..78144c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,6 +204,10 @@ fail_under = 78 "src/deadline_worker_agent/installer/win_installer.py" ] +"sys_platform != 'linux'" = [ + "src/deadline_worker_agent/linux/*.py", +] + [tool.coverage.coverage_conditional_plugin.rules] # This cannot be empty otherwise coverage-conditional-plugin crashes with: # AttributeError: 'NoneType' object has no attribute 'items' diff --git a/src/deadline_worker_agent/installer/install.sh b/src/deadline_worker_agent/installer/install.sh index fb38b59c..fb7b90cb 100755 --- a/src/deadline_worker_agent/installer/install.sh +++ b/src/deadline_worker_agent/installer/install.sh @@ -456,15 +456,28 @@ Environment=AWS_REGION=$region AWS_DEFAULT_REGION=$region EOF fi + ############################################################### + ############### NOTE FOR CODE REVIEWERS ################# + ############################################################### + # Review changes to AmbientCapabilities below carefully # + ############################################################### + cat >> /etc/systemd/system/deadline-worker.service < ctypes.c_int: # pragma: nocover + if result != 0: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +def _cap_get_proc_err_check( + result: cap_t, + func: Any, + args: Tuple[cap_t, cap_flag_t, ctypes.c_int, cap_value_ptr, cap_flag_value_t], +) -> cap_t: # pragma: nocover + if not result: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +def _cap_to_text_errcheck( + result: ctypes.c_char_p, + func: Any, + args: Tuple[cap_t, ssize_ptr_t], +) -> ctypes.c_char_p: # pragma: nocover + if not result: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +def _cap_get_flag_errcheck( + result: ctypes.c_int, func: Any, args: Tuple[cap_t, cap_value_t, cap_flag_t, cap_flag_value_ptr] +) -> ctypes.c_int: # pragma: nocover + if result != 0: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +@cache +def _get_libcap() -> Optional[ctypes.CDLL]: # pragma: nocover + if not sys.platform.startswith("linux"): + raise OSError(f"libcap is only available on Linux, but found platform: {sys.platform}") + + libcap_path = find_library("cap") + if not libcap_path: + return None + + libcap = ctypes.CDLL(libcap_path, use_errno=True) + + # https://man7.org/linux/man-pages/man3/cap_set_proc.3.html + libcap.cap_set_proc.restype = ctypes.c_int + libcap.cap_set_proc.argtypes = [ + ctypes.POINTER(Cap), + ] + libcap.cap_set_proc.errcheck = _cap_set_proc_err_check # type: ignore + + # https://man7.org/linux/man-pages/man3/cap_get_proc.3.html + libcap.cap_get_proc.restype = cap_t + libcap.cap_get_proc.argtypes = [] + libcap.cap_get_proc.errcheck = _cap_get_proc_err_check # type: ignore + + # https://man7.org/linux/man-pages/man3/cap_set_flag.3.html + libcap.cap_set_flag.restype = ctypes.c_int + libcap.cap_set_flag.argtypes = [ + cap_t, + cap_flag_t, + ctypes.c_int, + cap_value_ptr, + cap_flag_value_t, + ] + + # https://man7.org/linux/man-pages/man3/cap_get_flag.3.html + libcap.cap_get_flag.restype = ctypes.c_int + libcap.cap_get_flag.argtypes = ( + cap_t, + cap_value_t, + cap_flag_t, + cap_flag_value_ptr, + ) + libcap.cap_get_flag.errcheck = _cap_get_flag_errcheck # type: ignore + + # https://man7.org/linux/man-pages/man3/cap_to_text.3.html + libcap.cap_to_text.restype = ctypes.c_char_p + libcap.cap_to_text.argtypes = [ + cap_t, + ssize_ptr_t, + ] + libcap.cap_to_text.errcheck = _cap_to_text_errcheck # type: ignore + + return libcap + + +def _get_caps_str( + *, + libcap: ctypes.CDLL, + caps: cap_t, +) -> str: + cap_text = libcap.cap_to_text(caps, None).decode() + return cap_text + + +def _has_cap_kill_inheritable( + *, + libcap: ctypes.CDLL, + caps: cap_t, +) -> bool: + flag_value = cap_flag_value_t() + libcap.cap_get_flag(caps, CAP_KILL, CAP_INHERITABLE, ctypes.byref(flag_value)) + return flag_value.value == CAP_SET + + +def drop_kill_cap_from_inheritable() -> None: + if not sys.platform.startswith("linux"): + return + libcap = _get_libcap() + if not libcap: + logger.warning( + "Unable to locate libcap. The worker agent will run without Linux capability awareness." + ) + return + + caps = libcap.cap_get_proc() + caps_str = _get_caps_str(libcap=libcap, caps=caps) + if _has_cap_kill_inheritable(libcap=libcap, caps=caps): + logger.info( + "CAP_KILL was found in the thread's inheritable capability set (%s). Dropping CAP_KILL from the thread's inheritable capability set", + caps_str, + ) + cap_value_arr_t = cap_value_t * 1 + cap_value_arr = cap_value_arr_t() + cap_value_arr[0] = CAP_KILL + libcap.cap_set_flag( + caps, + CAP_INHERITABLE, + len(cap_value_arr), + cap_value_arr, + CAP_CLEAR, + ) + libcap.cap_set_proc(caps) + caps_str_after = _get_caps_str(libcap=libcap, caps=caps) + logger.info("Capabilites are: %s", caps_str_after) + else: + logger.info( + "CAP_KILL was not found in the thread's inheritable capability set (%s)", caps_str + ) + + +def main() -> None: + libcap = _get_libcap() + if not libcap: + print("ERROR: libcap not found") + sys.exit(1) + caps = libcap.cap_get_proc() + print(_get_caps_str(libcap=libcap, caps=caps)) + + +if __name__ == "__main__": + main() diff --git a/src/deadline_worker_agent/startup/entrypoint.py b/src/deadline_worker_agent/startup/entrypoint.py index 37753e0c..0f10ccfd 100644 --- a/src/deadline_worker_agent/startup/entrypoint.py +++ b/src/deadline_worker_agent/startup/entrypoint.py @@ -18,6 +18,7 @@ from ..api_models import WorkerStatus from ..boto import DEADLINE_BOTOCORE_CONFIG, OTHER_BOTOCORE_CONFIG, DeadlineClient from ..errors import ServiceShutdown +from ..linux.capabilities import drop_kill_cap_from_inheritable from ..log_sync.cloudwatch import stream_cloudwatch_logs from ..log_sync.loggers import ROOT_LOGGER, logger as log_sync_logger from ..worker import Worker @@ -80,6 +81,11 @@ def entrypoint(cli_args: Optional[list[str]] = None, *, stop: Optional[Event] = # Log the configuration (logs to DEBUG by default) config.log() + # If we have the CAP_KILL Linux capability, we must programmatically + # remove it from the inheritable capability set so it is not inherited + # by session action subprocesses + drop_kill_cap_from_inheritable() + # Register the Worker try: worker_bootstrap = bootstrap_worker(config=config) diff --git a/test/e2e/test_cap_kill.py b/test/e2e/test_cap_kill.py new file mode 100644 index 00000000..b2077e99 --- /dev/null +++ b/test/e2e/test_cap_kill.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" +This test module contains tests that verify the Worker agent removes CAP_KILL from its inheritable +capability set +""" + +import os +import re +from typing import Generator + +import boto3 +import botocore +import pytest + +from deadline_test_fixtures import ( + DeadlineClient, + EC2InstanceWorker, + Job, + TaskStatus, +) +from e2e.conftest import DeadlineResources + + +@pytest.fixture +def sleep_job_in_bg_pid(session_worker: EC2InstanceWorker) -> Generator[int, None, None]: + """Context manager that runs a sleep command in the background and yields the process ID of the + sleep process. The context-manager will do a best-effort to kill the sleep job when exiting the + context""" + + # Send SSM command to write and run a bash script + # The script creates a detached sleep process and outputs that process' PID + # This sleep process will run as the ssm-user which is different from the job user + result = session_worker.send_command( + " ; ".join( + [ + "echo '#!/bin/bash' > script.sh", + "echo 'set -euo pipefail' >> script.sh", + "echo 'nohup sleep 240 < /dev/null 2> /dev/null > /dev/null &' >> script.sh", + "echo 'echo $!' >> script.sh", + "chmod +x script.sh", + "./script.sh", + "rm script.sh", + ] + ) + ) + + # Capture the PID from the SSM command output + sleep_pid = int(result.stdout) + yield sleep_pid + + # Clean up the background sleep job if needed + try: + session_worker.send_command(f"kill -9 {sleep_pid} || true") + except Exception as e: + print(f"Failed to cleanup background sleep job {sleep_pid}: {e}") + + +@pytest.mark.skipif( + os.environ["OPERATING_SYSTEM"] == "windows", + reason="Linux specific test", +) +@pytest.mark.usefixtures("session_worker") +def test_cap_kill_not_inherited_by_running_jobs( + deadline_client: DeadlineClient, + deadline_resources: DeadlineResources, + sleep_job_in_bg_pid: int, +) -> None: + """Tests that the worker agent drops CAP_KILL from its inheritable capability set and that + session actions are not able to signal processes belonging to different OS users""" + + # WHEN + # Submit a job that tries to send a SIGTERM to the process owned by another user + job: Job = Job.submit( + client=deadline_client, + farm=deadline_resources.farm, + queue=deadline_resources.queue_a, + priority=98, + max_retries_per_task=1, + template={ + "specificationVersion": "jobtemplate-2023-09", + "name": "JobSessionActionTimeoutFail", + "steps": [ + { + "hostRequirements": { + "attributes": [ + { + "name": "attr.worker.os.family", + "allOf": [os.environ["OPERATING_SYSTEM"]], + } + ] + }, + "name": "Step0", + "script": { + "actions": { + "onRun": { + "command": "kill", + "args": [ + "-s", + "term", + str(sleep_job_in_bg_pid), + ], + "timeout": 1, # Times out in 1 second + "cancelation": { + "mode": "NOTIFY_THEN_TERMINATE", + "notifyPeriodInSeconds": 1, + }, + }, + }, + }, + }, + ], + }, + ) + job.wait_until_complete(client=deadline_client) + + # THEN + job.refresh_job_info(client=deadline_client) + assert job.task_run_status == TaskStatus.FAILED + logs_client = boto3.client( + "logs", + config=botocore.config.Config(retries={"max_attempts": 10, "mode": "adaptive"}), + ) + job.assert_single_task_log_contains( + deadline_client=deadline_client, + logs_client=logs_client, + expected_pattern=re.escape( + f"kill: sending signal to {sleep_job_in_bg_pid} failed: Operation not permitted" + ), + ) diff --git a/test/unit/linux/test_capabilities.py b/test/unit/linux/test_capabilities.py new file mode 100644 index 00000000..4b4be612 --- /dev/null +++ b/test/unit/linux/test_capabilities.py @@ -0,0 +1,288 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations +from typing import Generator, Any + +import pytest +import sys +from unittest.mock import MagicMock, patch, call, ANY + +from deadline_worker_agent.linux import capabilities as test_mod + + +pytestmark = pytest.mark.skipif(sys.platform != "linux", reason="Linux-specific tests") + + +@pytest.fixture +def libcap() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def caps() -> MagicMock: + return MagicMock() + + +@pytest.fixture(autouse=True) +def mock_get_libcap( + libcap: MagicMock, +) -> Generator[MagicMock, None, None]: + with patch.object(test_mod, "_get_libcap", return_value=libcap) as mock_get_libcap: + yield mock_get_libcap + + +@pytest.fixture +def mock_module_logger() -> Generator[MagicMock, None, None]: + with patch.object(test_mod, "logger") as mock_module_logger: + yield mock_module_logger + + +class TestGetCapsStr: + """Tests for _get_caps_str""" + + def test_success_case( + self, + libcap: MagicMock, + caps: MagicMock, + ) -> None: + # GIVEN + mock_cap_to_text: MagicMock = libcap.cap_to_text + mock_cap_to_text_return: MagicMock = mock_cap_to_text.return_value + mock_cap_to_text_return_decode: MagicMock = mock_cap_to_text_return.decode + + # WHEN + result = test_mod._get_caps_str(libcap=libcap, caps=caps) + + # THEN + mock_cap_to_text.assert_called_once_with(caps, None) + mock_cap_to_text_return_decode.assert_called_once_with() + assert result == mock_cap_to_text_return_decode.return_value + + def test_exception( + self, + libcap: MagicMock, + caps: MagicMock, + ) -> None: + """When libcap.cap_to_text raises an OSError it should not be handled""" + # GIVEN + mock_cap_to_text: MagicMock = libcap.cap_to_text + error_raised = OSError(5, "some error") + mock_cap_to_text.side_effect = [error_raised] + + # WHEN + def when() -> None: + test_mod._get_caps_str(libcap=libcap, caps=caps) + + # THEN + with pytest.raises(OSError) as raise_ctx: + when() + assert raise_ctx.value is error_raised + + +class TestHasCapKillInheritable: + """Test cases for _has_cap_kill_inheritable""" + + @pytest.mark.parametrize( + argnames="cap_get_flag_return_value", + argvalues=( + True, + False, + ), + ) + def test_behaviour( + self, + libcap: MagicMock, + cap_get_flag_return_value: bool, + caps: MagicMock, + ) -> None: + """Tests that _has_cap_kill_inheritable returns the correct value""" + # GIVEN + mock_cap_get_flag: MagicMock = libcap.cap_get_flag + with ( + patch.object(test_mod.ctypes, "byref") as mock_ctypes_byref, + patch.object(test_mod, "cap_flag_value_t") as mock_cap_flag_value_t, + ): + + def cap_get_flag_side_effect( + caps: test_mod.cap_t, + cap: int, + cap_set: int, + flag_value: Any, + ) -> None: + mock_cap_flag_value_t.return_value.value = cap_get_flag_return_value + + mock_cap_get_flag.side_effect = cap_get_flag_side_effect + + # WHEN + result = test_mod._has_cap_kill_inheritable( + libcap=libcap, + caps=caps, + ) + + # THEN + mock_cap_flag_value_t.assert_called_once_with() + mock_ctypes_byref.assert_called_once_with(mock_cap_flag_value_t.return_value) + mock_cap_get_flag.assert_called_once_with( + caps, + # Value for CAP_KILL + # See https://github.com/torvalds/linux/blob/28eb75e178d389d325f1666e422bc13bbbb9804c/include/uapi/linux/capability.h#L147 + 5, + # Value for CAP_INHERITABLE + # See https://ddnet.org/codebrowser/include/sys/capability.h.html#CAP_INHERITABLE + 2, + mock_ctypes_byref.return_value, + ) + assert result == cap_get_flag_return_value + + def test_exception( + self, + libcap: MagicMock, + caps: MagicMock, + ) -> None: + """Tests that when cap_get_flag returns an exception the exception is unhandled and + propagated to the caller""" + + # GIVEN + mock_cap_get_flag: MagicMock = libcap.cap_get_flag + exception_to_raise = OSError(3, "error msg") + mock_cap_get_flag.side_effect = [exception_to_raise] + + # WHEN + def when() -> None: + test_mod._has_cap_kill_inheritable( + libcap=libcap, + caps=caps, + ) + + # THEN + with pytest.raises(OSError) as raise_ctx: + when() + assert raise_ctx.value is exception_to_raise + + +class TestDropKillCapFromInheritable: + """Tests for drop_kill_cap_from_inheritable()""" + + def test_no_libcap_warns_and_continues( + self, + mock_get_libcap: MagicMock, + mock_module_logger: MagicMock, + ) -> None: + """Tests that when libcap is not found, the drop_kill_cap_from_inheritable function logs a + warning and continues""" + + # GIVEN + mock_get_libcap.return_value = None + module_logger_warning_mock: MagicMock = mock_module_logger.warning + + # WHEN + test_mod.drop_kill_cap_from_inheritable() + + # THEN + module_logger_warning_mock.assert_called_once_with( + "Unable to locate libcap. The worker agent will run without Linux capability awareness." + ) + + def test_has_cap_kill_inheritable( + self, + libcap: MagicMock, + caps: MagicMock, + mock_module_logger: MagicMock, + ) -> None: + """Tests that when CAP_KILL is in the thead's inheritable set, the + drop_kill_cap_from_inheritable() removes it""" + + # GIVEN + mock_cap_get_proc: MagicMock = libcap.cap_get_proc + mock_cap_set_flag: MagicMock = libcap.cap_set_flag + mock_cap_set_proc: MagicMock = libcap.cap_set_proc + mock_cap_get_proc.return_value = caps + cap_str_before = "before" + cap_str_after = "after" + module_logger_info_mock: MagicMock = mock_module_logger.info + with ( + patch.object( + test_mod, "_has_cap_kill_inheritable", return_value=True + ) as mock_has_cap_kill_inheritable, + patch.object( + test_mod, "_get_caps_str", side_effect=[cap_str_before, cap_str_after] + ) as mock_get_caps_str, + ): + # WHEN + test_mod.drop_kill_cap_from_inheritable() + + # THEN + mock_cap_get_proc.assert_called_once_with() + mock_get_caps_str.assert_has_calls( + [ + # cap str before + call(libcap=libcap, caps=caps), + # cap str after + call(libcap=libcap, caps=caps), + ] + ) + mock_has_cap_kill_inheritable.assert_called_once_with(libcap=libcap, caps=caps) + module_logger_info_mock.assert_has_calls( + [ + call( + "CAP_KILL was found in the thread's inheritable capability set (%s). Dropping CAP_KILL from the thread's inheritable capability set", + cap_str_before, + ), + call("Capabilites are: %s", cap_str_after), + ] + ) + mock_cap_set_flag.assert_called_once_with( + caps, + # CAP_INHERITABLE, see https://ddnet.org/codebrowser/include/sys/capability.h.html#cap_flag_t + 2, + # Number of caps cleared + 1, + ANY, + # CAP_CLEAR, see # See https://ddnet.org/codebrowser/include/sys/capability.h.html#cap_flag_value_t + 0, + ) + mock_cap_set_proc.assert_called_once_with(caps) + # Third arg is cap_value_arr_t (a C struct) containing the list of capabilities to clear from the capability set + assert len(mock_cap_set_flag.call_args.args[3]) == 1 + # CAP_KILL is 5, see https://github.com/torvalds/linux/blob/28eb75e178d389d325f1666e422bc13bbbb9804c/include/uapi/linux/capability.h#L147 + assert mock_cap_set_flag.call_args.args[3][0] == 5 + + def test_does_not_have_cap_kill_inheritable( + self, + libcap: MagicMock, + caps: MagicMock, + mock_module_logger: MagicMock, + ) -> None: + """Test that when drop_kill_cap_from_inheritable() does not detect CAP_KILL in the + inheritable capability set, it does not attempt to remove it and logs the capability + str""" + # GIVEN + mock_cap_get_proc: MagicMock = libcap.cap_get_proc + mock_cap_get_proc.return_value = caps + mock_cap_set_flag: MagicMock = libcap.cap_set_flag + mock_cap_set_proc: MagicMock = libcap.cap_set_proc + cap_str = "before" + module_logger_info_mock: MagicMock = mock_module_logger.info + + with ( + patch.object( + test_mod, "_has_cap_kill_inheritable", return_value=False + ) as mock_has_cap_kill_inheritable, + patch.object( + test_mod, + "_get_caps_str", + return_value=cap_str, + ) as mock_get_caps_str, + ): + # WHEN + test_mod.drop_kill_cap_from_inheritable() + + # THEN + mock_cap_get_proc.assert_called_once_with() + mock_get_caps_str.assert_called_once_with(libcap=libcap, caps=caps) + mock_has_cap_kill_inheritable.assert_called_once_with(libcap=libcap, caps=caps) + module_logger_info_mock.assert_called_once_with( + "CAP_KILL was not found in the thread's inheritable capability set (%s)", cap_str + ) + mock_cap_set_flag.assert_not_called() + mock_cap_set_proc.assert_not_called()