Skip to content

Commit

Permalink
[AIRFLOW-1797] S3Hook.load_string didn't work on Python3
Browse files Browse the repository at this point in the history
With the switch to Boto3 we now need the content
to be bytes, not a
string. On Python2 there is no difference, but for
Python3 this matters.

And since there were no real tests covering the
S3Hook I've added some
basic ones.

Closes #2771 from ashb/AIRFLOW-1797

(cherry picked from commit 28411b1)
Signed-off-by: Bolke de Bruin <[email protected]>
  • Loading branch information
ashb authored and bolkedebruin committed Nov 9, 2017
1 parent 84623da commit 6b7c17d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 23 deletions.
7 changes: 4 additions & 3 deletions airflow/hooks/S3_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from airflow.exceptions import AirflowException
from airflow.contrib.hooks.aws_hook import AwsHook

from six import StringIO
from six import BytesIO
from urllib.parse import urlparse
import re
import fnmatch
Expand Down Expand Up @@ -217,7 +217,8 @@ def load_string(self,
key,
bucket_name=None,
replace=False,
encrypt=False):
encrypt=False,
encoding='utf-8'):
"""
Loads a string to S3
Expand Down Expand Up @@ -247,7 +248,7 @@ def load_string(self,
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"

filelike_buffer = StringIO(string_data)
filelike_buffer = BytesIO(string_data.encode(encoding))

client = self.get_conn()
client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs=extra_args)
20 changes: 0 additions & 20 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,26 +2386,6 @@ def test_host_encoded_https_connection(self, mock_get_connection):
self.assertEqual(hook.base_url, 'https://localhost')


try:
from airflow.hooks.S3_hook import S3Hook
except ImportError:
S3Hook = None


@unittest.skipIf(S3Hook is None,
"Skipping test because S3Hook is not installed")
class S3HookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
self.s3_test_url = "s3://test/this/is/not/a-real-key.txt"

def test_parse_s3_url(self):
parsed = S3Hook.parse_s3_url(self.s3_test_url)
self.assertEqual(parsed,
("test", "this/is/not/a-real-key.txt"),
"Incorrect parsing of the s3 url")


send_email_test = mock.Mock()


Expand Down
74 changes: 74 additions & 0 deletions tests/hooks/test_s3_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from airflow import configuration

try:
from airflow.hooks.S3_hook import S3Hook
except ImportError:
S3Hook = None


try:
import boto3
from moto import mock_s3
except ImportError:
mock_s3 = None


@unittest.skipIf(S3Hook is None,
"Skipping test because S3Hook is not available")
@unittest.skipIf(mock_s3 is None,
"Skipping test because moto.mock_s3 is not available")
class TestS3Hook(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
self.s3_test_url = "s3://test/this/is/not/a-real-key.txt"

def test_parse_s3_url(self):
parsed = S3Hook.parse_s3_url(self.s3_test_url)
self.assertEqual(parsed,
("test", "this/is/not/a-real-key.txt"),
"Incorrect parsing of the s3 url")

@mock_s3
def test_load_string(self):
hook = S3Hook(aws_conn_id=None)
conn = hook.get_conn()
# We need to create the bucket since this is all in Moto's 'virtual'
# AWS account
conn.create_bucket(Bucket="mybucket")

hook.load_string(u"Contént", "my_key", "mybucket")
body = boto3.resource('s3').Object('mybucket', 'my_key').get()['Body'].read()

self.assertEqual(body, b'Cont\xC3\xA9nt')

@mock_s3
def test_read_key(self):
hook = S3Hook(aws_conn_id=None)
conn = hook.get_conn()
# We need to create the bucket since this is all in Moto's 'virtual'
# AWS account
conn.create_bucket(Bucket='mybucket')
conn.put_object(Bucket='mybucket', Key='my_key', Body=b'Cont\xC3\xA9nt')

self.assertEqual(hook.read_key('my_key', 'mybucket'), u'Contént')


if __name__ == '__main__':
unittest.main()

0 comments on commit 6b7c17d

Please sign in to comment.