diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a1de4f7e2a2f..c0b5afe6645f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,9 @@ CHANGELOG Next Release (TBD) ================== +* feature:``aws s3 cp``: Added ability to upload local + file streams from standard input to s3 and download s3 + objects as local file streams to standard output. * feature:Page Size: Add a ``--page-size`` option, that controls page size when perfoming an operation that uses pagination. diff --git a/awscli/customizations/s3/constants.py b/awscli/customizations/s3/constants.py index d0877eed26b2..7c0b7c4fcbc3 100644 --- a/awscli/customizations/s3/constants.py +++ b/awscli/customizations/s3/constants.py @@ -18,3 +18,4 @@ MAX_SINGLE_UPLOAD_SIZE = 5 * (1024 ** 3) MAX_UPLOAD_SIZE = 5 * (1024 ** 4) MAX_QUEUE_SIZE = 1000 +STREAM_INPUT_TIMEOUT = 0.1 diff --git a/awscli/customizations/s3/executor.py b/awscli/customizations/s3/executor.py index 872f181ef055..c6866246022a 100644 --- a/awscli/customizations/s3/executor.py +++ b/awscli/customizations/s3/executor.py @@ -15,8 +15,8 @@ import sys import threading -from awscli.customizations.s3.utils import uni_print, \ - IORequest, IOCloseRequest, StablePriorityQueue +from awscli.customizations.s3.utils import uni_print, bytes_print, \ + IORequest, IOCloseRequest, StablePriorityQueue from awscli.customizations.s3.tasks import OrderableTask @@ -50,8 +50,7 @@ def __init__(self, num_threads, result_queue, self.quiet = quiet self.threads_list = [] self.write_queue = write_queue - self.print_thread = PrintThread(self.result_queue, - self.quiet) + self.print_thread = PrintThread(self.result_queue, self.quiet) self.print_thread.daemon = True self.io_thread = IOWriterThread(self.write_queue) @@ -153,23 +152,28 @@ def run(self): self._cleanup() return elif isinstance(task, IORequest): - filename, offset, data = task - fileobj = self.fd_descriptor_cache.get(filename) - if fileobj is None: - fileobj = open(filename, 'rb+') - self.fd_descriptor_cache[filename] = fileobj - fileobj.seek(offset) + filename, offset, data, is_stream = task + if is_stream: + fileobj = sys.stdout + bytes_print(data) + else: + fileobj = self.fd_descriptor_cache.get(filename) + if fileobj is None: + fileobj = open(filename, 'rb+') + self.fd_descriptor_cache[filename] = fileobj + fileobj.seek(offset) + fileobj.write(data) LOGGER.debug("Writing data to: %s, offset: %s", filename, offset) - fileobj.write(data) fileobj.flush() elif isinstance(task, IOCloseRequest): LOGGER.debug("IOCloseRequest received for %s, closing file.", task.filename) - fileobj = self.fd_descriptor_cache.get(task.filename) - if fileobj is not None: - fileobj.close() - del self.fd_descriptor_cache[task.filename] + if not task.is_stream: + fileobj = self.fd_descriptor_cache.get(task.filename) + if fileobj is not None: + fileobj.close() + del self.fd_descriptor_cache[task.filename] def _cleanup(self): for fileobj in self.fd_descriptor_cache.values(): @@ -237,7 +241,7 @@ def __init__(self, result_queue, quiet): self._lock = threading.Lock() self._needs_newline = False - self._total_parts = 0 + self._total_parts = '...' self._total_files = '...' # This is a public attribute that clients can inspect to determine diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index b53be0c45939..f756bb4b3c62 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -95,7 +95,7 @@ def __init__(self, directory, filename): class FileStat(object): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, - operation_name=None): + operation_name=None, is_stream=False): self.src = src self.dest = dest self.compare_key = compare_key @@ -104,6 +104,7 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.src_type = src_type self.dest_type = dest_type self.operation_name = operation_name + self.is_stream = is_stream class FileGenerator(object): @@ -115,7 +116,8 @@ class FileGenerator(object): ``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``. """ def __init__(self, service, endpoint, operation_name, - follow_symlinks=True, page_size=None, result_queue=None): + follow_symlinks=True, page_size=None, result_queue=None, + is_stream=False): self._service = service self._endpoint = endpoint self.operation_name = operation_name @@ -124,6 +126,7 @@ def __init__(self, service, endpoint, operation_name, self.result_queue = result_queue if not result_queue: self.result_queue = queue.Queue() + self.is_stream = is_stream def call(self, files): """ @@ -135,7 +138,11 @@ def call(self, files): dest = files['dest'] src_type = src['type'] dest_type = dest['type'] - function_table = {'s3': self.list_objects, 'local': self.list_files} + function_table = {'s3': self.list_objects} + if self.is_stream: + function_table['local'] = self.list_local_file_stream + else: + function_table['local'] = self.list_files sep_table = {'s3': '/', 'local': os.sep} source = src['path'] file_list = function_table[src_type](source, files['dir_op']) @@ -155,7 +162,15 @@ def call(self, files): compare_key=compare_key, size=size, last_update=last_update, src_type=src_type, dest_type=dest_type, - operation_name=self.operation_name) + operation_name=self.operation_name, + is_stream=self.is_stream) + + def list_local_file_stream(self, path, dir_op): + """ + Yield some dummy values for a local file stream since it does not + actually have a file. + """ + yield '-', 0, None def list_files(self, path, dir_op): """ diff --git a/awscli/customizations/s3/fileinfo.py b/awscli/customizations/s3/fileinfo.py index fe482e64d13b..407fcc8837ad 100644 --- a/awscli/customizations/s3/fileinfo.py +++ b/awscli/customizations/s3/fileinfo.py @@ -11,7 +11,7 @@ from botocore.compat import quote from awscli.customizations.s3.utils import find_bucket_key, \ check_etag, check_error, operate, uni_print, \ - guess_content_type, MD5Error + guess_content_type, MD5Error, bytes_print class CreateDirectoryError(Exception): @@ -26,7 +26,7 @@ def read_file(filename): return in_file.read() -def save_file(filename, response_data, last_update): +def save_file(filename, response_data, last_update, is_stream=False): """ This writes to the file upon downloading. It reads the data in the response. Makes a new directory if needed and then writes the @@ -35,31 +35,57 @@ def save_file(filename, response_data, last_update): """ body = response_data['Body'] etag = response_data['ETag'][1:-1] - d = os.path.dirname(filename) - try: - if not os.path.exists(d): - os.makedirs(d) - except OSError as e: - if not e.errno == errno.EEXIST: - raise CreateDirectoryError( - "Could not create directory %s: %s" % (d, e)) + if not is_stream: + d = os.path.dirname(filename) + try: + if not os.path.exists(d): + os.makedirs(d) + except OSError as e: + if not e.errno == errno.EEXIST: + raise CreateDirectoryError( + "Could not create directory %s: %s" % (d, e)) md5 = hashlib.md5() file_chunks = iter(partial(body.read, 1024 * 1024), b'') - with open(filename, 'wb') as out_file: - if not _is_multipart_etag(etag): - for chunk in file_chunks: - md5.update(chunk) - out_file.write(chunk) - else: - for chunk in file_chunks: - out_file.write(chunk) + if is_stream: + # Need to save the data to be able to check the etag for a stream + # becuase once the data is written to the stream there is no + # undoing it. + payload = write_to_file(None, etag, md5, file_chunks, True) + else: + with open(filename, 'wb') as out_file: + write_to_file(out_file, etag, md5, file_chunks) + if not _is_multipart_etag(etag): if etag != md5.hexdigest(): - os.remove(filename) + if not is_stream: + os.remove(filename) raise MD5Error(filename) - last_update_tuple = last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + + if not is_stream: + last_update_tuple = last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + else: + # Now write the output to stdout since the md5 is correct. + bytes_print(payload) + sys.stdout.flush() + + +def write_to_file(out_file, etag, md5, file_chunks, is_stream=False): + """ + Updates the etag for each file chunk. It will write to the file if it a + file but if it is a stream it will return a byte string to be later + written to a stream. + """ + body = b'' + for chunk in file_chunks: + if not _is_multipart_etag(etag): + md5.update(chunk) + if is_stream: + body += chunk + else: + out_file.write(chunk) + return body def _is_multipart_etag(etag): @@ -140,7 +166,7 @@ class FileInfo(TaskInfo): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, operation_name=None, service=None, endpoint=None, - parameters=None, source_endpoint=None): + parameters=None, source_endpoint=None, is_stream=False): super(FileInfo, self).__init__(src, src_type=src_type, operation_name=operation_name, service=service, @@ -157,6 +183,7 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.parameters = {'acl': None, 'sse': None} self.source_endpoint = source_endpoint + self.is_stream = is_stream def _permission_to_param(self, permission): if permission == 'read': @@ -204,24 +231,30 @@ def _handle_object_params(self, params): if self.parameters['expires']: params['expires'] = self.parameters['expires'][0] - def upload(self): + def upload(self, payload=None): """ Redirects the file to the multipart upload function if the file is large. If it is small enough, it puts the file as an object in s3. """ - with open(self.src, 'rb') as body: - bucket, key = find_bucket_key(self.dest) - params = { - 'endpoint': self.endpoint, - 'bucket': bucket, - 'key': key, - 'body': body, - } - self._handle_object_params(params) - response_data, http = operate(self.service, 'PutObject', params) - etag = response_data['ETag'][1:-1] - body.seek(0) - check_etag(etag, body) + if payload: + self._handle_upload(payload) + else: + with open(self.src, 'rb') as body: + self._handle_upload(body) + + def _handle_upload(self, body): + bucket, key = find_bucket_key(self.dest) + params = { + 'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key, + 'body': body, + } + self._handle_object_params(params) + response_data, http = operate(self.service, 'PutObject', params) + etag = response_data['ETag'][1:-1] + body.seek(0) + check_etag(etag, body) def _inject_content_type(self, params, filename): # Add a content type param if we can guess the type. @@ -237,7 +270,8 @@ def download(self): bucket, key = find_bucket_key(self.src) params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key} response_data, http = operate(self.service, 'GetObject', params) - save_file(self.dest, response_data, self.last_update) + save_file(self.dest, response_data, self.last_update, + self.is_stream) def copy(self): """ diff --git a/awscli/customizations/s3/fileinfobuilder.py b/awscli/customizations/s3/fileinfobuilder.py index 8bc2042615ef..b220565b61cc 100644 --- a/awscli/customizations/s3/fileinfobuilder.py +++ b/awscli/customizations/s3/fileinfobuilder.py @@ -42,6 +42,7 @@ def _inject_info(self, file_base): file_info_attr['src_type'] = file_base.src_type file_info_attr['dest_type'] = file_base.dest_type file_info_attr['operation_name'] = file_base.operation_name + file_info_attr['is_stream'] = file_base.is_stream file_info_attr['service'] = self._service file_info_attr['endpoint'] = self._endpoint file_info_attr['source_endpoint'] = self._source_endpoint diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index 91f701bbd83d..46dc2e6c897d 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -14,10 +14,13 @@ import logging import math import os +import six from six.moves import queue +import sys +import time from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \ - NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE + NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE, STREAM_INPUT_TIMEOUT from awscli.customizations.s3.utils import find_chunksize, \ operate, find_bucket_key, relative_path, PrintTask, create_warning from awscli.customizations.s3.executor import Executor @@ -53,16 +56,24 @@ def __init__(self, session, params, result_queue=None, 'content_type': None, 'cache_control': None, 'content_disposition': None, 'content_encoding': None, 'content_language': None, 'expires': None, - 'grants': None} + 'grants': None, 'is_stream': False, 'paths_type': None, + 'expected_size': None} self.params['region'] = params['region'] for key in self.params.keys(): if key in params: self.params[key] = params[key] self.multi_threshold = multi_threshold self.chunksize = chunksize + self._max_executer_queue_size = MAX_QUEUE_SIZE + if self.params['is_stream']: + # This ensures that at most the number of multipart chunks + # waiting in the executor queue from a stream read in from stdin + # is the same as the number of threads needed to upload it. + self._max_executer_queue_size = NUM_THREADS self.executor = Executor( num_threads=NUM_THREADS, result_queue=self.result_queue, - quiet=self.params['quiet'], max_queue_size=MAX_QUEUE_SIZE, + quiet=self.params['quiet'], + max_queue_size=self._max_executer_queue_size, write_queue=self.write_queue ) self._multipart_uploads = [] @@ -162,7 +173,15 @@ def _enqueue_tasks(self, files): total_parts = 0 for filename in files: num_uploads = 1 - is_multipart_task = self._is_multipart_task(filename) + # If uploading stream, it is required to read from the stream + # to determine if the stream needs to be multipart uploaded. + payload = None + if getattr(filename, 'is_stream', False) and \ + filename.operation_name == 'upload': + payload, is_multipart_task = \ + self._pull_from_stream(self.multi_threshold) + else: + is_multipart_task = self._is_multipart_task(filename) too_large = False if hasattr(filename, 'size'): too_large = filename.size > MAX_UPLOAD_SIZE @@ -178,17 +197,42 @@ def _enqueue_tasks(self, files): # fact that it's transferring a file rather than # the specific part tasks required to perform the # transfer. - num_uploads = self._enqueue_multipart_tasks(filename) + num_uploads = self._enqueue_multipart_tasks(filename, payload) else: task = tasks.BasicTask( session=self.session, filename=filename, parameters=self.params, - result_queue=self.result_queue) + result_queue=self.result_queue, + payload=payload) self.executor.submit(task) total_files += 1 total_parts += num_uploads return total_files, total_parts + def _pull_from_stream(self, initial_amount_requested): + size = 0 + amount_requested = initial_amount_requested + total_retries = 0 + payload = b'' + stream_filein = sys.stdin + if six.PY3: + stream_filein = sys.stdin.buffer + while True: + payload_chunk = stream_filein.read(amount_requested) + payload_chunk_size = len(payload_chunk) + payload += payload_chunk + size += payload_chunk_size + amount_requested -= payload_chunk_size + if payload_chunk_size == 0: + time.sleep(STREAM_INPUT_TIMEOUT) + total_retries += 1 + else: + total_retries = 0 + if amount_requested == 0 or total_retries == 5: + break + payload_file = six.BytesIO(payload) + return payload_file, size == initial_amount_requested + def _is_multipart_task(self, filename): # First we need to determine if it's an operation that even # qualifies for multipart upload. @@ -203,10 +247,11 @@ def _is_multipart_task(self, filename): else: return False - def _enqueue_multipart_tasks(self, filename): + def _enqueue_multipart_tasks(self, filename, payload=None): num_uploads = 1 if filename.operation_name == 'upload': - num_uploads = self._enqueue_multipart_upload_tasks(filename) + num_uploads = self._enqueue_multipart_upload_tasks(filename, + payload=payload) elif filename.operation_name == 'move': if filename.src_type == 'local' and filename.dest_type == 's3': num_uploads = self._enqueue_multipart_upload_tasks( @@ -231,9 +276,12 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): chunksize = find_chunksize(filename.size, self.chunksize) num_downloads = int(filename.size / chunksize) context = tasks.MultipartDownloadContext(num_downloads) - create_file_task = tasks.CreateLocalFileTask(context=context, - filename=filename) - self.executor.submit(create_file_task) + if not filename.is_stream: + create_file_task = tasks.CreateLocalFileTask(context=context, + filename=filename) + self.executor.submit(create_file_task) + else: + context.announce_file_created() for i in range(num_downloads): task = tasks.DownloadPartTask( part_number=i, chunk_size=chunksize, @@ -252,17 +300,27 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): return num_downloads def _enqueue_multipart_upload_tasks(self, filename, - remove_local_file=False): + remove_local_file=False, + payload=None): # First we need to create a CreateMultipartUpload task, # then create UploadTask objects for each of the parts. # And finally enqueue a CompleteMultipartUploadTask. - chunksize = find_chunksize(filename.size, self.chunksize) - num_uploads = int(math.ceil(filename.size / - float(chunksize))) + chunksize = self.chunksize + if not filename.is_stream: + chunksize = find_chunksize(filename.size, self.chunksize) + num_uploads = int(math.ceil(filename.size / + float(chunksize))) + else: + if self.params['expected_size']: + chunksize = find_chunksize(int(self.params['expected_size']), + self.chunksize) + num_uploads = '...' upload_context = self._enqueue_upload_start_task( - chunksize, num_uploads, filename) - self._enqueue_upload_tasks( - num_uploads, chunksize, upload_context, filename, tasks.UploadPartTask) + chunksize, num_uploads, filename, payload) + num_uploads = self._enqueue_upload_tasks( + num_uploads, chunksize, upload_context, + filename, tasks.UploadPartTask + ) self._enqueue_upload_end_task(filename, upload_context) if remove_local_file: remove_task = tasks.RemoveFileTask(local_filename=filename.src, @@ -276,8 +334,7 @@ def _enqueue_multipart_copy_tasks(self, filename, num_uploads = int(math.ceil(filename.size / float(chunksize))) upload_context = self._enqueue_upload_start_task( chunksize, num_uploads, filename) - self._enqueue_upload_tasks( - num_uploads, chunksize, upload_context, filename, tasks.CopyPartTask) + self._enqueue_upload_tasks(num_uploads, chunksize, upload_context, filename, tasks.CopyPartTask) self._enqueue_upload_end_task(filename, upload_context) if remove_remote_file: remove_task = tasks.RemoveRemoteObjectTask( @@ -285,7 +342,8 @@ def _enqueue_multipart_copy_tasks(self, filename, self.executor.submit(remove_task) return num_uploads - def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): + def _enqueue_upload_start_task(self, chunksize, num_uploads, filename, + payload=None): upload_context = tasks.MultipartUploadContext( expected_parts=num_uploads) create_multipart_upload_task = tasks.CreateMultipartUploadTask( @@ -293,16 +351,56 @@ def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): parameters=self.params, result_queue=self.result_queue, upload_context=upload_context) self.executor.submit(create_multipart_upload_task) + if filename.is_stream and filename.operation_name == 'upload': + # Upload the part that was intially pulled from the stream. + self._enqueue_upload_single_part_task( + part_number=1, chunk_size=chunksize, + upload_context=upload_context, filename=filename, + task_class=tasks.UploadPartTask, payload=payload + ) return upload_context - def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, filename, - task_class): - for i in range(1, (num_uploads + 1)): - task = task_class( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, upload_context=upload_context, - filename=filename) - self.executor.submit(task) + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): + if filename.is_stream and filename.operation_name == 'upload': + # The previous upload occured right after the multipart + # upload started for a stream. + num_uploads = 1 + while True: + payload, is_remaining = self._pull_from_stream(chunksize) + self._enqueue_upload_single_part_task( + part_number=num_uploads+1, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class, + payload=payload + ) + num_uploads += 1 + if not is_remaining: + break + upload_context.announce_total_parts(num_uploads) + else: + for i in range(1, (num_uploads + 1)): + self._enqueue_upload_single_part_task( + part_number=i, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class + ) + return num_uploads + + def _enqueue_upload_single_part_task(self, part_number, chunk_size, + upload_context, filename, task_class, + payload=None): + kwargs = {'part_number': part_number, 'chunk_size': chunk_size, + 'result_queue': self.result_queue, + 'upload_context': upload_context, 'filename': filename} + if payload: + kwargs['payload'] = payload + task = task_class(**kwargs) + self.executor.submit(task) def _enqueue_upload_end_task(self, filename, upload_context): complete_multipart_upload_task = tasks.CompleteMultipartUploadTask( diff --git a/awscli/customizations/s3/subcommands.py b/awscli/customizations/s3/subcommands.py index 6ec6856294d9..0df2f8c1e94f 100644 --- a/awscli/customizations/s3/subcommands.py +++ b/awscli/customizations/s3/subcommands.py @@ -206,6 +206,15 @@ 'The object key name to use when ' 'a 4XX class error occurs.')} +EXPECTED_SIZE = {'name': 'expected-size', + 'help_text': ( + 'This argument specifies the expected size of a stream ' + 'in terms of bytes. Note that this argument is needed ' + 'only when a stream is being uploaded to s3 and the size ' + 'is larger than 5GB. Failure to include this argument ' + 'under these conditions may result in a failed upload. ' + 'due to too many parts in upload.')} + TRANSFER_ARGS = [DRYRUN, QUIET, RECURSIVE, INCLUDE, EXCLUDE, ACL, FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE, SSE, STORAGE_CLASS, GRANTS, WEBSITE_REDIRECT, CONTENT_TYPE, @@ -413,7 +422,7 @@ class CpCommand(S3TransferCommand): USAGE = " or " \ "or " ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True, - 'synopsis': USAGE}] + TRANSFER_ARGS + 'synopsis': USAGE}] + TRANSFER_ARGS + [EXPECTED_SIZE] EXAMPLES = BasicCommand.FROM_FILE('s3/cp.rst') @@ -566,7 +575,8 @@ def run(self): operation_name, self.parameters['follow_symlinks'], self.parameters['page_size'], - result_queue=result_queue) + result_queue=result_queue, + is_stream=self.parameters['is_stream']) rev_generator = FileGenerator(self._service, self._endpoint, '', self.parameters['follow_symlinks'], self.parameters['page_size'], @@ -683,8 +693,19 @@ def add_paths(self, paths): self.parameters['dest'] = paths[1] elif len(paths) == 1: self.parameters['dest'] = paths[0] + self._validate_streaming_paths() self._validate_path_args() + def _validate_streaming_paths(self): + self.parameters['is_stream'] = False + if self.parameters['src'] == '-' or self.parameters['dest'] == '-': + self.parameters['is_stream'] = True + self.parameters['dir_op'] = False + self.parameters['quiet'] = True + if self.parameters['is_stream'] and self.cmd != 'cp': + raise ValueError("Streaming currently is only compatible with " + "single file cp commands") + def _validate_path_args(self): # If we're using a mv command, you can't copy the object onto itself. params = self.parameters diff --git a/awscli/customizations/s3/tasks.py b/awscli/customizations/s3/tasks.py index 37089c42a2b9..07e6547bb619 100644 --- a/awscli/customizations/s3/tasks.py +++ b/awscli/customizations/s3/tasks.py @@ -63,7 +63,8 @@ class BasicTask(OrderableTask): attributes like ``session`` object in order for the filename to perform its designated operation. """ - def __init__(self, session, filename, parameters, result_queue): + def __init__(self, session, filename, parameters, + result_queue, payload=None): self.session = session self.service = self.session.get_service('s3') @@ -72,6 +73,7 @@ def __init__(self, session, filename, parameters, result_queue): self.parameters = parameters self.result_queue = result_queue + self.payload = payload def __call__(self): self._execute_task(attempts=3) @@ -84,9 +86,12 @@ def _execute_task(self, attempts, last_error=''): error_message=last_error) return filename = self.filename + kwargs = {} + if self.payload: + kwargs['payload'] = self.payload try: if not self.parameters['dryrun']: - getattr(filename, filename.operation_name)() + getattr(filename, filename.operation_name)(**kwargs) except requests.ConnectionError as e: connect_error = str(e) LOGGER.debug("%s %s failure: %s", @@ -195,13 +200,14 @@ class UploadPartTask(OrderableTask): complete the multipart upload initiated by the ``FileInfo`` object. """ - def __init__(self, part_number, chunk_size, - result_queue, upload_context, filename): + def __init__(self, part_number, chunk_size, result_queue, upload_context, + filename, payload=None): self._result_queue = result_queue self._upload_context = upload_context self._part_number = part_number self._chunk_size = chunk_size self._filename = filename + self._payload = payload def _read_part(self): actual_filename = self._filename.src @@ -216,9 +222,13 @@ def __call__(self): LOGGER.debug("Waiting for upload id.") upload_id = self._upload_context.wait_for_upload_id() bucket, key = find_bucket_key(self._filename.dest) - total = int(math.ceil( - self._filename.size/float(self._chunk_size))) - body = self._read_part() + if self._filename.is_stream: + body = self._payload + total = self._upload_context.expected_parts + else: + total = int(math.ceil( + self._filename.size/float(self._chunk_size))) + body = self._read_part() params = {'endpoint': self._filename.endpoint, 'bucket': bucket, 'key': key, 'part_number': self._part_number, @@ -298,14 +308,17 @@ def __call__(self): # 3) Queue an IO request to the IO thread letting it know we're # done with the file. self._context.wait_for_completion() - last_update_tuple = self._filename.last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(self._filename.dest, (int(mod_timestamp), int(mod_timestamp))) + if not self._filename.is_stream: + last_update_tuple = self._filename.last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(self._filename.dest, + (int(mod_timestamp), int(mod_timestamp))) message = print_operation(self._filename, False, self._parameters['dryrun']) print_task = {'message': message, 'error': False} self._result_queue.put(PrintTask(**print_task)) - self._io_queue.put(IOCloseRequest(self._filename.dest)) + self._io_queue.put(IOCloseRequest(self._filename.dest, + self._filename.is_stream)) class DownloadPartTask(OrderableTask): @@ -393,16 +406,23 @@ def _queue_writes(self, body): body.set_socket_timeout(self.READ_TIMEOUT) amount_read = 0 current = body.read(iterate_chunk_size) + if self._filename.is_stream: + self._context.wait_for_turn(self._part_number) while current: offset = self._part_number * self._chunk_size + amount_read LOGGER.debug("Submitting IORequest to write queue.") - self._io_queue.put(IORequest(self._filename.dest, offset, current)) + self._io_queue.put( + IORequest(self._filename.dest, offset, current, + self._filename.is_stream) + ) LOGGER.debug("Request successfully submitted.") amount_read += len(current) current = body.read(iterate_chunk_size) # Change log message. LOGGER.debug("Done queueing writes for part number %s to file: %s", self._part_number, self._filename.dest) + if self._filename.is_stream: + self._context.done_with_turn() class CreateMultipartUploadTask(BasicTask): @@ -530,7 +550,7 @@ class MultipartUploadContext(object): _CANCELLED = '_CANCELLED' _COMPLETED = '_COMPLETED' - def __init__(self, expected_parts): + def __init__(self, expected_parts='...'): self._upload_id = None self._expected_parts = expected_parts self._parts = [] @@ -540,6 +560,10 @@ def __init__(self, expected_parts): self._upload_complete_condition = threading.Condition(self._lock) self._state = self._UNSTARTED + @property + def expected_parts(self): + return self._expected_parts + def announce_upload_id(self, upload_id): with self._upload_id_condition: self._upload_id = upload_id @@ -551,9 +575,15 @@ def announce_finished_part(self, etag, part_number): self._parts.append({'ETag': etag, 'PartNumber': part_number}) self._parts_condition.notifyAll() + def announce_total_parts(self, total_parts): + with self._parts_condition: + self._expected_parts = total_parts + self._parts_condition.notifyAll() + def wait_for_parts_to_finish(self): with self._parts_condition: - while len(self._parts) < self._expected_parts: + while self._expected_parts == '...' or \ + len(self._parts) < self._expected_parts: if self._state == self._CANCELLED: raise UploadCancelledError("Upload has been cancelled.") self._parts_condition.wait(timeout=1) @@ -653,9 +683,11 @@ def __init__(self, num_parts, lock=None): lock = threading.Lock() self._lock = lock self._created_condition = threading.Condition(self._lock) + self._submit_write_condition = threading.Condition(self._lock) self._completed_condition = threading.Condition(self._lock) self._state = self._STATES['UNSTARTED'] self._finished_parts = set() + self._current_stream_part_number = 0 def announce_completed_part(self, part_number): with self._completed_condition: @@ -685,6 +717,19 @@ def wait_for_completion(self): "Download has been cancelled.") self._completed_condition.wait(timeout=1) + def wait_for_turn(self, part_number): + with self._submit_write_condition: + while self._current_stream_part_number != part_number: + if self._state == self._STATES['CANCELLED']: + raise DownloadCancelledError( + "Download has been cancelled.") + self._submit_write_condition.wait(timeout=0.2) + + def done_with_turn(self): + with self._submit_write_condition: + self._current_stream_part_number += 1 + self._submit_write_condition.notifyAll() + def cancel(self): with self._lock: self._state = self._STATES['CANCELLED'] diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index eea51a5fbdbc..76cf68b39bd6 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -243,6 +243,21 @@ def uni_print(statement): sys.stdout.write(statement.encode('utf-8')) +def bytes_print(statement): + """ + This function is used to properly write bytes to standard out. + """ + if PY3: + if getattr(sys.stdout, 'buffer', None): + sys.stdout.buffer.write(statement) + else: + # If it is not possible to write to the standard out buffer. + # The next best option is to decode and write to standard out. + sys.stdout.write(statement.decode('utf-8')) + else: + sys.stdout.write(statement) + + def guess_content_type(filename): """Given a filename, guess it's content type. @@ -396,7 +411,8 @@ def __new__(cls, message, error=False, total_parts=None, warning=None): warning) -IORequest = namedtuple('IORequest', ['filename', 'offset', 'data']) +IORequest = namedtuple('IORequest', + ['filename', 'offset', 'data', 'is_stream']) # Used to signal that IO for the filename is finished, and that # any associated resources may be cleaned up. -IOCloseRequest = namedtuple('IOCloseRequest', ['filename']) +IOCloseRequest = namedtuple('IOCloseRequest', ['filename', 'is_stream']) diff --git a/awscli/examples/s3/cp.rst b/awscli/examples/s3/cp.rst index 6bdf25dc0dc1..1fe488bc7751 100644 --- a/awscli/examples/s3/cp.rst +++ b/awscli/examples/s3/cp.rst @@ -101,3 +101,15 @@ Output:: upload: file.txt to s3://mybucket/file.txt +**Uploading a local file stream to S3** + +The following ``cp`` command uploads a local file stream from standard input to a specified bucket and key:: + + aws s3 cp - s3://mybucket/stream.txt + + +**Downloading a S3 object as a local file stream** + +The following ``cp`` command downloads a S3 object locally as a stream to standard output:: + + aws s3 cp s3://mybucket/stream.txt - diff --git a/awscli/testutils.py b/awscli/testutils.py index b6f4bd2abcc0..2f4b55e018c1 100644 --- a/awscli/testutils.py +++ b/awscli/testutils.py @@ -395,7 +395,7 @@ def _escape_quotes(command): def aws(command, collect_memory=False, env_vars=None, - wait_for_finish=True): + wait_for_finish=True, input_data=None): """Run an aws command. This help function abstracts the differences of running the "aws" @@ -421,7 +421,7 @@ def aws(command, collect_memory=False, env_vars=None, else: aws_command = 'python %s' % get_aws_cmd() full_command = '%s %s' % (aws_command, command) - stdout_encoding = _get_stdout_encoding() + stdout_encoding = get_stdout_encoding() if isinstance(full_command, six.text_type) and not six.PY3: full_command = full_command.encode(stdout_encoding) INTEG_LOG.debug("Running command: %s", full_command) @@ -429,13 +429,16 @@ def aws(command, collect_memory=False, env_vars=None, env['AWS_DEFAULT_REGION'] = "us-east-1" if env_vars is not None: env = env_vars - process = Popen(full_command, stdout=PIPE, stderr=PIPE, shell=True, - env=env) + process = Popen(full_command, stdout=PIPE, stderr=PIPE, stdin=PIPE, + shell=True, env=env) if not wait_for_finish: return process memory = None if not collect_memory: - stdout, stderr = process.communicate() + kwargs = {} + if input_data: + kwargs = {'input': input_data} + stdout, stderr = process.communicate(**kwargs) else: stdout, stderr, memory = _wait_and_collect_mem(process) return Result(process.returncode, @@ -444,7 +447,7 @@ def aws(command, collect_memory=False, env_vars=None, memory) -def _get_stdout_encoding(): +def get_stdout_encoding(): encoding = getattr(sys.__stdout__, 'encoding', None) if encoding is None: encoding = 'utf-8' diff --git a/tests/integration/customizations/s3/test_plugin.py b/tests/integration/customizations/s3/test_plugin.py index 65e9d0f42217..93882b7f5598 100644 --- a/tests/integration/customizations/s3/test_plugin.py +++ b/tests/integration/customizations/s3/test_plugin.py @@ -28,7 +28,7 @@ import botocore.session import six -from awscli.testutils import unittest, FileCreator +from awscli.testutils import unittest, FileCreator, get_stdout_encoding from awscli.testutils import aws as _aws from tests.unit.customizations.s3 import create_bucket as _create_bucket from awscli.customizations.s3 import constants @@ -44,12 +44,13 @@ def cd(directory): os.chdir(original) -def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True): +def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True, + input_data=None): if not env_vars: env_vars = os.environ.copy() env_vars['AWS_DEFAULT_REGION'] = "us-west-2" return _aws(command, collect_memory=collect_memory, env_vars=env_vars, - wait_for_finish=wait_for_finish) + wait_for_finish=wait_for_finish, input_data=input_data) class BaseS3CLICommand(unittest.TestCase): @@ -1222,5 +1223,99 @@ def test_sync_file_with_spaces(self): self.assertEqual(p2.rc, 0) +class TestStreams(BaseS3CLICommand): + def test_upload(self): + """ + This tests uploading a small stream from stdin. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + 'This is a test') + + def test_unicode_upload(self): + """ + This tests being able to upload unicode from stdin. + """ + unicode_str = u'\u00e9 This is a test' + byte_str = unicode_str.encode('utf-8') + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=byte_str) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + unicode_str) + + def test_multipart_upload(self): + """ + This tests the ability to multipart upload streams from stdin. + The data has some unicode in it to avoid having to do a seperate + multipart upload test just for unicode. + """ + + bucket_name = self.create_bucket() + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assert_key_contents_equal(bucket_name, 'stream', data) + + def test_download(self): + """ + This tests downloading a small stream from stdout. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, 'This is a test') + + def test_unicode_download(self): + """ + This tests downloading a small unicode stream from stdout. + """ + bucket_name = self.create_bucket() + + data = u'\u00e9 This is a test' + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + + # Downloading the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + def test_multipart_download(self): + """ + This tests the ability to multipart download streams to stdout. + The data has some unicode in it to avoid having to do a seperate + multipart download test just for unicode. + """ + bucket_name = self.create_bucket() + + # First lets upload some data via streaming since + # its faster and we do not have to write to a file! + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + + # Download the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/__init__.py b/tests/unit/customizations/s3/__init__.py index 5ada082edd3e..3216cba61451 100644 --- a/tests/unit/customizations/s3/__init__.py +++ b/tests/unit/customizations/s3/__init__.py @@ -16,7 +16,7 @@ import string import six -from mock import patch +from mock import patch, Mock class S3HandlerBaseTest(unittest.TestCase): @@ -33,7 +33,6 @@ def setUp(self): def tearDown(self): self.wait_timeout_patch.stop() - def make_loc_files(): """ This sets up the test by making a directory named some_directory. It @@ -161,6 +160,7 @@ def compare_files(self, result_file, ref_file): self.assertEqual(result_file.src_type, ref_file.src_type) self.assertEqual(result_file.dest_type, ref_file.dest_type) self.assertEqual(result_file.operation_name, ref_file.operation_name) + self.assertEqual(result_file.is_stream, ref_file.is_stream) def list_contents(bucket, session): @@ -188,3 +188,24 @@ def list_buckets(session): html_response, response_data = operation.call(endpoint) contents = response_data['Buckets'] return contents + + +class MockStdIn(object): + """ + This class patches stdin in order to write a stream of bytes into + stdin. + """ + def __init__(self, input_bytes=b''): + input_data = six.BytesIO(input_bytes) + if six.PY3: + mock_object = Mock() + mock_object.buffer = input_data + else: + mock_object = input_data + self._patch = patch('sys.stdin', mock_object) + + def __enter__(self): + self._patch.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._patch.__exit__() diff --git a/tests/unit/customizations/s3/test_executor.py b/tests/unit/customizations/s3/test_executor.py index 9afaacd3ba22..46eecbae6b63 100644 --- a/tests/unit/customizations/s3/test_executor.py +++ b/tests/unit/customizations/s3/test_executor.py @@ -15,6 +15,7 @@ import shutil import six from six.moves import queue +import sys import mock @@ -41,17 +42,17 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_handles_io_request(self): - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() with open(self.filename, 'rb') as f: self.assertEqual(f.read(), b'foobar') def test_out_of_order_io_requests(self): - self.queue.put(IORequest(self.filename, 6, b'morestuff')) - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 6, b'morestuff', False)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() with open(self.filename, 'rb') as f: @@ -60,10 +61,10 @@ def test_out_of_order_io_requests(self): def test_multiple_files_in_queue(self): second_file = os.path.join(self.temp_dir, 'bar') open(second_file, 'w').close() - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IORequest(second_file, 0, b'otherstuff')) - self.queue.put(IOCloseRequest(second_file)) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IORequest(second_file, 0, b'otherstuff', False)) + self.queue.put(IOCloseRequest(second_file, False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -72,6 +73,21 @@ def test_multiple_files_in_queue(self): with open(second_file, 'rb') as f: self.assertEqual(f.read(), b'otherstuff') + def test_stream_requests(self): + # Test that offset has no affect on the order in which requests + # are written to stdout. The order of requests for a stream are + # first in first out. + self.queue.put(IORequest('nonexistant-file', 10, b'foobar', True)) + self.queue.put(IORequest('nonexistant-file', 6, b'otherstuff', True)) + # The thread should not try to close the file name because it is + # writing to stdout. If it does, the thread will fail because + # the file does not exist. + self.queue.put(IOCloseRequest('nonexistant-file', True)) + self.queue.put(ShutdownThreadRequest()) + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.io_thread.run() + self.assertEqual(mock_stdout.getvalue(), 'foobarotherstuff') + class TestExecutor(unittest.TestCase): def test_shutdown_does_not_hang(self): @@ -84,12 +100,14 @@ class FloodIOQueueTask(object): def __call__(self): for i in range(50): - executor.write_queue.put(IORequest(f.name, 0, b'foobar')) + executor.write_queue.put(IORequest(f.name, 0, + b'foobar', False)) executor.submit(FloodIOQueueTask()) executor.initiate_shutdown() executor.wait_until_shutdown() self.assertEqual(open(f.name, 'rb').read(), b'foobar') + class TestPrintThread(unittest.TestCase): def test_print_warning(self): result_queue = queue.Queue() diff --git a/tests/unit/customizations/s3/test_filegenerator.py b/tests/unit/customizations/s3/test_filegenerator.py index d38a48424ce2..3c199c858954 100644 --- a/tests/unit/customizations/s3/test_filegenerator.py +++ b/tests/unit/customizations/s3/test_filegenerator.py @@ -486,6 +486,24 @@ def test_normalize_sort_backslash(self): self.assertEqual(ref_names[i], names[i]) +class TestLocalStreams(unittest.TestCase): + def test_local_stream(self): + file_input = {'src': {'path': '-', 'type': 'local'}, + 'dest': {'path': 'mybucket/', 'type': 's3'}, + 'dir_op': False, 'use_src_name': True} + file_generator = FileGenerator(None, None, None, is_stream=True) + files = file_generator.call(file_input) + result_list = [] + for file_stat in files: + result_list.append(file_stat) + ref_list = [FileStat(src='-', dest='mybucket/-', compare_key='-', + size=0, last_update=None, src_type='local', + dest_type='s3', operation_name=None, + is_stream=True)] + for i in range(len(result_list)): + compare_files(self, result_list[i], ref_list[i]) + + class S3FileGeneratorTest(unittest.TestCase): def setUp(self): self.session = FakeSession() diff --git a/tests/unit/customizations/s3/test_fileinfo.py b/tests/unit/customizations/s3/test_fileinfo.py index 48a6651f42fb..6a31e3edb1d5 100644 --- a/tests/unit/customizations/s3/test_fileinfo.py +++ b/tests/unit/customizations/s3/test_fileinfo.py @@ -21,6 +21,7 @@ from awscli.testutils import unittest from awscli.customizations.s3 import fileinfo +from awscli.customizations.s3.utils import MD5Error class TestSaveFile(unittest.TestCase): @@ -58,3 +59,16 @@ def test_makedir_other_exception(self, makedirs): fileinfo.save_file(self.filename, self.response_data, self.last_update) self.assertFalse(os.path.isfile(self.filename)) + + def test_stream_file(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + fileinfo.save_file(None, self.response_data, None, True) + self.assertEqual(mock_stdout.getvalue(), "foobar") + + def test_stream_file_md5_error(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.response_data['ETag'] = '"0"' + with self.assertRaises(MD5Error): + fileinfo.save_file(None, self.response_data, None, True) + # Make sure nothing is written to stdout. + self.assertEqual(mock_stdout.getvalue(), "") diff --git a/tests/unit/customizations/s3/test_fileinfobuilder.py b/tests/unit/customizations/s3/test_fileinfobuilder.py index 439c006ad136..1791fd93f571 100644 --- a/tests/unit/customizations/s3/test_fileinfobuilder.py +++ b/tests/unit/customizations/s3/test_fileinfobuilder.py @@ -26,7 +26,8 @@ def test_info_setter(self): files = [FileStat(src='src', dest='dest', compare_key='compare_key', size='size', last_update='last_update', src_type='src_type', dest_type='dest_type', - operation_name='operation_name')] + operation_name='operation_name', + is_stream='is_stream')] file_infos = info_setter.call(files) for file_info in file_infos: attributes = file_info.__dict__.keys() diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 20bc3a62a858..6bf368d45bbb 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -14,15 +14,19 @@ import os import random import sys -from awscli.testutils import unittest +import mock + +from awscli.testutils import unittest from awscli import EnvironmentVariables from awscli.customizations.s3.s3handler import S3Handler from awscli.customizations.s3.fileinfo import FileInfo +from awscli.customizations.s3.tasks import CreateMultipartUploadTask, \ + UploadPartTask, CreateLocalFileTask from tests.unit.customizations.s3.fake_session import FakeSession from tests.unit.customizations.s3 import make_loc_files, clean_loc_files, \ make_s3_files, s3_cleanup, create_bucket, list_contents, list_buckets, \ - S3HandlerBaseTest + S3HandlerBaseTest, MockStdIn class S3HandlerTestDeleteList(S3HandlerBaseTest): @@ -612,5 +616,164 @@ def test_bucket(self): self.assertEqual(orig_number_buckets, number_buckets) +class TestStreams(S3HandlerBaseTest): + def setUp(self): + super(TestStreams, self).setUp() + self.session = FakeSession() + self.service = self.session.get_service('s3') + self.endpoint = self.service.get_endpoint('us-east-1') + self.params = {'is_stream': True, 'region': 'us-east-1'} + stream_timeout = 'awscli.customizations.s3.constants.STREAM_INPUT_TIMEOUT' + self.stream_timeout_patch = mock.patch(stream_timeout, 0.001) + self.stream_timeout_patch.start() + + def tearDown(self): + super(TestStreams, self).tearDown() + self.stream_timeout_patch.stop() + + def test_pull_from_stream(self): + s3handler = S3Handler(self.session, self.params, chunksize=2) + input_to_stdin = b'This is a test' + size = len(input_to_stdin) + # Retrieve the entire string. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin) + # Ensure the function exits when there is nothing to read. + with MockStdIn(): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, b'') + # Ensure the function does not grab too much out of stdin. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size-2) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin[:-2]) + # Retrieve the rest of standard in. + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, input_to_stdin[-2:]) + + def test_upload_stream_not_multipart_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # No multipart upload should have been submitted. + self.assertEqual(len(submitted_tasks), 1) + self.assertEqual(submitted_tasks[0][0][0].payload.read(), + b'bar') + + def test_upload_stream_is_multipart_task(self): + s3handler = S3Handler(self.session, self.params, + multi_threshold=1) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # This should be a multipart upload so multiple tasks + # should have been submitted. + self.assertEqual(len(submitted_tasks), 4) + self.assertEqual(submitted_tasks[1][0][0]._payload.read(), + b'b') + self.assertEqual(submitted_tasks[2][0][0]._payload.read(), + b'ar') + + def test_upload_stream_with_expected_size(self): + self.params['expected_size'] = 100000 + # With this large of expected size, the chunksize of 2 will have + # to change. + s3handler = S3Handler(self.session, self.params, chunksize=2) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + with MockStdIn(b'bar'): + s3handler._enqueue_multipart_upload_tasks(fileinfo, False, b'') + submitted_tasks = s3handler.executor.submit.call_args_list + # Determine what the chunksize was changed to from one of the + # UploadPartTasks. + changed_chunk_size = submitted_tasks[1][0][0]._chunk_size + # New chunksize should have a total parts under 1000. + self.assertTrue(100000/changed_chunk_size < 1000) + + def test_upload_stream_enqueue_upload_start_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + s3handler._enqueue_upload_start_task(None, None, fileinfo, b'foo') + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertEqual(len(submitted_tasks), 2) + self.assertEqual(type(submitted_tasks[0][0][0]), + CreateMultipartUploadTask) + # Check that the initially pulled part of the stream gets submitted + # after the instantiating the CreateMultipartTask. + self.assertEqual(type(submitted_tasks[1][0][0]), + UploadPartTask) + # Check that the payload is correct + self.assertEqual(submitted_tasks[1][0][0]._payload, b'foo') + + def test_upload_stream_enqueue_upload_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + stdin_input = b'This is a test' + with MockStdIn(stdin_input): + num_parts = s3handler._enqueue_upload_tasks(None, 2, mock.Mock(), + fileinfo, + UploadPartTask) + submitted_tasks = s3handler.executor.submit.call_args_list + # Ensure the returned number of parts is correct. + self.assertEqual(num_parts, len(submitted_tasks) + 1) + # Ensure the number of tasks uploaded are as expected + self.assertEqual(len(submitted_tasks), 8) + index = 0 + for i in range(len(submitted_tasks)-1): + self.assertEqual(submitted_tasks[i][0][0]._payload.read(), + stdin_input[index:index+2]) + index += 2 + # Ensure that the last part is an empty string as expected. + self.assertEqual(submitted_tasks[7][0][0]._payload.read(), b'') + + def test_enqueue_upload_single_part_task_stream(self): + """ + This test ensures that a payload gets attached to a task when + it is submitted to the executor. + """ + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + mock_task_class = mock.Mock() + s3handler._enqueue_upload_single_part_task( + part_number=1, chunk_size=2, upload_context=None, + filename=None, task_class=mock_task_class, + payload=b'This is a test' + ) + args, kwargs = mock_task_class.call_args + self.assertIn('payload', kwargs.keys()) + self.assertEqual(kwargs['payload'], b'This is a test') + + def test_enqueue_range_download_tasks_stream(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True, size=100) + s3handler._enqueue_range_download_tasks(fileinfo) + # Ensure that no request was sent to make a file locally. + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertNotEqual(type(submitted_tasks[0][0][0]), + CreateLocalFileTask) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/test_subcommands.py b/tests/unit/customizations/s3/test_subcommands.py index 363fc496889f..92b4e8b1e18f 100644 --- a/tests/unit/customizations/s3/test_subcommands.py +++ b/tests/unit/customizations/s3/test_subcommands.py @@ -197,7 +197,8 @@ def test_run_cp_put(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -213,7 +214,8 @@ def test_error_on_same_line_as_status(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -236,7 +238,8 @@ def test_run_cp_get(self): 'src': s3_file, 'dest': local_file, 'filters': filters, 'paths_type': 's3local', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -253,7 +256,8 @@ def test_run_cp_copy(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -270,7 +274,8 @@ def test_run_mv(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mv', params) cmd_arc.create_instructions() cmd_arc.run() @@ -287,7 +292,8 @@ def test_run_remove(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rm', params) cmd_arc.create_instructions() cmd_arc.run() @@ -308,7 +314,8 @@ def test_run_sync(self): 'src': local_dir, 'dest': s3_prefix, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'sync', params) cmd_arc.create_instructions() cmd_arc.run() @@ -324,7 +331,7 @@ def test_run_mb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mb', params) cmd_arc.create_instructions() cmd_arc.run() @@ -340,7 +347,7 @@ def test_run_rb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -357,7 +364,7 @@ def test_run_rb_nonzero_rc(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -468,6 +475,34 @@ def test_check_force(self): cmd_params.parameters['src'] = 's3://mybucket' cmd_params.check_force(None) + def test_validate_streaming_paths_upload(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['quiet']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_streaming_paths_download(self): + parameters = {'src': 'localfile', 'dest': '-'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['quiet']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_no_streaming_paths(self): + parameters = {'src': 'localfile', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertFalse(cmd_params.parameters['is_stream']) + + def test_validate_streaming_paths_error(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'sync', parameters, '') + with self.assertRaises(ValueError): + cmd_params._validate_streaming_paths() + class HelpDocTest(BaseAWSHelpOutputTest): def setUp(self): diff --git a/tests/unit/customizations/s3/test_tasks.py b/tests/unit/customizations/s3/test_tasks.py index 4451c85cb569..eda16f765778 100644 --- a/tests/unit/customizations/s3/test_tasks.py +++ b/tests/unit/customizations/s3/test_tasks.py @@ -22,6 +22,7 @@ from awscli.customizations.s3.tasks import CompleteDownloadTask from awscli.customizations.s3.tasks import DownloadPartTask from awscli.customizations.s3.tasks import MultipartUploadContext +from awscli.customizations.s3.tasks import MultipartDownloadContext from awscli.customizations.s3.tasks import UploadCancelledError from awscli.customizations.s3.tasks import print_operation from awscli.customizations.s3.tasks import RetriesExeededError @@ -163,6 +164,58 @@ def test_basic_threaded_parts(self): self.calls[2][1:], ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_streaming_threaded_parts(self): + # This is similar to the basic threaded parts test but instead + # the thread has to wait to know exactly how many parts are + # expected from the stream. This is indicated when the expected + # parts of the context changes from ... to an integer. + + self.context = MultipartUploadContext(expected_parts='...') + upload_part_thread = threading.Thread(target=self.upload_part, + args=(1,)) + # Once this thread starts it will immediately block. + self.start_thread(upload_part_thread) + + # Also, let's start the thread that will do the complete + # multipart upload. It will also block because it needs all + # the parts so it's blocked up the upload_part_thread. It also + # needs the upload_id so it's blocked on that as well. + complete_upload_thread = threading.Thread(target=self.complete_upload) + self.start_thread(complete_upload_thread) + + # Then finally the CreateMultipartUpload completes and we + # announce the upload id. + self.create_upload('my_upload_id') + # The complete upload thread should still be waiting for an expect + # parts number. + with self.call_lock: + was_completed = (len(self.calls) > 2) + + # The upload_part thread can now proceed as well as the complete + # multipart upload thread. + self.context.announce_total_parts(1) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # Make sure that the completed task was never called since it was + # waiting to announce the parts. + self.assertFalse(was_completed) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 3) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'create_multipart_upload') + self.assertEqual(self.calls[1][0], 'upload_part') + self.assertEqual(self.calls[2][0], 'complete_upload') + + # Verify the correct args were used. + self.assertEqual(self.calls[0][1], 'my_upload_id') + self.assertEqual(self.calls[1][1:], (1, 'my_upload_id')) + self.assertEqual( + self.calls[2][1:], + ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_randomized_stress_test(self): # Now given that we've verified the functionality from # the two tests above, we randomize the threading to ensure @@ -279,6 +332,7 @@ def setUp(self): self.filename.size = 10 * 1024 * 1024 self.filename.src = 'bucket/key' self.filename.dest = 'local/file' + self.filename.is_stream = False self.filename.service = self.service self.filename.operation_name = 'download' self.context = mock.Mock() @@ -325,9 +379,9 @@ def test_download_queues_io_properly(self): call_args_list = self.io_queue.put.call_args_list self.assertEqual(len(call_args_list), 2) self.assertEqual(call_args_list[0], - mock.call(('local/file', 0, b'foobar'))) + mock.call(('local/file', 0, b'foobar', False))) self.assertEqual(call_args_list[1], - mock.call(('local/file', 6, b'morefoobar'))) + mock.call(('local/file', 6, b'morefoobar', False))) def test_incomplete_read_is_retried(self): self.service.get_operation.return_value.call.side_effect = \ @@ -342,6 +396,61 @@ def test_incomplete_read_is_retried(self): self.service.get_operation.call_count) +class TestMultipartDownloadContext(unittest.TestCase): + def setUp(self): + self.context = MultipartDownloadContext(num_parts=2) + self.calls = [] + self.threads = [] + self.call_lock = threading.Lock() + self.caught_exception = None + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + + def download_stream_part(self, part_number): + try: + self.context.wait_for_turn(part_number) + with self.call_lock: + self.calls.append(('download_part', str(part_number))) + self.context.done_with_turn() + except Exception as e: + self.caught_exception = e + return + + def start_thread(self, thread): + thread.start() + self.threads.append(thread) + + def test_stream_context(self): + part_thread = threading.Thread(target=self.download_stream_part, + args=(1,)) + # Once this thread starts it will immediately block becasue it is + # waiting for part zero to finish submitting its task. + self.start_thread(part_thread) + + # Now create the thread that should submit its task first. + part_thread2 = threading.Thread(target=self.download_stream_part, + args=(0,)) + self.start_thread(part_thread2) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 2) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'download_part') + self.assertEqual(self.calls[1][0], 'download_part') + + # Verify the correct order were used. + self.assertEqual(self.calls[0][1], '0') + self.assertEqual(self.calls[1][1], '1') + + class TestTaskOrdering(unittest.TestCase): def setUp(self): self.q = StablePriorityQueue(maxsize=10, max_priority=20) diff --git a/tests/unit/test_completer.py b/tests/unit/test_completer.py index fc5365b40a69..aab65f01a5d9 100644 --- a/tests/unit/test_completer.py +++ b/tests/unit/test_completer.py @@ -73,7 +73,8 @@ '--cache-control', '--content-type', '--content-disposition', '--source-region', '--content-encoding', '--content-language', - '--expires', '--grants'] + GLOBALOPTS)), + '--expires', '--grants', '--expected-size'] + + GLOBALOPTS)), ('aws s3 cp --quiet -', -1, set(['--no-guess-mime-type', '--dryrun', '--recursive', '--content-type', '--follow-symlinks', '--no-follow-symlinks', @@ -82,7 +83,7 @@ '--expires', '--website-redirect', '--acl', '--storage-class', '--sse', '--exclude', '--include', - '--source-region', + '--source-region','--expected-size', '--grants'] + GLOBALOPTS)), ('aws emr ', -1, set(['add-instance-groups', 'add-steps', 'add-tags', 'create-cluster', 'create-default-roles',