From cd3ad3f2e2108dfd2c416765c33a2e6dca576846 Mon Sep 17 00:00:00 2001 From: Niels Zeilemaker Date: Sun, 22 Oct 2017 20:01:18 +0200 Subject: [PATCH] [AIRFLOW-1520] Boto3 S3Hook, S3Log Closes #2532 from NielsZeilemaker/AIRFLOW-1520 --- airflow/contrib/hooks/aws_hook.py | 122 ++++++--- airflow/hooks/S3_hook.py | 402 ++++++++---------------------- tests/utils/log/test_logging.py | 5 + 3 files changed, 208 insertions(+), 321 deletions(-) diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 61d0eb425e788..ca2ee054e2d97 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -14,11 +14,64 @@ import boto3 +import configparser from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook +def _parse_s3_config(config_file_name, config_format='boto', profile=None): + """ + Parses a config file for s3 credentials. Can currently + parse boto, s3cmd.conf and AWS SDK config formats + + :param config_file_name: path to the config file + :type config_file_name: str + :param config_format: config type. One of "boto", "s3cmd" or "aws". + Defaults to "boto" + :type config_format: str + :param profile: profile name in AWS type config file + :type profile: str + """ + Config = configparser.ConfigParser() + if Config.read(config_file_name): # pragma: no cover + sections = Config.sections() + else: + raise AirflowException("Couldn't read {0}".format(config_file_name)) + # Setting option names depending on file format + if config_format is None: + config_format = 'boto' + conf_format = config_format.lower() + if conf_format == 'boto': # pragma: no cover + if profile is not None and 'profile ' + profile in sections: + cred_section = 'profile ' + profile + else: + cred_section = 'Credentials' + elif conf_format == 'aws' and profile is not None: + cred_section = profile + else: + cred_section = 'default' + # Option names + if conf_format in ('boto', 'aws'): # pragma: no cover + key_id_option = 'aws_access_key_id' + secret_key_option = 'aws_secret_access_key' + # security_token_option = 'aws_security_token' + else: + key_id_option = 'access_key' + secret_key_option = 'secret_key' + # Actual Parsing + if cred_section not in sections: + raise AirflowException("This config file format is not recognized") + else: + try: + access_key = Config.get(cred_section, key_id_option) + secret_key = Config.get(cred_section, secret_key_option) + except: + logging.warning("Option Error in parsing s3 config file") + raise + return (access_key, secret_key) + + class AwsHook(BaseHook): """ Interact with AWS. @@ -28,46 +81,59 @@ class AwsHook(BaseHook): def __init__(self, aws_conn_id='aws_default'): self.aws_conn_id = aws_conn_id - def get_client_type(self, client_type, region_name=None): - try: - connection_object = self.get_connection(self.aws_conn_id) - aws_access_key_id = connection_object.login - aws_secret_access_key = connection_object.password + def _get_credentials(self, region_name): + aws_access_key_id = None + aws_secret_access_key = None + s3_endpoint_url = None + + if self.aws_conn_id: + try: + connection_object = self.get_connection(self.aws_conn_id) + if connection_object.login: + aws_access_key_id = connection_object.login + aws_secret_access_key = connection_object.password + + elif 'aws_secret_access_key' in connection_object.extra_dejson: + aws_access_key_id = connection_object.extra_dejson['aws_access_key_id'] + aws_secret_access_key = connection_object.extra_dejson['aws_secret_access_key'] + + elif 's3_config_file' in connection_object.extra_dejson: + aws_access_key_id, aws_secret_access_key = \ + _parse_s3_config(connection_object.extra_dejson['s3_config_file'], + connection_object.extra_dejson.get('s3_config_format')) + + if region_name is None: + region_name = connection_object.extra_dejson.get('region_name') + + s3_endpoint_url = connection_object.extra_dejson.get('host') + + except AirflowException: + # No connection found: fallback on boto3 credential strategy + # http://boto3.readthedocs.io/en/latest/guide/configuration.html + pass - if region_name is None: - region_name = connection_object.extra_dejson.get('region_name') + return aws_access_key_id, aws_secret_access_key, region_name, s3_endpoint_url - except AirflowException: - # No connection found: fallback on boto3 credential strategy - # http://boto3.readthedocs.io/en/latest/guide/configuration.html - aws_access_key_id = None - aws_secret_access_key = None + def get_client_type(self, client_type, region_name=None): + aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \ + self._get_credentials(region_name) return boto3.client( client_type, region_name=region_name, aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key + aws_secret_access_key=aws_secret_access_key, + endpoint_url=endpoint_url ) def get_resource_type(self, resource_type, region_name=None): - try: - connection_object = self.get_connection(self.aws_conn_id) - aws_access_key_id = connection_object.login - aws_secret_access_key = connection_object.password - - if region_name is None: - region_name = connection_object.extra_dejson.get('region_name') - - except AirflowException: - # No connection found: fallback on boto3 credential strategy - # http://boto3.readthedocs.io/en/latest/guide/configuration.html - aws_access_key_id = None - aws_secret_access_key = None - + aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \ + self._get_credentials(region_name) + return boto3.resource( resource_type, region_name=region_name, aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key + aws_secret_access_key=aws_secret_access_key, + endpoint_url=endpoint_url ) diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py index c405001816984..f8052ca4ba6d1 100644 --- a/airflow/hooks/S3_hook.py +++ b/airflow/hooks/S3_hook.py @@ -12,142 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - -from future import standard_library - -from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.aws_hook import AwsHook -standard_library.install_aliases() +from six import StringIO +from urllib.parse import urlparse import re import fnmatch -import configparser -import math -import os -from urllib.parse import urlparse -import warnings - -import boto -from boto.s3.connection import S3Connection, NoHostProvided -from boto.sts import STSConnection - -boto.set_stream_logger('boto') - -from airflow.exceptions import AirflowException -from airflow.hooks.base_hook import BaseHook - -def _parse_s3_config(config_file_name, config_format='boto', profile=None): +class S3Hook(AwsHook): """ - Parses a config file for s3 credentials. Can currently - parse boto, s3cmd.conf and AWS SDK config formats - - :param config_file_name: path to the config file - :type config_file_name: str - :param config_format: config type. One of "boto", "s3cmd" or "aws". - Defaults to "boto" - :type config_format: str - :param profile: profile name in AWS type config file - :type profile: str - """ - Config = configparser.ConfigParser() - if Config.read(config_file_name): # pragma: no cover - sections = Config.sections() - else: - raise AirflowException("Couldn't read {0}".format(config_file_name)) - # Setting option names depending on file format - if config_format is None: - config_format = 'boto' - conf_format = config_format.lower() - if conf_format == 'boto': # pragma: no cover - if profile is not None and 'profile ' + profile in sections: - cred_section = 'profile ' + profile - else: - cred_section = 'Credentials' - elif conf_format == 'aws' and profile is not None: - cred_section = profile - else: - cred_section = 'default' - # Option names - if conf_format in ('boto', 'aws'): # pragma: no cover - key_id_option = 'aws_access_key_id' - secret_key_option = 'aws_secret_access_key' - # security_token_option = 'aws_security_token' - else: - key_id_option = 'access_key' - secret_key_option = 'secret_key' - # Actual Parsing - if cred_section not in sections: - raise AirflowException("This config file format is not recognized") - else: - try: - access_key = Config.get(cred_section, key_id_option) - secret_key = Config.get(cred_section, secret_key_option) - calling_format = None - if Config.has_option(cred_section, 'calling_format'): - calling_format = Config.get(cred_section, 'calling_format') - except: - log = LoggingMixin().log - log.warning("Option Error in parsing s3 config file") - raise - return (access_key, secret_key, calling_format) - - -class S3Hook(BaseHook): - """ - Interact with S3. This class is a wrapper around the boto library. + Interact with AWS S3, using the boto3 library. """ - def __init__( - self, - s3_conn_id='s3_default'): - self.s3_conn_id = s3_conn_id - self.s3_conn = self.get_connection(s3_conn_id) - self.extra_params = self.s3_conn.extra_dejson - self.profile = self.extra_params.get('profile') - self.calling_format = None - self.s3_host = None - self._creds_in_conn = 'aws_secret_access_key' in self.extra_params - self._creds_in_config_file = 's3_config_file' in self.extra_params - self._default_to_boto = False - if 'host' in self.extra_params: - self.s3_host = self.extra_params['host'] - if self._creds_in_conn: - self._a_key = self.extra_params['aws_access_key_id'] - self._s_key = self.extra_params['aws_secret_access_key'] - if 'calling_format' in self.extra_params: - self.calling_format = self.extra_params['calling_format'] - elif self._creds_in_config_file: - self.s3_config_file = self.extra_params['s3_config_file'] - # The format can be None and will default to boto in the parser - self.s3_config_format = self.extra_params.get('s3_config_format') - else: - self._default_to_boto = True - # STS support for cross account resource access - self._sts_conn_required = ('aws_account_id' in self.extra_params or - 'role_arn' in self.extra_params) - if self._sts_conn_required: - self.role_arn = (self.extra_params.get('role_arn') or - "arn:aws:iam::" + - self.extra_params['aws_account_id'] + - ":role/" + - self.extra_params['aws_iam_role']) - self.connection = self.get_conn() - - def __getstate__(self): - pickled_dict = dict(self.__dict__) - del pickled_dict['connection'] - return pickled_dict - def __setstate__(self, d): - self.__dict__.update(d) - self.__dict__['connection'] = self.get_conn() - - def _parse_s3_url(self, s3url): - warnings.warn( - 'Please note: S3Hook._parse_s3_url() is now ' - 'S3Hook.parse_s3_url() (no leading underscore).', - DeprecationWarning) - return self.parse_s3_url(s3url) + def get_conn(self): + return self.get_client_type('s3') @staticmethod def parse_s3_url(s3url): @@ -159,62 +38,6 @@ def parse_s3_url(s3url): key = parsed_url.path.strip('/') return (bucket_name, key) - def get_conn(self): - """ - Returns the boto S3Connection object. - """ - if self._default_to_boto: - return S3Connection(profile_name=self.profile) - a_key = s_key = None - if self._creds_in_config_file: - a_key, s_key, calling_format = _parse_s3_config(self.s3_config_file, - self.s3_config_format, - self.profile) - elif self._creds_in_conn: - a_key = self._a_key - s_key = self._s_key - calling_format = self.calling_format - s3_host = self.s3_host - - if calling_format is None: - calling_format = 'boto.s3.connection.SubdomainCallingFormat' - - if s3_host is None: - s3_host = NoHostProvided - - if self._sts_conn_required: - sts_connection = STSConnection(aws_access_key_id=a_key, - aws_secret_access_key=s_key, - profile_name=self.profile) - assumed_role_object = sts_connection.assume_role( - role_arn=self.role_arn, - role_session_name="Airflow_" + self.s3_conn_id - ) - creds = assumed_role_object.credentials - connection = S3Connection( - aws_access_key_id=creds.access_key, - aws_secret_access_key=creds.secret_key, - calling_format=calling_format, - security_token=creds.session_token - ) - else: - connection = S3Connection(aws_access_key_id=a_key, - aws_secret_access_key=s_key, - calling_format=calling_format, - host=s3_host, - profile_name=self.profile) - return connection - - def get_credentials(self): - if self._creds_in_config_file: - a_key, s_key, calling_format = _parse_s3_config(self.s3_config_file, - self.s3_config_format, - self.profile) - elif self._creds_in_conn: - a_key = self._a_key - s_key = self._s_key - return a_key, s_key - def check_for_bucket(self, bucket_name): """ Check if bucket_name exists. @@ -222,20 +45,35 @@ def check_for_bucket(self, bucket_name): :param bucket_name: the name of the bucket :type bucket_name: str """ - return self.connection.lookup(bucket_name) is not None + try: + self.get_conn().head_bucket(Bucket=bucket_name) + return True + except: + return False def get_bucket(self, bucket_name): """ - Returns a boto.s3.bucket.Bucket object + Returns a boto3.S3.Bucket object :param bucket_name: the name of the bucket :type bucket_name: str """ - return self.connection.get_bucket(bucket_name) + s3 = self.get_resource('s3') + return s3.Bucket(bucket_name) - def list_keys(self, bucket_name, prefix='', delimiter=''): + def check_for_prefix(self, bucket_name, prefix, delimiter): """ - Lists keys in a bucket under prefix and not containing delimiter + Checks that a prefix exists in a bucket + """ + prefix = prefix + delimiter if prefix[-1] != delimiter else prefix + prefix_split = re.split(r'(\w+[{d}])$'.format(d=delimiter), prefix, 1) + previous_level = prefix_split[0] + plist = self.list_prefixes(bucket_name, previous_level, delimiter) + return False if plist is None else prefix in plist + + def list_prefixes(self, bucket_name, prefix='', delimiter=''): + """ + Lists prefixes in a bucket under prefix :param bucket_name: the name of the bucket :type bucket_name: str @@ -244,13 +82,14 @@ def list_keys(self, bucket_name, prefix='', delimiter=''): :param delimiter: the delimiter marks key hierarchy. :type delimiter: str """ - b = self.get_bucket(bucket_name) - keylist = list(b.list(prefix=prefix, delimiter=delimiter)) - return [k.name for k in keylist] if keylist != [] else None + response = self.get_conn().list_objects_v2(Bucket=bucket_name, + Prefix=prefix, + Delimiter=delimiter) + return [p.Prefix for p in response['CommonPrefixes']] if response.get('CommonPrefixes') else None - def list_prefixes(self, bucket_name, prefix='', delimiter=''): + def list_keys(self, bucket_name, prefix='', delimiter=''): """ - Lists prefixes in a bucket under prefix + Lists keys in a bucket under prefix and not containing delimiter :param bucket_name: the name of the bucket :type bucket_name: str @@ -259,24 +98,32 @@ def list_prefixes(self, bucket_name, prefix='', delimiter=''): :param delimiter: the delimiter marks key hierarchy. :type delimiter: str """ - b = self.get_bucket(bucket_name) - plist = b.list(prefix=prefix, delimiter=delimiter) - prefix_names = [p.name for p in plist - if isinstance(p, boto.s3.prefix.Prefix)] - return prefix_names if prefix_names != [] else None + response = self.get_conn().list_objects_v2(Bucket=bucket_name, + Prefix=prefix, + Delimiter=delimiter) + return [k.Key for k in response['Contents']] if response.get('Contents') else None def check_for_key(self, key, bucket_name=None): """ - Checks that a key exists in a bucket + Checks if a key exists in a bucket + + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which the file is stored + :type bucket_name: str """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) - bucket = self.get_bucket(bucket_name) - return bucket.get_key(key) is not None + + try: + self.get_conn().head_object(Bucket=bucket_name, Key=key) + return True + except: + return False def get_key(self, key, bucket_name=None): """ - Returns a boto.s3.key.Key object + Returns a boto3.S3.Key object :param key: the path to the key :type key: str @@ -285,8 +132,21 @@ def get_key(self, key, bucket_name=None): """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) - bucket = self.get_bucket(bucket_name) - return bucket.get_key(key) + + return self.get_conn().get_object(Bucket=bucket_name, Key=key) + + def read_key(self, key, bucket_name=None): + """ + Reads a key from S3 + + :param key: S3 key that will point to the file + :type key: str + :param bucket_name: Name of the bucket in which the file is stored + :type bucket_name: str + """ + + obj = self.get_key(key, bucket_name) + return obj['Body'].read().decode('utf-8') def check_for_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): @@ -299,7 +159,7 @@ def check_for_wildcard_key(self, def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ - Returns a boto.s3.key.Key object matching the regular expression + Returns a boto3.s3.Key object matching the regular expression :param regex_key: the path to the key :type regex_key: str @@ -308,32 +168,20 @@ def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ if not bucket_name: (bucket_name, wildcard_key) = self.parse_s3_url(wildcard_key) - bucket = self.get_bucket(bucket_name) + prefix = re.split(r'[*]', wildcard_key, 1)[0] klist = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) - if not klist: - return None - key_matches = [k for k in klist if fnmatch.fnmatch(k, wildcard_key)] - return bucket.get_key(key_matches[0]) if key_matches else None - - def check_for_prefix(self, bucket_name, prefix, delimiter): - """ - Checks that a prefix exists in a bucket - """ - prefix = prefix + delimiter if prefix[-1] != delimiter else prefix - prefix_split = re.split(r'(\w+[{d}])$'.format(d=delimiter), prefix, 1) - previous_level = prefix_split[0] - plist = self.list_prefixes(bucket_name, previous_level, delimiter) - return False if plist is None else prefix in plist - - def load_file( - self, - filename, - key, - bucket_name=None, - replace=False, - multipart_bytes=5 * (1024 ** 3), - encrypt=False): + if klist: + key_matches = [k for k in klist if fnmatch.fnmatch(k, wildcard_key)] + if key_matches: + return self.get_key(key_matches[0], bucket_name) + + def load_file(self, + filename, + key, + bucket_name=None, + replace=False, + encrypt=False): """ Loads a local file to S3 @@ -347,64 +195,34 @@ def load_file( if it already exists. If replace is False and the key exists, an error will be raised. :type replace: bool - :param multipart_bytes: If provided, the file is uploaded in parts of - this size (minimum 5242880). The default value is 5GB, since S3 - cannot accept non-multipart uploads for files larger than 5GB. If - the file is smaller than the specified limit, the option will be - ignored. - :type multipart_bytes: int :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) - bucket = self.get_bucket(bucket_name) - key_obj = bucket.get_key(key) - if not replace and key_obj: - raise ValueError("The key {key} already exists.".format( - **locals())) - - key_size = os.path.getsize(filename) - if multipart_bytes and key_size >= multipart_bytes: - # multipart upload - from filechunkio import FileChunkIO - mp = bucket.initiate_multipart_upload(key_name=key, - encrypt_key=encrypt) - total_chunks = int(math.ceil(key_size / multipart_bytes)) - sent_bytes = 0 - try: - for chunk in range(total_chunks): - offset = chunk * multipart_bytes - bytes = min(multipart_bytes, key_size - offset) - with FileChunkIO(filename, 'r', offset=offset, bytes=bytes) as fp: - self.log.info('Sending chunk %s of %s...', chunk + 1, total_chunks) - mp.upload_part_from_file(fp, part_num=chunk + 1) - except: - mp.cancel_upload() - raise - mp.complete_upload() - else: - # regular upload - if not key_obj: - key_obj = bucket.new_key(key_name=key) - key_size = key_obj.set_contents_from_filename(filename, - replace=replace, - encrypt_key=encrypt) - self.log.info( - "The key {key} now contains {key_size} bytes".format(**locals()) - ) - - def load_string(self, string_data, - key, bucket_name=None, + + if not replace and self.check_for_key(key, bucket_name): + raise ValueError("The key {key} already exists.".format(key=key)) + + extra_args={} + if encrypt: + extra_args['ServerSideEncryption'] = "AES256" + + client = self.get_conn() + client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args) + + def load_string(self, + string_data, + key, + bucket_name=None, replace=False, encrypt=False): """ - Loads a local file to S3 + Loads a string to S3 - This is provided as a convenience to drop a file in S3. It uses the - boto infrastructure to ship a file to s3. It is currently using only - a single part download, and should not be used to move large files. + This is provided as a convenience to drop a string in S3. It uses the + boto infrastructure to ship a file to s3. :param string_data: string to set as content for the key. :type string_data: str @@ -418,20 +236,18 @@ def load_string(self, string_data, :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool - """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) - bucket = self.get_bucket(bucket_name) - key_obj = bucket.get_key(key) - if not replace and key_obj: - raise ValueError("The key {key} already exists.".format( - **locals())) - if not key_obj: - key_obj = bucket.new_key(key_name=key) - key_size = key_obj.set_contents_from_string(string_data, - replace=replace, - encrypt_key=encrypt) - self.log.info( - "The key {key} now contains {key_size} bytes".format(**locals()) - ) + + if not replace and self.check_for_key(key, bucket_name): + raise ValueError("The key {key} already exists.".format(key=key)) + + extra_args={} + if encrypt: + extra_args['ServerSideEncryption'] = "AES256" + + filelike_buffer = StringIO(string_data) + + client = self.get_conn() + client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs=extra_args) diff --git a/tests/utils/log/test_logging.py b/tests/utils/log/test_logging.py index 8df6dfc8f6116..57f869fdd8a19 100644 --- a/tests/utils/log/test_logging.py +++ b/tests/utils/log/test_logging.py @@ -60,6 +60,10 @@ def test_log_exists_raises(self): self.hook_inst_mock.get_key.side_effect = Exception('error') self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) + def test_log_exists_false(self): + self.hook_inst_mock.check_for_key.return_value = False + self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) + def test_log_exists_no_hook(self): self.hook_mock.side_effect = Exception('Failed to connect') self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) @@ -100,6 +104,7 @@ def test_write(self): ) def test_write_raises(self): + self.hook_inst_mock.read_key.return_value = '' self.hook_inst_mock.load_string.side_effect = Exception('error') handler = S3TaskHandler() with mock.patch.object(handler.log, 'error') as mock_error: