diff --git a/pyproject.toml b/pyproject.toml index 75f14790..fa4c53d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "typing_extensions ~= 4.8", "psutil ~= 5.9", "pydantic ~= 1.10.0", + "requests == 2.31.*", ] requires-python = ">=3.7" diff --git a/src/deadline_worker_agent/installer/__init__.py b/src/deadline_worker_agent/installer/__init__.py index e0be2d93..5fea8523 100644 --- a/src/deadline_worker_agent/installer/__init__.py +++ b/src/deadline_worker_agent/installer/__init__.py @@ -5,6 +5,8 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from subprocess import CalledProcessError, run +import re +import requests import sys import sysconfig @@ -18,6 +20,45 @@ } +def _get_ec2_region() -> Optional[str]: + """ + Gets the AWS region if running on EC2 by querying IMDS. + Returns None if region could not be detected. + """ + try: + # Create IMDSv2 token + token_response = requests.put( + url="http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "10"}, # 10 second expiry + ) + token = token_response.text + if not token: + raise RuntimeError("Received empty IMDSv2 token") + + # Get AZ + az_response = requests.get( + url="http://169.254.169.254/latest/meta-data/placement/availability-zone", + headers={"X-aws-ec2-metadata-token": token}, + ) + az = az_response.text + except Exception as e: + print(f"Failed to detect AWS region: {e}") + return None + else: + if not az: + print("AWS region could not be detected, received empty response from IMDS") + return None + + match = re.match(r"^([a-z-]+-[0-9])([a-z])?$", az) + if not match: + print( + f"AWS region could not be detected, got unexpected availability zone from IMDS: {az}" + ) + return None + + return match.group(1) + + def install() -> None: """Installer entrypoint for the AWS Deadline Cloud Worker Agent""" @@ -28,6 +69,13 @@ def install() -> None: arg_parser = get_argument_parser() args = arg_parser.parse_args(namespace=ParsedCommandLineArguments) scripts_path = Path(sysconfig.get_path("scripts")) + + if args.region is None: + args.region = _get_ec2_region() + if args.region is None: + print("ERROR: Unable to detect AWS region. Please provide a value for --region.") + sys.exit(1) + if sys.platform == "win32": installer_args: dict[str, Any] = dict( farm_id=args.farm_id, @@ -94,7 +142,7 @@ class ParsedCommandLineArguments(Namespace): farm_id: str fleet_id: str - region: str + region: Optional[str] = None user: str password: Optional[str] = None group: Optional[str] = None @@ -126,8 +174,11 @@ def get_argument_parser() -> ArgumentParser: # pragma: no cover ) parser.add_argument( "--region", - help='The AWS region of the AWS Deadline Cloud farm. Defaults to "us-west-2".', - default="us-west-2", + help=( + "The AWS region of the AWS Deadline Cloud farm. " + "If on EC2, this is optional and the region will be automatically detected. Otherwise, this option is required." + ), + default=None, ) # Windows local usernames are restricted to 20 characters in length. diff --git a/src/deadline_worker_agent/installer/install.sh b/src/deadline_worker_agent/installer/install.sh index fe62d8c0..db6f9e07 100755 --- a/src/deadline_worker_agent/installer/install.sh +++ b/src/deadline_worker_agent/installer/install.sh @@ -36,7 +36,7 @@ farm_id=unset fleet_id=unset wa_user=$default_wa_user confirm="" -region="us-west-2" +region="unset" scripts_path="unset" worker_agent_program="deadline-worker-agent" client_library_program="deadline" @@ -62,7 +62,7 @@ usage() echo " --fleet-id FLEET_ID" echo " The AWS Deadline Cloud Fleet ID that the Worker belongs to." echo " --region REGION" - echo " The AWS region of the AWS Deadline Cloud farm. Defaults to $region." + echo " The AWS region of the AWS Deadline Cloud farm." echo " --user USER" echo " A user name that the AWS Deadline Cloud Worker Agent will run as. Defaults to $default_wa_user." echo " --group GROUP" @@ -197,7 +197,11 @@ else set -e fi -if [[ ! -z "${region}" ]] && [[ ! "${region}" =~ ^[a-z]+-[a-z]+-[0-9]+$ ]]; then +if [[ "${region}" == "unset" ]]; then + echo "ERROR: --region not specified" + usage +fi +if [[ ! "${region}" =~ ^[a-z]+-[a-z]+-([a-z]+-)?[0-9]+$ ]]; then echo "ERROR: Not a valid value for --region: ${region}" usage fi diff --git a/test/unit/install/test_install.py b/test/unit/install/test_install.py index 2eac08b7..51a241f5 100644 --- a/test/unit/install/test_install.py +++ b/test/unit/install/test_install.py @@ -7,6 +7,7 @@ from typing import Generator from unittest.mock import MagicMock, patch import sysconfig +import typing import pytest @@ -49,6 +50,7 @@ def expected_cmd( parsed_args: ParsedCommandLineArguments, platform: str, ) -> list[str]: + assert parsed_args.region is not None, "Region is required" expected_cmd = [ "sudo", str(installer_mod.INSTALLER_PATH[platform]), @@ -212,3 +214,119 @@ def test_unsupported_platform_raises(platform: str, capsys: pytest.CaptureFixtur capture = capsys.readouterr() assert capture.out == f"ERROR: Unsupported platform {platform}\n" + + +class TestGetEc2Region: + """Tests for _get_ec2_region function""" + + @pytest.fixture(autouse=True) + def mock_requests_get(self) -> Generator[MagicMock, None, None]: + with patch.object(installer_mod.requests, "get") as m: + yield m + + @pytest.fixture(autouse=True) + def mock_requests_put(self) -> Generator[MagicMock, None, None]: + with patch.object(installer_mod.requests, "put") as m: + yield m + + def test_gets_ec2_region(self, mock_requests_get: MagicMock, mock_requests_put: MagicMock): + # GIVEN + region = "us-east-2" + az = f"{region}a" + + mock_requests_get.return_value.text = az + + # WHEN + actual = installer_mod._get_ec2_region() + + # THEN + assert actual == region + mock_requests_put.assert_called_once_with( + url="http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "10"}, + ) + mock_requests_get.assert_called_once_with( + url="http://169.254.169.254/latest/meta-data/placement/availability-zone", + headers={"X-aws-ec2-metadata-token": mock_requests_put.return_value.text}, + ) + + @pytest.mark.parametrize( + ["put_side_effect", "get_side_effect"], + [ + [Exception(), None], # token request fails + [None, Exception()], # az request fails + ], + ) + def test_fails_if_request_raises( + self, + put_side_effect: typing.Optional[Exception], + get_side_effect: typing.Optional[Exception], + mock_requests_put: MagicMock, + mock_requests_get: MagicMock, + capfd: pytest.CaptureFixture, + ): + # GIVEN + if put_side_effect: + mock_requests_put.side_effect = put_side_effect + if get_side_effect: + mock_requests_get.side_effect = get_side_effect + + # WHEN + retval = installer_mod._get_ec2_region() + + # THEN + assert retval is None + out, _ = capfd.readouterr() + assert "Failed to detect AWS region: " in out + + def test_raises_if_empty_token_received( + self, + mock_requests_put: MagicMock, + capfd: pytest.CaptureFixture, + ): + # GIVEN + mock_requests_put.return_value.text = None + + # WHEN + retval = installer_mod._get_ec2_region() + + # THEN + assert retval is None + out, _ = capfd.readouterr() + assert "Failed to detect AWS region: Received empty IMDSv2 token" in out + + def test_fails_if_empty_az_received( + self, + mock_requests_get: MagicMock, + capfd: pytest.CaptureFixture, + ): + # GIVEN + mock_requests_get.return_value.text = "" + + # WHEN + retval = installer_mod._get_ec2_region() + + # THEN + assert retval is None + out, _ = capfd.readouterr() + assert "AWS region could not be detected, received empty response from IMDS" in out + + def test_fails_if_nonvalid_az_received( + self, + mock_requests_get: MagicMock, + capfd: pytest.CaptureFixture, + ): + # GIVEN + az = "Not-A-Region-Code-123" + mock_requests_get.return_value.text = az + + # WHEN + retval = installer_mod._get_ec2_region() + + # THEN + assert retval is None + out, _ = capfd.readouterr() + assert ( + f"AWS region could not be detected, got unexpected availability zone from IMDS: {az}" + in out + ) diff --git a/test/unit/install/test_windows_installer.py b/test/unit/install/test_windows_installer.py index 495c1326..ced0f19d 100644 --- a/test/unit/install/test_windows_installer.py +++ b/test/unit/install/test_windows_installer.py @@ -61,6 +61,7 @@ def test_start_windows_installer_fails_when_run_as_non_admin_user( ) -> None: # GIVEN is_user_an_admin.return_value = False + assert parsed_args.region is not None, "Region is required" with (patch.object(installer_mod, "get_argument_parser") as mock_get_arg_parser,): with pytest.raises(SystemExit):