diff --git a/ap/auth/views/views.py b/ap/auth/views/views.py index 3a16d5d5..e2692e1b 100644 --- a/ap/auth/views/views.py +++ b/ap/auth/views/views.py @@ -60,6 +60,8 @@ def get(self, request): self._login_success(request, user, token) return redirect("/") except OAuthError as error: + if settings.DEBUG: + raise error sentry_sdk.capture_exception(error) return self._login_failure() diff --git a/ap/aws/base.py b/ap/aws/base.py index ec108129..34a84e9a 100644 --- a/ap/aws/base.py +++ b/ap/aws/base.py @@ -11,7 +11,7 @@ class AWSService: aws_service_name: str = "" def __init__(self, assume_role_name=None, profile_name=None, region_name=None): - self.assume_role_name = assume_role_name or settings.DEFAULT_ROLE_ARN + self.assume_role_name = assume_role_name self.profile_name = profile_name self.region_name = region_name or settings.AWS_DEFAULT_REGION diff --git a/ap/aws/glue.py b/ap/aws/glue.py index 4a201649..49449bc7 100644 --- a/ap/aws/glue.py +++ b/ap/aws/glue.py @@ -10,6 +10,12 @@ def __init__(self, catalog_id=None): super().__init__() self.catalog_id = catalog_id or settings.GLUE_CATALOG_ID + def get_database_list(self): + databases = self._request("get_databases") + if not databases: + return [] + return databases["DatabaseList"] + def get_table_list(self, database_name): tables = self._request("get_tables", CatalogId=self.catalog_id, DatabaseName=database_name) if not tables: diff --git a/ap/aws/session.py b/ap/aws/session.py index 64c8feeb..59b000a6 100644 --- a/ap/aws/session.py +++ b/ap/aws/session.py @@ -12,7 +12,10 @@ class BotoSession: def __init__(self, assume_role_name=None, profile_name=None, region_name=None): - self.assume_role_name = assume_role_name or settings.DEFAULT_ROLE_ARN + self.assume_role_name = assume_role_name + if self.assume_role_name is None and settings.DEFAULT_ROLE_ARN is not None: + self.assume_role_name = settings.DEFAULT_ROLE_ARN + self.region_name = region_name or settings.AWS_DEFAULT_REGION self.profile_name = profile_name @@ -30,7 +33,7 @@ def refreshable_credentials(self): def get_sts_credentials(self) -> dict: log.info("Getting credentials using STS") - boto3_ini_session = boto3.Session(region_name=self.region_name) + boto3_ini_session = boto3.Session(region_name=settings.AWS_DEFAULT_REGION) sts = boto3_ini_session.client("sts") response = sts.assume_role( RoleArn=self.assume_role_name, diff --git a/ap/database_access/views.py b/ap/database_access/views.py index a69f1fd0..d8c23a3d 100644 --- a/ap/database_access/views.py +++ b/ap/database_access/views.py @@ -11,7 +11,7 @@ class DatabaseListView(OIDCLoginRequiredMixin, TemplateView): def get_context_data(self, **kwargs: Any) -> dict[str, Any]: context = super().get_context_data(**kwargs) - context["databases"] = aws.GlueService().client.get_databases()["DatabaseList"] + context["databases"] = aws.GlueService().get_database_list() return context diff --git a/ap/settings/common.py b/ap/settings/common.py index f802c13c..d17be6eb 100644 --- a/ap/settings/common.py +++ b/ap/settings/common.py @@ -320,6 +320,7 @@ QUICKSIGHT_DOMAINS = os.environ.get("QUICKSIGHT_DOMAINS", "").split(",") -DEFAULT_ROLE_ARN = os.environ.get("DEFAULT_ROLE_ARN", None) +# should not be required when using a service role e.g. in dev/prod +DEFAULT_STS_ROLE_TO_ASSUME = os.environ.get("DEFAULT_STS_ROLE_TO_ASSUME", None) AWS_DEFAULT_REGION = os.environ.get("AWS_DEFAULT_REGION", "eu-west-2")