Skip to content

Commit

Permalink
Update default behaviour when retreiving STS creds
Browse files Browse the repository at this point in the history
We only want to default to using STS to retrieve AWS credentials when
not using a service role that already has permission to call services.
These changes attempt to make that more explicit. For now, we should
only default to using STS when running locally.
  • Loading branch information
michaeljcollinsuk committed Aug 13, 2024
1 parent 00e618c commit 1ad9ff0
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 5 deletions.
2 changes: 2 additions & 0 deletions ap/auth/views/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def get(self, request):
self._login_success(request, user, token)
return redirect("/")
except OAuthError as error:
if settings.DEBUG:
raise error
sentry_sdk.capture_exception(error)
return self._login_failure()

Expand Down
2 changes: 1 addition & 1 deletion ap/aws/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class AWSService:
aws_service_name: str = ""

def __init__(self, assume_role_name=None, profile_name=None, region_name=None):
self.assume_role_name = assume_role_name or settings.DEFAULT_ROLE_ARN
self.assume_role_name = assume_role_name
self.profile_name = profile_name
self.region_name = region_name or settings.AWS_DEFAULT_REGION

Expand Down
6 changes: 6 additions & 0 deletions ap/aws/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def __init__(self, catalog_id=None):
super().__init__()
self.catalog_id = catalog_id or settings.GLUE_CATALOG_ID

def get_database_list(self):
databases = self._request("get_databases")
if not databases:
return []
return databases["DatabaseList"]

def get_table_list(self, database_name):
tables = self._request("get_tables", CatalogId=self.catalog_id, DatabaseName=database_name)
if not tables:
Expand Down
7 changes: 5 additions & 2 deletions ap/aws/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

class BotoSession:
def __init__(self, assume_role_name=None, profile_name=None, region_name=None):
self.assume_role_name = assume_role_name or settings.DEFAULT_ROLE_ARN
self.assume_role_name = assume_role_name
if self.assume_role_name is None and settings.DEFAULT_ROLE_ARN is not None:
self.assume_role_name = settings.DEFAULT_ROLE_ARN

self.region_name = region_name or settings.AWS_DEFAULT_REGION
self.profile_name = profile_name

Expand All @@ -30,7 +33,7 @@ def refreshable_credentials(self):

def get_sts_credentials(self) -> dict:
log.info("Getting credentials using STS")
boto3_ini_session = boto3.Session(region_name=self.region_name)
boto3_ini_session = boto3.Session(region_name=settings.AWS_DEFAULT_REGION)
sts = boto3_ini_session.client("sts")
response = sts.assume_role(
RoleArn=self.assume_role_name,
Expand Down
2 changes: 1 addition & 1 deletion ap/database_access/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DatabaseListView(OIDCLoginRequiredMixin, TemplateView):

def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs)
context["databases"] = aws.GlueService().client.get_databases()["DatabaseList"]
context["databases"] = aws.GlueService().get_database_list()
return context


Expand Down
3 changes: 2 additions & 1 deletion ap/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@

QUICKSIGHT_DOMAINS = os.environ.get("QUICKSIGHT_DOMAINS", "").split(",")

DEFAULT_ROLE_ARN = os.environ.get("DEFAULT_ROLE_ARN", None)
# should not be required when using a service role e.g. in dev/prod
DEFAULT_STS_ROLE_TO_ASSUME = os.environ.get("DEFAULT_STS_ROLE_TO_ASSUME", None)

AWS_DEFAULT_REGION = os.environ.get("AWS_DEFAULT_REGION", "eu-west-2")

0 comments on commit 1ad9ff0

Please sign in to comment.