Skip to content

Commit

Permalink
S3 boto3 migration (#198) (#255)
Browse files Browse the repository at this point in the history
Migrated s3 interfaces from awscrt to boto3. Current implementation is single-threaded.
  • Loading branch information
milesturin authored Apr 27, 2022
1 parent c13472a commit fc7d8cb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 68 deletions.
4 changes: 2 additions & 2 deletions skylark/cli/cli_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def copy_azure_local(src_account_name: str, src_container_name: str, src_key: st
return copy_objstore_local(azure, src_key, dst)


def copy_local_s3(src: Path, dst_bucket: str, dst_key: str, use_tls: bool = True):
s3 = S3Interface(None, dst_bucket, use_tls=use_tls)
def copy_local_s3(src: Path, dst_bucket: str, dst_key: str):
s3 = S3Interface(None, dst_bucket)
return copy_local_objstore(s3, src, dst_key)


Expand Down
64 changes: 11 additions & 53 deletions skylark/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import mimetypes
import os
from typing import Iterator, List

import botocore.exceptions
from awscrt.auth import AwsCredentialsProvider
from awscrt.http import HttpHeaders, HttpRequest
from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup
from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
from boto3.s3.transfer import TransferConfig
from skylark import exceptions
from skylark.compute.aws.aws_auth import AWSAuthentication
from skylark.utils import logger

from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject
from skylark.utils import logger


class S3Object(ObjectStoreObject):
Expand All @@ -20,24 +14,13 @@ def full_path(self):


class S3Interface(ObjectStoreInterface):
def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=10, num_threads=4):
def __init__(self, aws_region, bucket_name):
self.auth = AWSAuthentication()
self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region
self.bucket_name = bucket_name
if not self.bucket_exists():
logger.error("Specified bucket does not exist.")
raise exceptions.MissingBucketException()
event_loop_group = EventLoopGroup(num_threads=num_threads, cpu_group=None)
host_resolver = DefaultHostResolver(event_loop_group)
bootstrap = ClientBootstrap(event_loop_group, host_resolver)
self._s3_client = S3Client(
bootstrap=bootstrap,
region=self.aws_region,
credential_provider=AwsCredentialsProvider.new_default_chain(bootstrap),
throughput_target_gbps=throughput_target_gbps,
part_size=part_size,
tls_mode=S3RequestTlsMode.ENABLED if use_tls else S3RequestTlsMode.DISABLED,
)

def region_tag(self):
return "aws:" + self.aws_region
Expand Down Expand Up @@ -65,7 +48,6 @@ def create_bucket(self, premium_tier=True):
assert self.bucket_exists()

def list_objects(self, prefix="") -> Iterator[S3Object]:
prefix = prefix if not prefix.startswith("/") else prefix[1:]
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
paginator = s3_client.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
Expand All @@ -82,7 +64,7 @@ def delete_objects(self, keys: List[str]):
def get_obj_metadata(self, obj_name):
s3_resource = self.auth.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name)
try:
return s3_resource.Object(str(obj_name).lstrip("/"))
return s3_resource.Object(str(obj_name))
except botocore.exceptions.ClientError as e:
raise NoSuchObjectException(f"Object {obj_name} does not exist, or you do not have permission to access it") from e

Expand All @@ -96,36 +78,12 @@ def exists(self, obj_name):
except NoSuchObjectException:
return False

# todo: implement range request for download
def download_object(self, src_object_name, dst_file_path):
src_object_name, dst_file_path = str(src_object_name), str(dst_file_path)
src_object_name = "/" + src_object_name if src_object_name[0] != "/" else src_object_name
download_headers = HttpHeaders([("host", self.bucket_name + ".s3." + self.aws_region + ".amazonaws.com")])
request = HttpRequest("GET", src_object_name, download_headers)

def _on_body_download(offset, chunk, **kwargs):
if not os.path.exists(dst_file_path):
open(dst_file_path, "a").close()
with open(dst_file_path, "rb+") as f:
f.seek(offset)
f.write(chunk)

self._s3_client.make_request(
recv_filepath=dst_file_path,
request=request,
type=S3RequestType.GET_OBJECT,
on_body=_on_body_download,
).finished_future.result()

def upload_object(self, src_file_path, dst_object_name, content_type="infer"):
src_file_path, dst_object_name = str(src_file_path), str(dst_object_name)
dst_object_name = "/" + dst_object_name if dst_object_name[0] != "/" else dst_object_name
content_len = os.path.getsize(src_file_path)
if content_type == "infer":
content_type = mimetypes.guess_type(src_file_path)[0] or "application/octet-stream"
upload_headers = HttpHeaders()
upload_headers.add("host", self.bucket_name + ".s3." + self.aws_region + ".amazonaws.com")
upload_headers.add("Content-Type", content_type)
upload_headers.add("Content-Length", str(content_len))
request = HttpRequest("PUT", dst_object_name, upload_headers)
self._s3_client.make_request(send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT).finished_future.result()
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
s3_client.download_file(self.bucket_name, src_object_name, dst_file_path, Config=TransferConfig(use_threads=False))

def upload_object(self, src_file_path, dst_object_name):
dst_object_name, src_file_path = str(dst_object_name), str(src_file_path)
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
s3_client.upload_file(src_file_path, self.bucket_name, dst_object_name, Config=TransferConfig(use_threads=False))
43 changes: 30 additions & 13 deletions skylark/test/test_s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,53 @@

from skylark.obj_store.s3_interface import S3Interface
from skylark.utils.utils import Timer
from skylark.utils import logger

def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"):
logger.debug("creating interfaces...")
interface = S3Interface(region, bucket)
assert interface.aws_region == region
assert interface.bucket_name == bucket
interface.create_bucket()

def test_s3_interface():
s3_interface = S3Interface("us-east-1", "skylark-test-us-east-1", True)
assert s3_interface.aws_region == "us-east-1"
assert s3_interface.bucket_name == "skylark-test-us-east-1"
s3_interface.create_bucket()
debug_time = lambda n, s, e: logger.debug(f"{n} {s}MB in {round(e, 2)}s ({round(s / e, 2)}MB/s)")

# generate file and upload
obj_name = "/test.txt"
obj_name = "test.txt"
file_size_mb = 128
with tempfile.NamedTemporaryFile() as tmp:
fpath = tmp.name
with open(fpath, "wb") as f:
with open(fpath, "rb+") as f:
logger.debug("writing...")
f.write(os.urandom(int(file_size_mb * MB)))
file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest()
f.seek(0)
logger.debug("verifying...")
file_md5 = hashlib.md5(f.read()).hexdigest()

logger.debug("uploading...")
with Timer() as t:
s3_interface.upload_object(fpath, obj_name)
assert s3_interface.get_obj_size(obj_name) == os.path.getsize(fpath)
interface.upload_object(fpath, obj_name)
debug_time("uploaded", file_size_mb, t.elapsed)

assert interface.get_obj_size(obj_name) == os.path.getsize(fpath)

# download object
with tempfile.NamedTemporaryFile() as tmp:
fpath = tmp.name
if os.path.exists(fpath):
os.remove(fpath)

logger.debug("downloading...")
with Timer() as t:
s3_interface.download_object(obj_name, fpath)
assert s3_interface.get_obj_size(obj_name) == os.path.getsize(fpath)
interface.download_object(obj_name, fpath)
debug_time("downloaded", file_size_mb, t.elapsed)

assert interface.get_obj_size(obj_name) == os.path.getsize(fpath)

# check md5
dl_file_md5 = hashlib.md5(open(fpath, "rb").read()).hexdigest()
with open(fpath, "rb") as f:
logger.debug("verifying...")
dl_file_md5 = hashlib.md5(f.read()).hexdigest()

assert dl_file_md5 == file_md5
logger.debug("done.")

0 comments on commit fc7d8cb

Please sign in to comment.