diff --git a/redash/query_runner/athena.py b/redash/query_runner/athena.py index d9d2736531..db13297caa 100644 --- a/redash/query_runner/athena.py +++ b/redash/query_runner/athena.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) ANNOTATE_QUERY = parse_boolean(os.environ.get('ATHENA_ANNOTATE_QUERY', 'true')) SHOW_EXTRA_SETTINGS = parse_boolean(os.environ.get('ATHENA_SHOW_EXTRA_SETTINGS', 'true')) +ASSUME_ROLE = parse_boolean(os.environ.get('ATHENA_ASSUME_ROLE', 'false')) OPTIONAL_CREDENTIALS = parse_boolean(os.environ.get('ATHENA_OPTIONAL_CREDENTIALS', 'true')) try: @@ -85,7 +86,7 @@ def configuration_schema(cls): }, }, 'required': ['region', 's3_staging_dir'], - 'order': ['region', 'aws_access_key', 'aws_secret_key', 's3_staging_dir', 'schema', 'work_group'], + 'order': ['region', 's3_staging_dir', 'schema', 'work_group'], 'secret': ['aws_secret_key'] } @@ -101,8 +102,29 @@ def configuration_schema(cls): }, }) - if not OPTIONAL_CREDENTIALS: - schema['required'] += ['aws_access_key', 'aws_secret_key'] + if ASSUME_ROLE: + del schema['properties']['aws_access_key'] + del schema['properties']['aws_secret_key'] + schema['secret'] = [] + + schema['order'].insert(1, 'iam_role') + schema['order'].insert(2, 'external_id') + schema['properties'].update({ + 'iam_role': { + 'type': 'string', + 'title': 'IAM role to assume', + }, + 'external_id': { + 'type': 'string', + 'title': 'External ID to be used while STS assume role', + }, + }) + else: + schema['order'].insert(1, 'aws_access_key') + schema['order'].insert(2, 'aws_secret_key') + + if not OPTIONAL_CREDENTIALS and not ASSUME_ROLE: + schema['required'] += ['aws_access_key', 'aws_secret_key'] return schema @@ -118,13 +140,30 @@ def annotate_query(cls): def type(cls): return "athena" - def __get_schema_from_glue(self): - client = boto3.client( - 'glue', - aws_access_key_id=self.configuration.get('aws_access_key', None), - aws_secret_access_key=self.configuration.get('aws_secret_key', None), - region_name=self.configuration['region'] + def _get_iam_credentials(self, user=None): + if ASSUME_ROLE: + role_session_name = 'redash' if user is None else user.email + sts = boto3.client('sts') + creds = sts.assume_role( + RoleArn=self.configuration.get('iam_role'), + RoleSessionName=role_session_name, + ExternalId=self.configuration.get('external_id') ) + return { + 'aws_access_key_id': creds['Credentials']['AccessKeyId'], + 'aws_secret_access_key': creds['Credentials']['SecretAccessKey'], + 'aws_session_token': creds['Credentials']['SessionToken'], + 'region_name': self.configuration['region'] + } + else: + return { + 'aws_access_key_id': self.configuration.get('aws_access_key', None), + 'aws_secret_access_key': self.configuration.get('aws_secret_key', None), + 'region_name': self.configuration['region'] + } + + def __get_schema_from_glue(self): + client = boto3.client('glue', **self._get_iam_credentials()) schema = {} database_paginator = client.get_paginator('get_databases') @@ -169,14 +208,12 @@ def get_schema(self, get_stats=False): def run_query(self, query, user): cursor = pyathena.connect( s3_staging_dir=self.configuration['s3_staging_dir'], - region_name=self.configuration['region'], - aws_access_key_id=self.configuration.get('aws_access_key', None), - aws_secret_access_key=self.configuration.get('aws_secret_key', None), schema_name=self.configuration.get('schema', 'default'), encryption_option=self.configuration.get('encryption_option', None), kms_key=self.configuration.get('kms_key', None), work_group=self.configuration.get('work_group', 'primary'), - formatter=SimpleFormatter()).cursor() + formatter=SimpleFormatter(), + **self._get_iam_credentials(user=user)).cursor() try: cursor.execute(query)