Skip to content

Commit

Permalink
Support for Requester Pays (#520)
Browse files Browse the repository at this point in the history
* Requester pay working with devrel-delta-datasets

* More documentation and cleaning up code

* Change to config instead of flag

* Linting for github pipeline

* Fix merge conflict and typing

* Change exception and activate_requester

* Change bolded
  • Loading branch information
HaileyJang authored Sep 14, 2022
1 parent 6a2930b commit f149ef9
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 12 deletions.
6 changes: 6 additions & 0 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import skyplane.cli.usage.definitions
import skyplane.cli.usage.client
from skyplane.cli.usage.client import UsageClient, UsageStatsStatus
from skyplane.obj_store.s3_interface import S3Interface
from skyplane.replicate.replicator_client import ReplicatorClient

import typer
Expand Down Expand Up @@ -133,6 +134,8 @@ def cp(
)
raise typer.Exit(1)

requester_pays: bool = cloud_config.get_flag("requester_pays")

if provider_src == "local" or provider_dst == "local":
typer.secho("Local transfers are not yet supported (but will be soon!)", fg="red", err=True)
typer.secho("Skyplane is currently most optimized for cloud to cloud transfers.", fg="yellow", err=True)
Expand All @@ -149,6 +152,9 @@ def cp(
src_region = src_client.region_tag()
dst_region = dst_client.region_tag()

if requester_pays:
src_client.set_requester_bool(True)

transfer_pairs = generate_full_transferobjlist(
src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst, recursive=recursive
)
Expand Down
12 changes: 8 additions & 4 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import typer
from rich import print as rprint

from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root
from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root, cloud_config
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.obj_store.s3_interface import S3Object
from skyplane.obj_store.s3_interface import S3Interface, S3Object
from skyplane.obj_store.gcs_interface import GCSObject
from skyplane.obj_store.azure_blob_interface import AzureBlobObject
from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob
Expand Down Expand Up @@ -159,11 +159,15 @@ def generate_full_transferobjlist(
source_iface = ObjectStoreInterface.create(source_region, source_bucket)
dest_iface = ObjectStoreInterface.create(dest_region, dest_bucket)

requester_pays = cloud_config.get_flag("requester_pays")
if requester_pays:
source_iface.set_requester_bool(True)

# ensure buckets exist
if not source_iface.bucket_exists():
raise exceptions.MissingBucketException(f"Source bucket {source_bucket} does not exist")
if not dest_iface.bucket_exists():
raise exceptions.MissingBucketException(f"Destination bucket {dest_bucket} does not exist")

source_objs, dest_objs = [], []

# query all source region objects
Expand All @@ -173,7 +177,7 @@ def generate_full_transferobjlist(
source_objs.append(obj)
status.update(f"Querying objects in {source_bucket} (found {len(source_objs)} objects so far)")
if not source_objs:
logger.error("Specified object does not exist.")
logger.error("Specified object does not exist.\n")
raise exceptions.MissingObjectException(f"No objects were found in the specified prefix {source_prefix} in {source_bucket}")

# map objects to destination object paths
Expand Down
2 changes: 2 additions & 0 deletions skyplane/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"gcp_use_premium_network": bool,
"usage_stats": bool,
"gcp_service_account_name": str,
"requester_pays": bool,
}

_DEFAULT_FLAGS = {
Expand All @@ -54,6 +55,7 @@
"gcp_use_premium_network": True,
"usage_stats": True,
"gcp_service_account_name": "skyplane-manual",
"requester_pays": False,
}


Expand Down
2 changes: 1 addition & 1 deletion skyplane/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def pretty_print_str(self):
class MissingBucketException(SkyplaneException):
def pretty_print_str(self):
err = f"[red][bold]:x: MissingBucketException:[/bold] {str(self)}[/red]"
err += "\n[bold][red]Please ensure that the bucket exists and is accessible.[/red][/bold]"
err += "\n[red][bold]Please ensure that the bucket exists and is accessible.[/bold] See https://skyplane.org/en/latest/faq.html#TroubleshootingMissingBucketException.[/red]"
return err


Expand Down
6 changes: 6 additions & 0 deletions skyplane/gateway/gateway_obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from multiprocessing import Event, Manager, Process, Value, Queue
from typing import Dict, Optional

from skyplane import cloud_config
from skyplane.chunk import ChunkRequest
from skyplane.gateway.chunk_store import ChunkStore
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.obj_store.s3_interface import S3Interface
from skyplane.utils import logger
from skyplane.utils.retry import retry_backoff

Expand All @@ -26,6 +28,7 @@ def __init__(self, chunk_store: ChunkStore, error_event, error_queue: Queue, max
self.error_queue = error_queue
self.n_processes = max_conn
self.processes = []
self.src_requester_pays = cloud_config.get_flag("requester_pays")

# shared state
self.manager = Manager()
Expand Down Expand Up @@ -117,6 +120,9 @@ def worker_loop(self, worker_id: int):
logger.debug(f"[obj_store:{self.worker_id}] Start download {chunk_req.chunk.chunk_id} from {bucket}")

obj_store_interface = self.get_obj_store_interface(chunk_req.src_region, bucket)

if self.src_requester_pays:
obj_store_interface.set_requester_bool(True)
md5sum = retry_backoff(
partial(
obj_store_interface.download_object,
Expand Down
3 changes: 3 additions & 0 deletions skyplane/obj_store/object_store_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def path(self) -> str:
def region_tag(self) -> str:
raise NotImplementedError()

def set_requester_bool(self, requester: bool):
return

def create_bucket(self, region_tag: str):
raise NotImplementedError()

Expand Down
26 changes: 19 additions & 7 deletions skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Iterator, List, Optional

import botocore.exceptions
import botocore.client

from skyplane import exceptions
from skyplane.compute.aws.aws_auth import AWSAuthentication
Expand All @@ -21,6 +22,7 @@ def full_path(self):
class S3Interface(ObjectStoreInterface):
def __init__(self, bucket_name: str):
self.auth = AWSAuthentication()
self.requester_pays = False
self.bucket_name = bucket_name

def path(self):
Expand All @@ -43,17 +45,21 @@ def aws_region(self):
def region_tag(self):
return "aws:" + self.aws_region

def set_requester_bool(self, requester: bool):
self.requester_pays = requester

def _s3_client(self, region=None):
region = region if region is not None else self.aws_region
return self.auth.get_boto3_client("s3", region)

def bucket_exists(self):
s3_client = self._s3_client("us-east-1")
try:
s3_client.list_objects_v2(Bucket=self.bucket_name, MaxKeys=1) # list one object to check if bucket exists
requester_pays = {"RequestPayer": "requester"} if self.requester_pays else {}
s3_client.list_objects_v2(Bucket=self.bucket_name, MaxKeys=1, **requester_pays)
return True
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "NoSuchBucket":
if e.response["Error"]["Code"] == "NoSuchBucket" or e.response["Error"]["Code"] == "AccessDenied":
return False
raise e

Expand All @@ -70,7 +76,8 @@ def delete_bucket(self):

def list_objects(self, prefix="") -> Iterator[S3Object]:
paginator = self._s3_client().get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
requester_pays = {"RequestPayer": "requester"} if self.requester_pays else {}
page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, **requester_pays)
for page in page_iterator:
for obj in page.get("Contents", []):
yield S3Object("aws", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"])
Expand Down Expand Up @@ -113,14 +120,19 @@ def download_object(
write_block_size=2**16,
) -> Optional[bytes]:
src_object_name, dst_file_path = str(src_object_name), str(dst_file_path)

s3_client = self._s3_client()
assert len(src_object_name) > 0, f"Source object name must be non-empty: '{src_object_name}'"

args = {"Bucket": self.bucket_name, "Key": src_object_name}

if size_bytes:
byte_range = f"bytes={offset_bytes}-{offset_bytes + size_bytes - 1}"
response = s3_client.get_object(Bucket=self.bucket_name, Key=src_object_name, Range=byte_range)
else:
response = s3_client.get_object(Bucket=self.bucket_name, Key=src_object_name)
args["Range"] = f"bytes={offset_bytes}-{offset_bytes + size_bytes - 1}"

if self.requester_pays:
args["RequestPayer"] = "requester"

response = s3_client.get_object(**args)

# write response data
if not os.path.exists(dst_file_path):
Expand Down

0 comments on commit f149ef9

Please sign in to comment.