Skip to content

Commit

Permalink
fix: provide region in queue AWS configuration (#289)
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Usiskin <[email protected]>
  • Loading branch information
jusiskin authored Apr 3, 2024
1 parent 0587b60 commit efeecfe
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 51 deletions.
52 changes: 50 additions & 2 deletions src/deadline_worker_agent/aws_credentials/aws_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def __init__(
# finally, read the config
self._config_parser.read(self.path)

def install_credential_process(self, profile_name: str, script_path: Path) -> None:
def install_credential_process(
self,
profile_name: str,
script_path: Path,
) -> None:
"""
Installs a credential process given the profile name and script path
Expand All @@ -116,7 +120,7 @@ def install_credential_process(self, profile_name: str, script_path: Path) -> No
script_path (Path): The script to call in the process
"""
self._config_parser[self._get_profile_name(profile_name)] = {
"credential_process": str(script_path.absolute())
"credential_process": str(script_path.absolute()),
}
self._write()

Expand Down Expand Up @@ -173,13 +177,57 @@ class AWSConfig(_AWSConfigBase):
Implementation of _AWSConfigBase to represent the ~/.aws/config file
"""

_region: str

def __init__(
self,
*,
os_user: Optional[SessionUser],
parent_dir: Path,
region: str,
) -> None:
"""
Constructor for the AWSConfigBase class
Args:
os_user (Optional[SessionUser]): If non-None, then this is the os user to add read
permissions for. If None, then the only the process user will be able to read
the credentials files.
parent_dir (Path): The directory where the AWS config and credentials files will be
written to.
region (str): The target region where the credentials are for
"""
super(AWSConfig, self).__init__(
os_user=os_user,
parent_dir=parent_dir,
)
self._region = region

def _get_profile_name(self, profile_name: str) -> str:
return f"profile {profile_name}"

@property
def path(self) -> Path:
return self._parent_dir / "config"

def install_credential_process(
self,
profile_name: str,
script_path: Path,
) -> None:
"""
Installs a credential process given the profile name and script path
Args:
profile_name (str): The profile name to install under
script_path (Path): The script to call in the process
"""
self._config_parser[self._get_profile_name(profile_name)] = {
"credential_process": str(script_path.absolute()),
"region": self._region,
}
self._write()


class AWSCredentials(_AWSConfigBase):
"""
Expand Down
12 changes: 10 additions & 2 deletions src/deadline_worker_agent/aws_credentials/queue_boto3_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class QueueBoto3Session(BaseBoto3Session):
_role_arn: str
_os_user: Optional[SessionUser]
_interrupt_event: Event
_region: str

# Name of the profile written to the user's AWS configuration for the
# credentials process
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
os_user: Optional[SessionUser] = None,
interrupt_event: Event,
worker_persistence_dir: Path,
region: str,
) -> None:
super().__init__()

Expand All @@ -146,6 +148,7 @@ def __init__(
self._role_arn = role_arn
self._os_user = os_user
self._interrupt_event = interrupt_event
self._region = region

self._profile_name = f"deadline-{self._queue_id}"

Expand All @@ -162,9 +165,14 @@ def __init__(

self._create_credentials_directory(os_user)

self._aws_config = AWSConfig(os_user=self._os_user, parent_dir=self._credential_dir)
self._aws_config = AWSConfig(
os_user=self._os_user,
parent_dir=self._credential_dir,
region=self._region,
)
self._aws_credentials = AWSCredentials(
os_user=self._os_user, parent_dir=self._credential_dir
os_user=self._os_user,
parent_dir=self._credential_dir,
)

self._install_credential_process()
Expand Down
1 change: 1 addition & 0 deletions src/deadline_worker_agent/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,7 @@ def _get_queue_aws_credentials(
os_user=os_user,
interrupt_event=self._shutdown,
worker_persistence_dir=self._worker_persistence_dir,
region=self._boto_session.region_name,
)
except (DeadlineRequestWorkerOfflineError, DeadlineRequestUnrecoverableError):
# These are terminal errors for the Session. We need to fail it, without attempting,
Expand Down
132 changes: 86 additions & 46 deletions test/unit/aws_credentials/test_aws_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import pytest
from unittest.mock import patch, MagicMock, PropertyMock
from unittest.mock import ANY, patch, MagicMock, PropertyMock
from pathlib import Path
from typing import Type, Generator, Optional
from typing import Callable, Generator, Optional, cast

import deadline_worker_agent.aws_credentials.aws_configs as aws_configs_mod
from deadline_worker_agent.aws_credentials.aws_configs import (
Expand Down Expand Up @@ -46,6 +46,11 @@ def os_user(request: pytest.FixtureRequest) -> Optional[SessionUser]:
return None


@pytest.fixture
def region() -> str:
return "us-west-2"


class TestSetupFile:
"""Tests for the _setup_file() function"""

Expand Down Expand Up @@ -232,10 +237,9 @@ def parent_dir(self) -> MagicMock:

def test_init(
self,
config_class: Type[_AWSConfigBase],
create_config_class: Callable[[], _AWSConfigBase],
os_user: Optional[SessionUser],
mock_config_parser: MagicMock,
parent_dir: MagicMock,
) -> None:
# GIVEN
if os.name == "posix":
Expand All @@ -246,10 +250,7 @@ def test_init(

with patch.object(aws_configs_mod, "_setup_file") as setup_file_mock:
# WHEN
config = config_class(
os_user=os_user,
parent_dir=parent_dir,
)
config = create_config_class()

# THEN
setup_file_mock.assert_called_once_with(
Expand All @@ -260,7 +261,7 @@ def test_init(

def test_path(
self,
config_class: Type[_AWSConfigBase],
create_config_class: Callable[[], _AWSConfigBase],
expected_path: Path,
os_user: Optional[SessionUser],
parent_dir: MagicMock,
Expand All @@ -271,10 +272,7 @@ def test_path(
else:
assert isinstance(os_user, WindowsSessionUser) or os_user is None

config = config_class(
os_user=os_user,
parent_dir=parent_dir,
)
config = create_config_class()
result = config.path

# THEN
Expand All @@ -284,19 +282,18 @@ def test_path(
def test_install_credential_process(
self,
mock_absolute: MagicMock,
config_class: Type[_AWSConfigBase],
create_config_class: Callable[[], _AWSConfigBase],
profile_name: str,
expected_profile_name_section: str,
os_user: Optional[SessionUser],
mock_config_parser: MagicMock,
parent_dir: MagicMock,
) -> None:
# GIVEN
if os.name == "posix":
assert isinstance(os_user, PosixSessionUser) or os_user is None
else:
assert isinstance(os_user, WindowsSessionUser) or os_user is None
config = config_class(os_user=os_user, parent_dir=parent_dir)
config = create_config_class()
script_path = Path("/path/to/installdir/echo_them_credentials.sh")
with patch.object(config, "_write") as write_mock:
# WHEN
Expand All @@ -315,30 +312,24 @@ def test_install_credential_process(
def test_uninstall_credential_process(
self,
mock_absolute: MagicMock,
config_class: Type[_AWSConfigBase],
create_config_class: Callable[[], _AWSConfigBase],
profile_name: str,
expected_profile_name_section: str,
os_user: Optional[SessionUser],
mock_config_parser: MagicMock,
parent_dir: MagicMock,
) -> None:
# GIVEN
if os.name == "posix":
assert isinstance(os_user, PosixSessionUser) or os_user is None
else:
assert isinstance(os_user, WindowsSessionUser) or os_user is None
config = config_class(
os_user=os_user,
parent_dir=parent_dir,
)
config = create_config_class()
script_path = Path("/path/to/installdir/echo_them_credentials.sh")
with patch.object(config, "_write") as write_mock:
config.install_credential_process(profile_name=profile_name, script_path=script_path)
mock_config_parser.__setitem__.assert_called_once_with(
expected_profile_name_section,
{
"credential_process": mock_absolute.return_value.__str__.return_value,
},
ANY,
)
write_mock.assert_called_once_with()
write_mock.reset_mock()
Expand All @@ -353,25 +344,18 @@ def test_uninstall_credential_process(

def test_write(
self,
config_class: Type[_AWSConfigBase],
create_config_class: Callable[[], _AWSConfigBase],
os_user: Optional[SessionUser],
mock_config_parser: MagicMock,
parent_dir: MagicMock,
) -> None:
# GIVEN
if os.name == "posix":
assert isinstance(os_user, PosixSessionUser) or os_user is None
else:
assert isinstance(os_user, WindowsSessionUser) or os_user is None
with (
patch.object(aws_configs_mod, "_logger") as logger_mock,
patch.object(config_class, "path", new_callable=PropertyMock) as path_prop_mock,
):
path: MagicMock = path_prop_mock.return_value
config = config_class(
os_user=os_user,
parent_dir=parent_dir,
)
with patch.object(aws_configs_mod, "_logger") as logger_mock:
config = create_config_class()

info_mock: MagicMock = logger_mock.info

# WHEN
Expand All @@ -381,9 +365,9 @@ def test_write(
info_mock.assert_called_once()
assert isinstance(info_mock.call_args.args[0], FilesystemLogEvent)
assert info_mock.call_args.args[0].subtype == FilesystemLogEventOp.WRITE
path.open.assert_called_once_with(mode="w")
cast(MagicMock, config.path.open).assert_called_once_with(mode="w")
mock_config_parser.write.assert_called_once_with(
fp=path.open.return_value.__enter__.return_value,
fp=cast(MagicMock, config.path.open).return_value.__enter__.return_value,
space_around_delimiters=False,
)

Expand All @@ -396,16 +380,59 @@ class TestAWSConfig(AWSConfigTestBase):
"""

@pytest.fixture
def config_class(self) -> Type[_AWSConfigBase]:
return AWSConfig
def create_config_class(
self,
os_user: Optional[SessionUser],
parent_dir: Path,
region: str,
) -> Callable[[], _AWSConfigBase]:
def creator() -> AWSConfig:
return AWSConfig(
os_user=os_user,
parent_dir=parent_dir,
region=region,
)

return creator

@pytest.fixture
def expected_profile_name_section(self, profile_name: str) -> str:
return f"profile {profile_name}"

@pytest.fixture
def expected_path(self, parent_dir: MagicMock) -> str:
return parent_dir / "config"
def expected_path(
self,
parent_dir: MagicMock,
) -> str:
return parent_dir.__truediv__.return_value

@patch.object(aws_configs_mod.Path, "absolute")
def test_install_credential_process(
self,
mock_absolute: MagicMock,
create_config_class: Callable[[], _AWSConfigBase],
profile_name: str,
expected_profile_name_section: str,
mock_config_parser: MagicMock, # type: ignore[override]
region: str,
) -> None:
"""Tests that the region is added to the config file"""
# GIVEN
config = create_config_class()
script_path = Path("/path/to/installdir/echo_them_credentials.sh")
with patch.object(config, "_write") as write_mock:
# WHEN
config.install_credential_process(profile_name=profile_name, script_path=script_path)

# THEN
mock_config_parser.__setitem__.assert_called_once_with(
expected_profile_name_section,
{
"credential_process": mock_absolute.return_value.__str__.return_value,
"region": region,
},
)
write_mock.assert_called_once_with()


class TestAWSCredentials(AWSConfigTestBase):
Expand All @@ -416,13 +443,26 @@ class TestAWSCredentials(AWSConfigTestBase):
"""

@pytest.fixture
def config_class(self) -> Type[_AWSConfigBase]:
return AWSCredentials
def create_config_class(
self,
os_user: Optional[SessionUser],
parent_dir: Path,
) -> Callable[[], _AWSConfigBase]:
def creator() -> AWSCredentials:
return AWSCredentials(
os_user=os_user,
parent_dir=parent_dir,
)

return creator

@pytest.fixture
def expected_profile_name_section(self, profile_name: str) -> str:
return f"{profile_name}"

@pytest.fixture
def expected_path(self, parent_dir: MagicMock) -> str:
return parent_dir / "credentials"
def expected_path(
self,
parent_dir: PropertyMock,
) -> str:
return parent_dir.__truediv__.return_value
Loading

0 comments on commit efeecfe

Please sign in to comment.