Skip to content

Commit

Permalink
fix!: Update get_queue_user_boto3_session to work with default creds
Browse files Browse the repository at this point in the history
- Also noticed that the name `get_queue_boto3_session` was ambiguous,
  so clarified this name by changing it to
  `get_queue_user_boto3_session`.
- Change the default for the aws_profile_name to "(default)" as a more
  obvious name for the default credentials, matching error messages
  that e.g. the AWS CLI produces about this.
- Added a unit test confirming the correct value going to boto3.Session
- Update the config dialog to work properly with the default profile
  configuration, showing it as "(default)"

Signed-off-by: Mark Wiebe <[email protected]>
  • Loading branch information
mwiebe committed Sep 21, 2023
1 parent 05d06ec commit cb66de1
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/deadline/client/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"list_jobs",
"list_fleets",
"list_storage_profiles_for_queue",
"get_queue_boto3_session",
"get_queue_user_boto3_session",
"get_queue_parameter_definitions",
"get_telemetry_client",
]
Expand All @@ -31,7 +31,7 @@
from ._session import (
AwsCredentialsStatus,
AwsCredentialsType,
get_queue_boto3_session,
get_queue_user_boto3_session,
check_credentials_status,
get_boto3_client,
get_boto3_session,
Expand Down
54 changes: 24 additions & 30 deletions src/deadline/client/api/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,18 @@ def get_boto3_session(

# If the default AWS profile name is either not set, or set to "default",
# use the default credentials provider chain instead of a named profile.
if profile_name in ("", "default"):
if profile_name in ("(default)", "default", ""):
profile_name = None

# If a config was provided, don't use the Session caching mechanism.
if config:
return boto3.Session(profile_name=profile_name)

if force_refresh:
invalidate_boto3_session_cache()

# If this is the first call or the profile name has changed, make a fresh Session
if (
force_refresh
or not __cached_boto3_session
or __cached_boto3_session_profile_name != profile_name
):
if not __cached_boto3_session or __cached_boto3_session_profile_name != profile_name:
__cached_boto3_session = boto3.Session(profile_name=profile_name)
__cached_boto3_session_profile_name = profile_name

Expand All @@ -87,13 +86,19 @@ def get_boto3_session(

def invalidate_boto3_session_cache() -> None:
"""
Invalidates the cached boto3 session.
Invalidates the cached boto3 session and boto3 queue session.
"""
global __cached_boto3_session
global __cached_boto3_session_profile_name
global __cached_boto3_queue_session
global __cached_farm_id_for_queue_session
global __cached_queue_id_for_queue_session

__cached_boto3_session = None
__cached_boto3_session_profile_name = None
__cached_boto3_queue_session = None
__cached_farm_id_for_queue_session = None
__cached_queue_id_for_queue_session = None


def get_boto3_client(service_name: str, config: Optional[ConfigParser] = None) -> BaseClient:
Expand Down Expand Up @@ -164,7 +169,7 @@ def get_studio_id(
return profile_config.get("studio_id", None)


def get_queue_boto3_session(
def get_queue_user_boto3_session(
deadline: BaseClient,
config: Optional[ConfigParser] = None,
farm_id: Optional[str] = None,
Expand All @@ -189,7 +194,7 @@ def get_queue_boto3_session(
global __cached_farm_id_for_queue_session
global __cached_queue_id_for_queue_session

base_session = get_boto3_session(config=config)
base_session = get_boto3_session(config=config, force_refresh=force_refresh)

if farm_id is None:
farm_id = get_setting("defaults.farm_id")
Expand All @@ -198,18 +203,17 @@ def get_queue_boto3_session(

# If a config was provided, don't use the Session caching mechanism.
if config:
return _get_queue_boto3_session(
return _get_queue_user_boto3_session(
deadline, base_session, farm_id, queue_id, queue_display_name
)

# If this is the first call or the farm ID/queue ID has changed, make a fresh Session and cache it
if (
force_refresh
or not __cached_boto3_queue_session
not __cached_boto3_queue_session
or __cached_farm_id_for_queue_session != farm_id
or __cached_queue_id_for_queue_session != queue_id
):
__cached_boto3_queue_session = _get_queue_boto3_session(
__cached_boto3_queue_session = _get_queue_user_boto3_session(
deadline, base_session, farm_id, queue_id, queue_display_name
)

Expand All @@ -219,14 +223,14 @@ def get_queue_boto3_session(
return __cached_boto3_queue_session


def _get_queue_boto3_session(
def _get_queue_user_boto3_session(
deadline: BaseClient,
base_session: boto3.Session,
farm_id: str,
queue_id: str,
queue_display_name: Optional[str] = None,
):
queue_credential_provider = QueueCredentialProvider(
queue_credential_provider = QueueUserCredentialProvider(
deadline,
farm_id,
queue_id,
Expand All @@ -236,27 +240,17 @@ def _get_queue_boto3_session(
botocore_session = get_botocore_session()
credential_provider = botocore_session.get_component("credential_provider")
credential_provider.insert_before("env", queue_credential_provider)
aws_profile_name: Optional[str] = None
if base_session.profile_name != "default":
aws_profile_name = base_session.profile_name

return boto3.Session(
botocore_session=botocore_session,
profile_name=base_session.profile_name,
profile_name=aws_profile_name,
region_name=base_session.region_name,
)


def invalidate_boto3_queue_session_cache() -> None:
"""
Invalidates the cached boto3 queue session.
"""
global __cached_boto3_queue_session
global __cached_farm_id_for_queue_session
global __cached_queue_id_for_queue_session

__cached_boto3_queue_session = None
__cached_farm_id_for_queue_session = None
__cached_queue_id_for_queue_session = None


@contextmanager
def _modified_logging_level(logger, level):
old_level = logger.getEffectiveLevel()
Expand Down Expand Up @@ -427,7 +421,7 @@ def method(*args, **kwargs) -> Any:
return method


class QueueCredentialProvider(CredentialProvider):
class QueueUserCredentialProvider(CredentialProvider):
"""A custom botocore CredentialProvider for handling AssumeQueueRoleForUser API
credentials. If the credentials expire, the provider will automatically refresh
them using the _get_queue_credentials method.
Expand Down
2 changes: 1 addition & 1 deletion src/deadline/client/api/_submit_job_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def create_job_from_job_bundle(
)
asset_references.input_directories.clear()

queue_role_session = api.get_queue_boto3_session(
queue_role_session = api.get_queue_user_boto3_session(
deadline=deadline,
config=config,
farm_id=create_job_args["farmId"],
Expand Down
4 changes: 2 additions & 2 deletions src/deadline/client/cli/_groups/bundle_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from botocore.exceptions import ClientError # type: ignore[import]

from deadline.client import api
from deadline.client.api import get_boto3_client, get_queue_boto3_session
from deadline.client.api import get_boto3_client, get_queue_user_boto3_session
from deadline.client.api._session import _modified_logging_level
from deadline.client.config import config_file, get_setting, set_setting
from deadline.client.job_bundle.loader import read_yaml_or_json, read_yaml_or_json_object
Expand Down Expand Up @@ -172,7 +172,7 @@ def bundle_submit(job_bundle_dir, asset_loading_method, parameter, yes, **args):
)
asset_references.input_directories.clear()

queue_role_session = get_queue_boto3_session(
queue_role_session = get_queue_user_boto3_session(
deadline=deadline,
config=config,
farm_id=farm_id,
Expand Down
2 changes: 1 addition & 1 deletion src/deadline/client/cli/_groups/job_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _download_job_output(

queue = deadline.get_queue(farmId=farm_id, queueId=queue_id)

queue_role_session = api.get_queue_boto3_session(
queue_role_session = api.get_queue_user_boto3_session(
deadline=deadline,
config=config,
farm_id=farm_id,
Expand Down
2 changes: 1 addition & 1 deletion src/deadline/client/config/config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"description": "The filesystem path to Deadline Cloud Monitor, set during login process.",
},
"defaults.aws_profile_name": {
"default": "",
"default": "(default)",
"section_format": "profile-{}",
"description": "The AWS profile name to use by default. Set to '' to use the default credentials."
+ " Other settings are saved with the profile.",
Expand Down
7 changes: 5 additions & 2 deletions src/deadline/client/ui/dialogs/deadline_config_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _fill_aws_profiles_box(self):
# if the configured profile does not exist.
try:
session = boto3.Session()
aws_profile_names = session._session.full_config["profiles"].keys()
aws_profile_names = ["(default)", *(name for name in session._session.full_config["profiles"].keys() if name != "default")]
except ProfileNotFound:
logger.exception("Failed to create boto3.Session for AWS profile list")
aws_profile_names = f"{NOT_VALID_MARKER} <failed to retrieve AWS profile names>"
Expand Down Expand Up @@ -402,7 +402,10 @@ def refresh(self):
aws_profile_name = config_file.get_setting(
"defaults.aws_profile_name", config=self.config
)
if aws_profile_name not in self.aws_profile_names:
# Change the values representing the default to the UI value representing the default
if aws_profile_name in ("(default)", "default", ""):
aws_profile_name = "(default)"
elif aws_profile_name not in self.aws_profile_names:
aws_profile_name = f"{NOT_VALID_MARKER} {aws_profile_name}"
index = self.aws_profiles_box.findText(aws_profile_name, Qt.MatchFixedString)
if index >= 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def on_submit(self):

queue = deadline.get_queue(farmId=farm_id, queueId=queue_id)

queue_role_session = api.get_queue_boto3_session(
queue_role_session = api.get_queue_user_boto3_session(
deadline=deadline,
farm_id=farm_id,
queue_id=queue_id,
Expand Down
39 changes: 31 additions & 8 deletions test/unit/deadline_client/api/test_api_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tests the deadline.client.api functions relating to boto3.Client
"""

from unittest.mock import call, patch, MagicMock
from unittest.mock import call, patch, MagicMock, ANY

import boto3 # type: ignore[import]
import pytest
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_get_check_credentials_status_configuration_error(fresh_deadline_config)
assert api.check_credentials_status() == api.AwsCredentialsStatus.CONFIGURATION_ERROR


def test_get_queue_boto3_session_cache(fresh_deadline_config):
def test_get_queue_user_boto3_session_cache(fresh_deadline_config):
session_mock = MagicMock()
session_mock.profile_name = "test_profile"
session_mock.region_name = "us-west-2"
Expand All @@ -107,23 +107,46 @@ def test_get_queue_boto3_session_cache(fresh_deadline_config):

with patch.object(api._session, "get_boto3_session", return_value=session_mock), patch(
"botocore.session.Session", return_value=mock_botocore_session
), patch.object(api._session, "_get_queue_boto3_session") as _get_queue_boto3_session_mock:
_ = api.get_queue_boto3_session(
), patch.object(api._session, "_get_queue_user_boto3_session") as _get_queue_user_boto3_session_mock:
_ = api.get_queue_user_boto3_session(
deadline_mock, farm_id="farm-1234", queue_id="queue-1234", queue_display_name="queue"
)
# Same farm ID and queue ID, returns cached session
_ = api.get_queue_boto3_session(
_ = api.get_queue_user_boto3_session(
deadline_mock, farm_id="farm-1234", queue_id="queue-1234", queue_display_name="queue"
)
# Different queue ID, makes a fresh session
_ = api.get_queue_boto3_session(
_ = api.get_queue_user_boto3_session(
deadline_mock, farm_id="farm-1234", queue_id="queue-5678", queue_display_name="queue"
)
# Different queue ID, makes a fresh session
_ = api.get_queue_boto3_session(
_ = api.get_queue_user_boto3_session(
deadline_mock, farm_id="farm-5678", queue_id="queue-1234", queue_display_name="queue"
)
assert _get_queue_boto3_session_mock.call_count == 3
assert _get_queue_user_boto3_session_mock.call_count == 3


def test_get_queue_user_boto3_session_no_profile(fresh_deadline_config):
"""Make sure that boto3.Session gets called with profile_name=None for the default profile."""
session_mock = MagicMock()
# The value returned when no profile was selected is "default"
session_mock.profile_name = "default"
session_mock.region_name = "us-west-2"
deadline_mock = MagicMock()
mock_botocore_session = MagicMock()
mock_botocore_session.get_config_variable = (
lambda name: "default" if name == "profile" else None
)

with patch.object(api._session, "get_boto3_session", return_value=session_mock), patch(
"botocore.session.Session", return_value=mock_botocore_session
), patch("boto3.Session") as boto3_session_mock:
api.get_queue_user_boto3_session(
deadline_mock, farm_id="farm-1234", queue_id="queue-1234", queue_display_name="queue"
)
boto3_session_mock.assert_called_once_with(
botocore_session=ANY, profile_name=None, region_name="us-west-2"
)


def test_check_deadline_api_available(fresh_deadline_config):
Expand Down
4 changes: 2 additions & 2 deletions test/unit/deadline_client/api/test_job_bundle_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_create_job_from_job_bundle_job_attachments(
# Use a temporary directory for the job bundle
with patch.object(_submit_job_bundle.api, "get_boto3_session"), patch.object(
_submit_job_bundle.api, "get_boto3_client"
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_boto3_session"), patch.object(
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_user_boto3_session"), patch.object(
api._submit_job_bundle, "_hash_attachments"
) as mock_hash_attachments, patch.object(
S3AssetManager, "upload_assets"
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_create_job_from_job_bundle_with_single_asset_file(
# Use a temporary directory for the job bundle
with patch.object(_submit_job_bundle.api, "get_boto3_session"), patch.object(
_submit_job_bundle.api, "get_boto3_client"
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_boto3_session"), patch.object(
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_user_boto3_session"), patch.object(
api._submit_job_bundle, "_hash_attachments"
) as mock_hash_attachments, patch.object(
S3AssetManager, "upload_assets"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_create_job_from_job_bundle_with_all_asset_ref_variants(
# Use a temporary directory for the job bundle
with patch.object(_submit_job_bundle.api, "get_boto3_session"), patch.object(
_submit_job_bundle.api, "get_boto3_client"
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_boto3_session"), patch.object(
) as client_mock, patch.object(_submit_job_bundle.api, "get_queue_user_boto3_session"), patch.object(
S3AssetManager, "hash_assets_and_create_manifest"
) as mock_hash_assets, patch.object(
S3AssetManager, "upload_assets"
Expand Down
10 changes: 6 additions & 4 deletions test/unit/deadline_client/cli/test_cli_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_cli_bundle_submit(fresh_deadline_config, temp_job_bundle_dir):
) as qp_boto3_client_mock, patch.object(
bundle_group, "_hash_attachments", return_value=[]
), patch.object(
bundle_group, "get_queue_boto3_session"
bundle_group, "get_queue_user_boto3_session"
), patch.object(
bundle_group, "_upload_attachments"
), patch.object(
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_cli_bundle_asset_load_method(fresh_deadline_config, temp_job_bundle_dir
), patch.object(
bundle_group.api, "get_boto3_session"
), patch.object(
bundle_group, "get_queue_boto3_session"
bundle_group, "get_queue_user_boto3_session"
), patch.object(
bundle_group.api, "get_telemetry_client"
):
Expand Down Expand Up @@ -412,7 +412,9 @@ def test_cli_bundle_accept_upload_confirmation(fresh_deadline_config, temp_job_b
), patch.object(bundle_group, "_upload_attachments"), patch.object(
bundle_group.api, "get_boto3_session"
), patch.object(
bundle_group, "get_queue_boto3_session"
bundle_group.api, "get_queue_parameter_definitions", return_value=[]
), patch.object(
bundle_group, "get_queue_user_boto3_session"
), patch.object(
bundle_group.api, "get_telemetry_client"
):
Expand Down Expand Up @@ -485,7 +487,7 @@ def test_cli_bundle_reject_upload_confirmation(fresh_deadline_config, temp_job_b
) as upload_attachments_mock, patch.object(
bundle_group.api, "get_boto3_session"
), patch.object(
bundle_group, "get_queue_boto3_session"
bundle_group, "get_queue_user_boto3_session"
), patch.object(
bundle_group.api, "get_telemetry_client"
):
Expand Down
4 changes: 2 additions & 2 deletions test/unit/deadline_client/cli/test_cli_handle_web_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_cli_handle_web_url_download_output_only_required_input(fresh_deadline_c
with patch.object(api, "get_boto3_client") as boto3_client_mock, patch.object(
job_group, "OutputDownloader"
) as MockOutputDownloader, patch.object(job_group, "round", return_value=0), patch.object(
api, "get_queue_boto3_session"
api, "get_queue_user_boto3_session"
):
mock_download = MagicMock()
MockOutputDownloader.return_value.download_job_output = mock_download
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_cli_handle_web_url_download_output_with_optional_input(fresh_deadline_c
with patch.object(api, "get_boto3_client") as boto3_client_mock, patch.object(
job_group, "OutputDownloader"
) as MockOutputDownloader, patch.object(job_group, "round", return_value=0), patch.object(
api, "get_queue_boto3_session"
api, "get_queue_user_boto3_session"
):
mock_download = MagicMock()
MockOutputDownloader.return_value.download_job_output = mock_download
Expand Down
Loading

0 comments on commit cb66de1

Please sign in to comment.