Skip to content

Commit

Permalink
Merge pull request #1623 from kyleknap/kms-ssec
Browse files Browse the repository at this point in the history
Add support for --sse aws:kms and --sse-c
  • Loading branch information
kyleknap committed Nov 18, 2015
2 parents e86e7f1 + f0bb239 commit f1ed24b
Show file tree
Hide file tree
Showing 16 changed files with 1,010 additions and 148 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Next Release (TBD)
* bugfix:``aws opsworks register``: Support ``--no-verify-ssl``
argument for the ``aws opsworks register`` command
(`issue 1632 <https://github.com/aws/aws-cli/pull/1632>`__)
* feature:``s3``: Add support for Server-Side Encryption with KMS
and Server-Side Encryption with Customer-Provided Keys.
(`issue 1623 <https://github.com/aws/aws-cli/pull/1623>`__)


1.9.7
Expand Down
6 changes: 5 additions & 1 deletion awscli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,16 @@ class CustomArgument(BaseCLIArgument):
def __init__(self, name, help_text='', dest=None, default=None,
action=None, required=None, choices=None, nargs=None,
cli_type_name=None, group_name=None, positional_arg=False,
no_paramfile=False, argument_model=None, synopsis=''):
no_paramfile=False, argument_model=None, synopsis='',
const=None):
self._name = name
self._help = help_text
self._dest = dest
self._default = default
self._action = action
self._required = required
self._nargs = nargs
self._const = const
self._cli_type_name = cli_type_name
self._group_name = group_name
self._positional_arg = positional_arg
Expand Down Expand Up @@ -275,6 +277,8 @@ def add_to_parser(self, parser):
kwargs['required'] = self._required
if self._nargs is not None:
kwargs['nargs'] = self._nargs
if self._const is not None:
kwargs['const'] = self._const
parser.add_argument(cli_name, **kwargs)

@property
Expand Down
9 changes: 7 additions & 2 deletions awscli/customizations/s3/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,17 @@ class FileGenerator(object):
``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``.
"""
def __init__(self, client, operation_name, follow_symlinks=True,
page_size=None, result_queue=None):
page_size=None, result_queue=None, request_parameters=None):
self._client = client
self.operation_name = operation_name
self.follow_symlinks = follow_symlinks
self.page_size = page_size
self.result_queue = result_queue
if not result_queue:
self.result_queue = queue.Queue()
self.request_parameters = {}
if request_parameters is not None:
self.request_parameters = request_parameters

def call(self, files):
"""
Expand Down Expand Up @@ -320,7 +323,9 @@ def _list_single_object(self, s3_path):
# instead use a HeadObject request.
bucket, key = find_bucket_key(s3_path)
try:
response = self._client.head_object(Bucket=bucket, Key=key)
params = {'Bucket': bucket, 'Key': key}
params.update(self.request_parameters.get('HeadObject', {}))
response = self._client.head_object(**params)
except ClientError as e:
# We want to try to give a more helpful error message.
# This is what the customer is going to see so we want to
Expand Down
90 changes: 27 additions & 63 deletions awscli/customizations/s3/fileinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from botocore.compat import quote
from awscli.customizations.s3.utils import find_bucket_key, \
uni_print, guess_content_type, MD5Error, bytes_print, set_file_utime
uni_print, guess_content_type, MD5Error, bytes_print, set_file_utime, \
RequestParamsMapper


LOGGER = logging.getLogger(__name__)
Expand All @@ -38,7 +39,6 @@ def save_file(filename, response_data, last_update, is_stream=False):
"""
body = response_data['Body']
etag = response_data['ETag'][1:-1]
sse = response_data.get('ServerSideEncryption', None)
if not is_stream:
d = os.path.dirname(filename)
try:
Expand All @@ -59,7 +59,7 @@ def save_file(filename, response_data, last_update, is_stream=False):
with open(filename, 'wb') as out_file:
write_to_file(out_file, etag, md5, file_chunks)

if not _is_multipart_etag(etag) and sse != 'aws:kms':
if _can_validate_md5_with_etag(etag, response_data):
if etag != md5.hexdigest():
if not is_stream:
os.remove(filename)
Expand All @@ -75,6 +75,15 @@ def save_file(filename, response_data, last_update, is_stream=False):
sys.stdout.flush()


def _can_validate_md5_with_etag(etag, response_data):
sse = response_data.get('ServerSideEncryption', None)
sse_customer_algorithm = response_data.get('SSECustomerAlgorithm', None)
if not _is_multipart_etag(etag) and sse != 'aws:kms' and \
sse_customer_algorithm is None:
return True
return False


def write_to_file(out_file, etag, md5, file_chunks, is_stream=False):
"""
Updates the etag for each file chunk. It will write to the file if it a
Expand Down Expand Up @@ -188,11 +197,9 @@ def __init__(self, src, dest=None, compare_key=None, size=None,
self.size = size
self.last_update = last_update
# Usually inject ``parameters`` from ``BasicTask`` class.
self.parameters = {}
if parameters is not None:
self.parameters = parameters
else:
self.parameters = {'acl': None,
'sse': None}
self.source_client = source_client
self.is_stream = is_stream
self.associated_response_data = associated_response_data
Expand All @@ -204,60 +211,10 @@ def set_size_from_s3(self):
bucket, key = find_bucket_key(self.src)
params = {'Bucket': bucket,
'Key': key}
RequestParamsMapper.map_head_object_params(params, self.parameters)
response_data = self.client.head_object(**params)
self.size = int(response_data['ContentLength'])

def _permission_to_param(self, permission):
if permission == 'read':
return 'GrantRead'
if permission == 'full':
return 'GrantFullControl'
if permission == 'readacl':
return 'GrantReadACP'
if permission == 'writeacl':
return 'GrantWriteACP'
raise ValueError('permission must be one of: '
'read|readacl|writeacl|full')

def _handle_object_params(self, params):
if self.parameters['acl']:
params['ACL'] = self.parameters['acl'][0]
if self.parameters['grants']:
for grant in self.parameters['grants']:
try:
permission, grantee = grant.split('=', 1)
except ValueError:
raise ValueError('grants should be of the form '
'permission=principal')
params[self._permission_to_param(permission)] = grantee
if self.parameters['sse']:
params['ServerSideEncryption'] = 'AES256'
if self.parameters['storage_class']:
params['StorageClass'] = self.parameters['storage_class'][0]
if self.parameters['website_redirect']:
params['WebsiteRedirectLocation'] = \
self.parameters['website_redirect'][0]
if self.parameters['guess_mime_type']:
self._inject_content_type(params, self.src)
if self.parameters['content_type']:
params['ContentType'] = self.parameters['content_type'][0]
if self.parameters['cache_control']:
params['CacheControl'] = self.parameters['cache_control'][0]
if self.parameters['content_disposition']:
params['ContentDisposition'] = \
self.parameters['content_disposition'][0]
if self.parameters['content_encoding']:
params['ContentEncoding'] = self.parameters['content_encoding'][0]
if self.parameters['content_language']:
params['ContentLanguage'] = self.parameters['content_language'][0]
if self.parameters['expires']:
params['Expires'] = self.parameters['expires'][0]

def _handle_metadata_directive(self, params):
if self.parameters['metadata_directive']:
params['MetadataDirective'] = \
self.parameters['metadata_directive'][0]

def is_glacier_compatible(self):
"""Determines if a file info object is glacier compatible
Expand Down Expand Up @@ -301,10 +258,14 @@ def _handle_upload(self, body):
'Key': key,
'Body': body,
}
self._handle_object_params(params)
self._inject_content_type(params)
RequestParamsMapper.map_put_object_params(params, self.parameters)
response_data = self.client.put_object(**params)

def _inject_content_type(self, params, filename):
def _inject_content_type(self, params):
if not self.parameters['guess_mime_type']:
return
filename = self.src
# Add a content type param if we can guess the type.
try:
guessed_type = guess_content_type(filename)
Expand All @@ -331,6 +292,7 @@ def download(self):
"""
bucket, key = find_bucket_key(self.src)
params = {'Bucket': bucket, 'Key': key}
RequestParamsMapper.map_get_object_params(params, self.parameters)
response_data = self.client.get_object(**params)
save_file(self.dest, response_data, self.last_update,
self.is_stream)
Expand All @@ -343,9 +305,9 @@ def copy(self):
bucket, key = find_bucket_key(self.dest)
params = {'Bucket': bucket,
'CopySource': copy_source, 'Key': key}
self._handle_object_params(params)
self._handle_metadata_directive(params)
self.client.copy_object(**params)
self._inject_content_type(params)
RequestParamsMapper.map_copy_object_params(params, self.parameters)
response_data = self.client.copy_object(**params)

def delete(self):
"""
Expand Down Expand Up @@ -378,7 +340,9 @@ def move(self):
def create_multipart_upload(self):
bucket, key = find_bucket_key(self.dest)
params = {'Bucket': bucket, 'Key': key}
self._handle_object_params(params)
self._inject_content_type(params)
RequestParamsMapper.map_create_multipart_upload_params(
params, self.parameters)
response_data = self.client.create_multipart_upload(**params)
upload_id = response_data['UploadId']
return upload_id
29 changes: 17 additions & 12 deletions awscli/customizations/s3/s3handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ def __init__(self, session, params, result_queue=None,
self.result_queue = result_queue
if not self.result_queue:
self.result_queue = queue.Queue()
self.params = {'dryrun': False, 'quiet': False, 'acl': None,
'guess_mime_type': True, 'sse': False,
'storage_class': None, 'website_redirect': None,
'content_type': None, 'cache_control': None,
'content_disposition': None, 'content_encoding': None,
'content_language': None, 'expires': None,
'grants': None, 'only_show_errors': False,
'is_stream': False, 'paths_type': None,
'expected_size': None, 'metadata_directive': None,
'ignore_glacier_warnings': False}
self.params = {
'dryrun': False, 'quiet': False, 'acl': None,
'guess_mime_type': True, 'sse_c_copy_source': None,
'sse_c_copy_source_key': None, 'sse': None,
'sse_c': None, 'sse_c_key': None, 'sse_kms_key_id': None,
'storage_class': None, 'website_redirect': None,
'content_type': None, 'cache_control': None,
'content_disposition': None, 'content_encoding': None,
'content_language': None, 'expires': None, 'grants': None,
'only_show_errors': False, 'is_stream': False,
'paths_type': None, 'expected_size': None,
'metadata_directive': None, 'ignore_glacier_warnings': False
}
self.params['region'] = params['region']
for key in self.params.keys():
if key in params:
Expand Down Expand Up @@ -287,7 +290,8 @@ def _do_enqueue_range_download_tasks(self, filename, chunksize,
task = tasks.DownloadPartTask(
part_number=i, chunk_size=chunksize,
result_queue=self.result_queue, filename=filename,
context=context, io_queue=self.write_queue)
context=context, io_queue=self.write_queue,
params=self.params)
self.executor.submit(task)

def _enqueue_multipart_upload_tasks(self, filename,
Expand Down Expand Up @@ -350,7 +354,8 @@ def _enqueue_upload_single_part_task(self, part_number, chunk_size,
payload=None):
kwargs = {'part_number': part_number, 'chunk_size': chunk_size,
'result_queue': self.result_queue,
'upload_context': upload_context, 'filename': filename}
'upload_context': upload_context, 'filename': filename,
'params': self.params}
if payload:
kwargs['payload'] = payload
task = task_class(**kwargs)
Expand Down
Loading

0 comments on commit f1ed24b

Please sign in to comment.