From fc7d8cbb0129bd53f94871d178f40164e535ceeb Mon Sep 17 00:00:00 2001 From: Miles Turin <31150941+milesturin@users.noreply.github.com> Date: Tue, 26 Apr 2022 20:02:45 -0700 Subject: [PATCH] S3 boto3 migration (#198) (#255) Migrated s3 interfaces from awscrt to boto3. Current implementation is single-threaded. --- skylark/cli/cli_helper.py | 4 +- skylark/obj_store/s3_interface.py | 64 ++++++------------------------- skylark/test/test_s3_interface.py | 43 ++++++++++++++------- 3 files changed, 43 insertions(+), 68 deletions(-) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index f079336b9..088a9442f 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -187,8 +187,8 @@ def copy_azure_local(src_account_name: str, src_container_name: str, src_key: st return copy_objstore_local(azure, src_key, dst) -def copy_local_s3(src: Path, dst_bucket: str, dst_key: str, use_tls: bool = True): - s3 = S3Interface(None, dst_bucket, use_tls=use_tls) +def copy_local_s3(src: Path, dst_bucket: str, dst_key: str): + s3 = S3Interface(None, dst_bucket) return copy_local_objstore(s3, src, dst_key) diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index 214079f18..9fa533cee 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -1,17 +1,11 @@ -import mimetypes -import os from typing import Iterator, List import botocore.exceptions -from awscrt.auth import AwsCredentialsProvider -from awscrt.http import HttpHeaders, HttpRequest -from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup -from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from boto3.s3.transfer import TransferConfig from skylark import exceptions from skylark.compute.aws.aws_auth import AWSAuthentication -from skylark.utils import logger - from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject +from skylark.utils import logger class S3Object(ObjectStoreObject): @@ -20,24 +14,13 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=10, num_threads=4): + def __init__(self, aws_region, bucket_name): self.auth = AWSAuthentication() self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region self.bucket_name = bucket_name if not self.bucket_exists(): logger.error("Specified bucket does not exist.") raise exceptions.MissingBucketException() - event_loop_group = EventLoopGroup(num_threads=num_threads, cpu_group=None) - host_resolver = DefaultHostResolver(event_loop_group) - bootstrap = ClientBootstrap(event_loop_group, host_resolver) - self._s3_client = S3Client( - bootstrap=bootstrap, - region=self.aws_region, - credential_provider=AwsCredentialsProvider.new_default_chain(bootstrap), - throughput_target_gbps=throughput_target_gbps, - part_size=part_size, - tls_mode=S3RequestTlsMode.ENABLED if use_tls else S3RequestTlsMode.DISABLED, - ) def region_tag(self): return "aws:" + self.aws_region @@ -65,7 +48,6 @@ def create_bucket(self, premium_tier=True): assert self.bucket_exists() def list_objects(self, prefix="") -> Iterator[S3Object]: - prefix = prefix if not prefix.startswith("/") else prefix[1:] s3_client = self.auth.get_boto3_client("s3", self.aws_region) paginator = s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) @@ -82,7 +64,7 @@ def delete_objects(self, keys: List[str]): def get_obj_metadata(self, obj_name): s3_resource = self.auth.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name) try: - return s3_resource.Object(str(obj_name).lstrip("/")) + return s3_resource.Object(str(obj_name)) except botocore.exceptions.ClientError as e: raise NoSuchObjectException(f"Object {obj_name} does not exist, or you do not have permission to access it") from e @@ -96,36 +78,12 @@ def exists(self, obj_name): except NoSuchObjectException: return False - # todo: implement range request for download def download_object(self, src_object_name, dst_file_path): src_object_name, dst_file_path = str(src_object_name), str(dst_file_path) - src_object_name = "/" + src_object_name if src_object_name[0] != "/" else src_object_name - download_headers = HttpHeaders([("host", self.bucket_name + ".s3." + self.aws_region + ".amazonaws.com")]) - request = HttpRequest("GET", src_object_name, download_headers) - - def _on_body_download(offset, chunk, **kwargs): - if not os.path.exists(dst_file_path): - open(dst_file_path, "a").close() - with open(dst_file_path, "rb+") as f: - f.seek(offset) - f.write(chunk) - - self._s3_client.make_request( - recv_filepath=dst_file_path, - request=request, - type=S3RequestType.GET_OBJECT, - on_body=_on_body_download, - ).finished_future.result() - - def upload_object(self, src_file_path, dst_object_name, content_type="infer"): - src_file_path, dst_object_name = str(src_file_path), str(dst_object_name) - dst_object_name = "/" + dst_object_name if dst_object_name[0] != "/" else dst_object_name - content_len = os.path.getsize(src_file_path) - if content_type == "infer": - content_type = mimetypes.guess_type(src_file_path)[0] or "application/octet-stream" - upload_headers = HttpHeaders() - upload_headers.add("host", self.bucket_name + ".s3." + self.aws_region + ".amazonaws.com") - upload_headers.add("Content-Type", content_type) - upload_headers.add("Content-Length", str(content_len)) - request = HttpRequest("PUT", dst_object_name, upload_headers) - self._s3_client.make_request(send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT).finished_future.result() + s3_client = self.auth.get_boto3_client("s3", self.aws_region) + s3_client.download_file(self.bucket_name, src_object_name, dst_file_path, Config=TransferConfig(use_threads=False)) + + def upload_object(self, src_file_path, dst_object_name): + dst_object_name, src_file_path = str(dst_object_name), str(src_file_path) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) + s3_client.upload_file(src_file_path, self.bucket_name, dst_object_name, Config=TransferConfig(use_threads=False)) diff --git a/skylark/test/test_s3_interface.py b/skylark/test/test_s3_interface.py index b30b45aa2..843da6c7b 100644 --- a/skylark/test/test_s3_interface.py +++ b/skylark/test/test_s3_interface.py @@ -5,36 +5,53 @@ from skylark.obj_store.s3_interface import S3Interface from skylark.utils.utils import Timer +from skylark.utils import logger +def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"): + logger.debug("creating interfaces...") + interface = S3Interface(region, bucket) + assert interface.aws_region == region + assert interface.bucket_name == bucket + interface.create_bucket() -def test_s3_interface(): - s3_interface = S3Interface("us-east-1", "skylark-test-us-east-1", True) - assert s3_interface.aws_region == "us-east-1" - assert s3_interface.bucket_name == "skylark-test-us-east-1" - s3_interface.create_bucket() + debug_time = lambda n, s, e: logger.debug(f"{n} {s}MB in {round(e, 2)}s ({round(s / e, 2)}MB/s)") # generate file and upload - obj_name = "/test.txt" + obj_name = "test.txt" file_size_mb = 128 with tempfile.NamedTemporaryFile() as tmp: fpath = tmp.name - with open(fpath, "wb") as f: + with open(fpath, "rb+") as f: + logger.debug("writing...") f.write(os.urandom(int(file_size_mb * MB))) - file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest() + f.seek(0) + logger.debug("verifying...") + file_md5 = hashlib.md5(f.read()).hexdigest() + logger.debug("uploading...") with Timer() as t: - s3_interface.upload_object(fpath, obj_name) - assert s3_interface.get_obj_size(obj_name) == os.path.getsize(fpath) + interface.upload_object(fpath, obj_name) + debug_time("uploaded", file_size_mb, t.elapsed) + + assert interface.get_obj_size(obj_name) == os.path.getsize(fpath) # download object with tempfile.NamedTemporaryFile() as tmp: fpath = tmp.name if os.path.exists(fpath): os.remove(fpath) + + logger.debug("downloading...") with Timer() as t: - s3_interface.download_object(obj_name, fpath) - assert s3_interface.get_obj_size(obj_name) == os.path.getsize(fpath) + interface.download_object(obj_name, fpath) + debug_time("downloaded", file_size_mb, t.elapsed) + + assert interface.get_obj_size(obj_name) == os.path.getsize(fpath) # check md5 - dl_file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest() + with open(fpath, "rb") as f: + logger.debug("verifying...") + dl_file_md5 = hashlib.md5(f.read()).hexdigest() + assert dl_file_md5 == file_md5 + logger.debug("done.")