Skip to content

Commit

Permalink
feat(installer)!: detect default AWS region on EC2 (#250)
Browse files Browse the repository at this point in the history
Signed-off-by: Jericho Tolentino <[email protected]>
  • Loading branch information
jericht authored Mar 25, 2024
1 parent 677fda6 commit 3db8685
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"typing_extensions ~= 4.8",
"psutil ~= 5.9",
"pydantic ~= 1.10.0",
"requests == 2.31.*",
]
requires-python = ">=3.7"

Expand Down
57 changes: 54 additions & 3 deletions src/deadline_worker_agent/installer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from argparse import ArgumentParser, Namespace
from pathlib import Path
from subprocess import CalledProcessError, run
import re
import requests
import sys
import sysconfig

Expand All @@ -18,6 +20,45 @@
}


def _get_ec2_region() -> Optional[str]:
"""
Gets the AWS region if running on EC2 by querying IMDS.
Returns None if region could not be detected.
"""
try:
# Create IMDSv2 token
token_response = requests.put(
url="http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "10"}, # 10 second expiry
)
token = token_response.text
if not token:
raise RuntimeError("Received empty IMDSv2 token")

# Get AZ
az_response = requests.get(
url="http://169.254.169.254/latest/meta-data/placement/availability-zone",
headers={"X-aws-ec2-metadata-token": token},
)
az = az_response.text
except Exception as e:
print(f"Failed to detect AWS region: {e}")
return None
else:
if not az:
print("AWS region could not be detected, received empty response from IMDS")
return None

match = re.match(r"^([a-z-]+-[0-9])([a-z])?$", az)
if not match:
print(
f"AWS region could not be detected, got unexpected availability zone from IMDS: {az}"
)
return None

return match.group(1)


def install() -> None:
"""Installer entrypoint for the AWS Deadline Cloud Worker Agent"""

Expand All @@ -28,6 +69,13 @@ def install() -> None:
arg_parser = get_argument_parser()
args = arg_parser.parse_args(namespace=ParsedCommandLineArguments)
scripts_path = Path(sysconfig.get_path("scripts"))

if args.region is None:
args.region = _get_ec2_region()
if args.region is None:
print("ERROR: Unable to detect AWS region. Please provide a value for --region.")
sys.exit(1)

if sys.platform == "win32":
installer_args: dict[str, Any] = dict(
farm_id=args.farm_id,
Expand Down Expand Up @@ -94,7 +142,7 @@ class ParsedCommandLineArguments(Namespace):

farm_id: str
fleet_id: str
region: str
region: Optional[str] = None
user: str
password: Optional[str] = None
group: Optional[str] = None
Expand Down Expand Up @@ -126,8 +174,11 @@ def get_argument_parser() -> ArgumentParser: # pragma: no cover
)
parser.add_argument(
"--region",
help='The AWS region of the AWS Deadline Cloud farm. Defaults to "us-west-2".',
default="us-west-2",
help=(
"The AWS region of the AWS Deadline Cloud farm. "
"If on EC2, this is optional and the region will be automatically detected. Otherwise, this option is required."
),
default=None,
)

# Windows local usernames are restricted to 20 characters in length.
Expand Down
10 changes: 7 additions & 3 deletions src/deadline_worker_agent/installer/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ farm_id=unset
fleet_id=unset
wa_user=$default_wa_user
confirm=""
region="us-west-2"
region="unset"
scripts_path="unset"
worker_agent_program="deadline-worker-agent"
client_library_program="deadline"
Expand All @@ -62,7 +62,7 @@ usage()
echo " --fleet-id FLEET_ID"
echo " The AWS Deadline Cloud Fleet ID that the Worker belongs to."
echo " --region REGION"
echo " The AWS region of the AWS Deadline Cloud farm. Defaults to $region."
echo " The AWS region of the AWS Deadline Cloud farm."
echo " --user USER"
echo " A user name that the AWS Deadline Cloud Worker Agent will run as. Defaults to $default_wa_user."
echo " --group GROUP"
Expand Down Expand Up @@ -197,7 +197,11 @@ else
set -e
fi

if [[ ! -z "${region}" ]] && [[ ! "${region}" =~ ^[a-z]+-[a-z]+-[0-9]+$ ]]; then
if [[ "${region}" == "unset" ]]; then
echo "ERROR: --region not specified"
usage
fi
if [[ ! "${region}" =~ ^[a-z]+-[a-z]+-([a-z]+-)?[0-9]+$ ]]; then
echo "ERROR: Not a valid value for --region: ${region}"
usage
fi
Expand Down
118 changes: 118 additions & 0 deletions test/unit/install/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Generator
from unittest.mock import MagicMock, patch
import sysconfig
import typing

import pytest

Expand Down Expand Up @@ -49,6 +50,7 @@ def expected_cmd(
parsed_args: ParsedCommandLineArguments,
platform: str,
) -> list[str]:
assert parsed_args.region is not None, "Region is required"
expected_cmd = [
"sudo",
str(installer_mod.INSTALLER_PATH[platform]),
Expand Down Expand Up @@ -212,3 +214,119 @@ def test_unsupported_platform_raises(platform: str, capsys: pytest.CaptureFixtur
capture = capsys.readouterr()

assert capture.out == f"ERROR: Unsupported platform {platform}\n"


class TestGetEc2Region:
"""Tests for _get_ec2_region function"""

@pytest.fixture(autouse=True)
def mock_requests_get(self) -> Generator[MagicMock, None, None]:
with patch.object(installer_mod.requests, "get") as m:
yield m

@pytest.fixture(autouse=True)
def mock_requests_put(self) -> Generator[MagicMock, None, None]:
with patch.object(installer_mod.requests, "put") as m:
yield m

def test_gets_ec2_region(self, mock_requests_get: MagicMock, mock_requests_put: MagicMock):
# GIVEN
region = "us-east-2"
az = f"{region}a"

mock_requests_get.return_value.text = az

# WHEN
actual = installer_mod._get_ec2_region()

# THEN
assert actual == region
mock_requests_put.assert_called_once_with(
url="http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "10"},
)
mock_requests_get.assert_called_once_with(
url="http://169.254.169.254/latest/meta-data/placement/availability-zone",
headers={"X-aws-ec2-metadata-token": mock_requests_put.return_value.text},
)

@pytest.mark.parametrize(
["put_side_effect", "get_side_effect"],
[
[Exception(), None], # token request fails
[None, Exception()], # az request fails
],
)
def test_fails_if_request_raises(
self,
put_side_effect: typing.Optional[Exception],
get_side_effect: typing.Optional[Exception],
mock_requests_put: MagicMock,
mock_requests_get: MagicMock,
capfd: pytest.CaptureFixture,
):
# GIVEN
if put_side_effect:
mock_requests_put.side_effect = put_side_effect
if get_side_effect:
mock_requests_get.side_effect = get_side_effect

# WHEN
retval = installer_mod._get_ec2_region()

# THEN
assert retval is None
out, _ = capfd.readouterr()
assert "Failed to detect AWS region: " in out

def test_raises_if_empty_token_received(
self,
mock_requests_put: MagicMock,
capfd: pytest.CaptureFixture,
):
# GIVEN
mock_requests_put.return_value.text = None

# WHEN
retval = installer_mod._get_ec2_region()

# THEN
assert retval is None
out, _ = capfd.readouterr()
assert "Failed to detect AWS region: Received empty IMDSv2 token" in out

def test_fails_if_empty_az_received(
self,
mock_requests_get: MagicMock,
capfd: pytest.CaptureFixture,
):
# GIVEN
mock_requests_get.return_value.text = ""

# WHEN
retval = installer_mod._get_ec2_region()

# THEN
assert retval is None
out, _ = capfd.readouterr()
assert "AWS region could not be detected, received empty response from IMDS" in out

def test_fails_if_nonvalid_az_received(
self,
mock_requests_get: MagicMock,
capfd: pytest.CaptureFixture,
):
# GIVEN
az = "Not-A-Region-Code-123"
mock_requests_get.return_value.text = az

# WHEN
retval = installer_mod._get_ec2_region()

# THEN
assert retval is None
out, _ = capfd.readouterr()
assert (
f"AWS region could not be detected, got unexpected availability zone from IMDS: {az}"
in out
)
1 change: 1 addition & 0 deletions test/unit/install/test_windows_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_start_windows_installer_fails_when_run_as_non_admin_user(
) -> None:
# GIVEN
is_user_an_admin.return_value = False
assert parsed_args.region is not None, "Region is required"

with (patch.object(installer_mod, "get_argument_parser") as mock_get_arg_parser,):
with pytest.raises(SystemExit):
Expand Down

0 comments on commit 3db8685

Please sign in to comment.