Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for s3 paths not having a trailing slash #557

Merged
merged 2 commits into from
Dec 16, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions awscli/customizations/s3/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,25 @@ def add_paths(self, paths):
the destination always have some value.
"""
self.check_path_type(paths)
self._normalize_s3_trailing_slash(paths)
src_path = paths[0]
self.parameters['src'] = src_path
if len(paths) == 2:
self.parameters['dest'] = paths[1]
elif len(paths) == 1:
self.parameters['dest'] = paths[0]

def _normalize_s3_trailing_slash(self, paths):
for i, path in enumerate(paths):
if path.startswith('s3://'):
bucket, key = find_bucket_key(path[5:])
if not key and not path.endswith('/'):
# If only a bucket was specified, we need
# to normalize the path and ensure it ends
# with a '/', s3://bucket -> s3://bucket/
path += '/'
paths[i] = path

def _verify_bucket_exists(self, bucket_name):
session = self.session
service = session.get_service('s3')
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/customizations/s3/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,23 @@ def test_cp_to_and_from_s3(self):
with open(full_path, 'r') as f:
self.assertEqual(f.read(), 'this is foo.txt')

def test_cp_without_trailing_slash(self):
# There's a unit test for this, but we still want to verify this
# with an integration test.
bucket_name = self.create_bucket()

# copy file into bucket.
foo_txt = self.files.create_file('foo.txt', 'this is foo.txt')
# Note that the destination has no trailing slash.
p = aws('s3 cp %s s3://%s' % (foo_txt, bucket_name))
self.assert_no_errors(p)

# Make sure object is in bucket.
self.assertTrue(self.key_exists(bucket_name, key_name='foo.txt'))
self.assertEqual(
self.get_key_contents(bucket_name, key_name='foo.txt'),
'this is foo.txt')

def test_cp_s3_s3_multipart(self):
from_bucket = self.create_bucket()
to_bucket = self.create_bucket()
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/customizations/s3/test_cp_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def test_operations_used_in_upload(self):
self.assertEqual(len(self.operations_called), 1, self.operations_called)
self.assertEqual(self.operations_called[0][0].name, 'PutObject')

def test_key_name_added_when_only_bucket_provided(self):
full_path = self.files.create_file('foo.txt', 'mycontent')
cmdline = '%s %s s3://bucket/' % (self.prefix, full_path)
self.parsed_responses = [{'ETag': '"c8afdb36c52cf4727836669019e69222"'}]
self.run_cmd(cmdline, expected_rc=0)
# The only operation we should have called is PutObject.
self.assertEqual(len(self.operations_called), 1, self.operations_called)
self.assertEqual(self.operations_called[0][0].name, 'PutObject')
self.assertEqual(self.operations_called[0][1]['key'], 'foo.txt')
self.assertEqual(self.operations_called[0][1]['bucket'], 'bucket')

def test_trailing_slash_appended(self):
full_path = self.files.create_file('foo.txt', 'mycontent')
# Here we're saying s3://bucket instead of s3://bucket/
# This should still work the same as if we added the trailing slash.
cmdline = '%s %s s3://bucket' % (self.prefix, full_path)
self.parsed_responses = [{'ETag': '"c8afdb36c52cf4727836669019e69222"'}]
self.run_cmd(cmdline, expected_rc=0)
# The only operation we should have called is PutObject.
self.assertEqual(len(self.operations_called), 1, self.operations_called)
self.assertEqual(self.operations_called[0][0].name, 'PutObject')
self.assertEqual(self.operations_called[0][1]['key'], 'foo.txt')
self.assertEqual(self.operations_called[0][1]['bucket'], 'bucket')

def test_operations_used_in_download_file(self):
self.parsed_responses = [
{"ContentLength": "100", "LastModified": "00:00:00Z"},
Expand Down