Skip to content

Commit

Permalink
Local gateway testing (#848)
Browse files Browse the repository at this point in the history
This allows for testing transfers locally, both for synthetic data and
real object store data.
  • Loading branch information
sarahwooders authored Jun 12, 2023
1 parent 6d6404e commit 6b0f52b
Show file tree
Hide file tree
Showing 25 changed files with 734 additions and 143 deletions.
29 changes: 15 additions & 14 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ RUN --mount=type=cache,target=/var/cache/apt apt-get update \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

#install HDFS Onprem Packages
RUN apt-get update && \
apt-get install -y openjdk-11-jdk && \
apt-get clean

ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64

RUN wget https://archive.apache.org/dist/hadoop/core/hadoop-3.3.0/hadoop-3.3.0.tar.gz -P /tmp \
&& tar -xzf /tmp/hadoop-3.3.0.tar.gz -C /tmp \
&& mv /tmp/hadoop-3.3.0 /usr/local/hadoop \
&& rm /tmp/hadoop-3.3.0.tar.gz

ENV HADOOP_HOME /usr/local/hadoop
# TODO: uncomment when on-prem is re-enabled
##install HDFS Onprem Packages
#RUN apt-get update && \
# apt-get install -y openjdk-11-jdk && \
# apt-get clean
#
#ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64
#
#RUN wget https://archive.apache.org/dist/hadoop/core/hadoop-3.3.0/hadoop-3.3.0.tar.gz -P /tmp \
# && tar -xzf /tmp/hadoop-3.3.0.tar.gz -C /tmp \
# && mv /tmp/hadoop-3.3.0 /usr/local/hadoop \
# && rm /tmp/hadoop-3.3.0.tar.gz
#
#ENV HADOOP_HOME /usr/local/hadoop

# configure stunnel
RUN mkdir -p /etc/stunnel \
Expand Down Expand Up @@ -47,7 +48,7 @@ RUN (echo 'net.ipv4.ip_local_port_range = 12000 65535' >> /etc/sysctl.conf) \
COPY scripts/requirements-gateway.txt /tmp/requirements-gateway.txt

#Onprem: Install Hostname Resolution for HDFS
COPY scripts/on_prem/hostname /tmp/hostname
#COPY scripts/on_prem/hostname /tmp/hostname

RUN --mount=type=cache,target=/root/.cache/pip pip3 install --no-cache-dir -r /tmp/requirements-gateway.txt && rm -r /tmp/requirements-gateway.txt

Expand Down
45 changes: 27 additions & 18 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional

from skyplane import compute
from skyplane.api.tracker import TransferProgressTracker
from skyplane.exceptions import GatewayContainerStartException
from skyplane.api.tracker import TransferProgressTracker, TransferHook
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
transfer_config: TransferConfig,
log_dir: str,
debug: bool = True,
local: bool = False,
):
"""
:param clientid: the uuid of the local host to create the dataplane
Expand All @@ -70,6 +72,7 @@ def __init__(
self.log_dir = Path(log_dir)
self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S")
self.transfer_dir.mkdir(exist_ok=True, parents=True)
self.local = local

# transfer logs
self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S")
Expand All @@ -88,6 +91,8 @@ def _start_gateway(
gateway_log_dir: Optional[PathLike],
authorize_ssh_pub_key: Optional[str] = None,
e2ee_key_bytes: Optional[str] = None,
container_name: Optional[str] = "skyplane_gateway",
port: Optional[int] = 8081,
):
# map outgoing ports
setup_args = {}
Expand All @@ -106,21 +111,30 @@ def _start_gateway(
if authorize_ssh_pub_key:
gateway_server.copy_public_key(authorize_ssh_pub_key)

# write gateway info file
gateway_info_path = Path(f"{gateway_log_dir}/gateway_info.json")
with open(gateway_info_path, "w") as f:
json.dump(self.topology.get_gateway_info_json(), f, indent=4)
logger.fs.info(f"Writing gateway info to {gateway_info_path}")

# write gateway programs
gateway_program_filename = Path(f"{gateway_log_dir}/gateway_program_{gateway_node.gateway_id}.json")
gateway_program_filename = Path(f"{gateway_log_dir}/gateway_program_{gateway_node.gateway_id}.json".replace(":", "-"))
with open(gateway_program_filename, "w") as f:
f.write(gateway_node.gateway_program.to_json())

# start gateway
gateway_server.start_gateway(
# setup_args,
gateway_docker_image=gateway_docker_image,
gateway_program_path=str(gateway_program_filename),
gateway_info_path=f"{gateway_log_dir}/gateway_info.json",
gateway_program_path=os.path.abspath(str(gateway_program_filename)),
gateway_info_path=os.path.abspath(os.path.join(gateway_log_dir, "gateway_info.json")),
e2ee_key_bytes=None, # TODO: remove
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
local=self.local,
container_name=container_name,
port=port,
)

def provision(
Expand Down Expand Up @@ -206,12 +220,6 @@ def provision(
Path(gateway_program_dir).mkdir(exist_ok=True, parents=True)
logger.fs.info(f"Writing gateway programs to {gateway_program_dir}")

# write gateway info file
gateway_info_path = f"{gateway_program_dir}/gateway_info.json"
with open(gateway_info_path, "w") as f:
json.dump(self.topology.get_gateway_info_json(), f, indent=4)
logger.fs.info(f"Writing gateway info to {gateway_info_path}")

# start gateways in parallel
jobs = []
for node, server in gateway_bound_nodes.items():
Expand All @@ -225,17 +233,18 @@ def provision(
self.copy_gateway_logs()
raise GatewayContainerStartException(f"Error starting gateways. Please check gateway logs {self.transfer_dir}")

def copy_gateway_log(self, instance, container_name: Optional[str] = "skyplane_gateway"):
# copy log from single gateway
out_file = self.transfer_dir / f"gateway_{instance.uuid()}.stdout"
err_file = self.transfer_dir / f"gateway_{instance.uuid()}.stderr"
logger.fs.info(f"[Dataplane.copy_gateway_logs] Copying logs from {instance.uuid()}: {out_file}")
instance.run_command(f"sudo docker logs -t {container_name} 2> /tmp/gateway.stderr > /tmp/gateway.stdout")
instance.download_file("/tmp/gateway.stdout", out_file)
instance.download_file("/tmp/gateway.stderr", err_file)

def copy_gateway_logs(self):
# copy logs from all gateways in parallel
def copy_log(instance):
out_file = self.transfer_dir / f"gateway_{instance.uuid()}.stdout"
err_file = self.transfer_dir / f"gateway_{instance.uuid()}.stderr"
logger.fs.info(f"[Dataplane.copy_gateway_logs] Copying logs from {instance.uuid()}: {out_file}")
instance.run_command("sudo docker logs -t skyplane_gateway 2> /tmp/gateway.stderr > /tmp/gateway.stdout")
instance.download_file("/tmp/gateway.stdout", out_file)
instance.download_file("/tmp/gateway.stderr", err_file)

do_parallel(copy_log, self.bound_nodes.values(), n=-1)
do_parallel(self.copy_gateway_log, self.bound_nodes.values(), n=-1)

def deprovision(self, max_jobs: int = 64, spinner: bool = False):
"""
Expand Down
38 changes: 26 additions & 12 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,7 @@ def transfer_pair_generator(
logger.fs.exception(e)
raise e from None

if dest_provider == "aws":
dest_obj = S3Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "azure":
dest_obj = AzureBlobObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "gcp":
dest_obj = GCSObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "cloudflare":
dest_obj = R2Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
else:
raise ValueError(f"Invalid dest_region {dest_region}, unknown provider")
dest_obj = dst_iface.create_object_repr(dest_key)
dest_objs[dst_iface.region_tag()] = dest_obj

# assert that all destinations share the same post-fix key
Expand Down Expand Up @@ -498,7 +489,7 @@ def dst_prefixes(self) -> List[str]:
@property
def dst_ifaces(self) -> List[StorageInterface]:
"""Return the destination object store interface"""
if not hasattr(self, "_dst_iface"):
if not hasattr(self, "_dst_ifaces"):
if self.transfer_type == "unicast":
provider_dst, bucket_dst, _ = parse_path(self.dst_paths[0])
self._dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst)]
Expand Down Expand Up @@ -646,6 +637,7 @@ def dispatch(
assert Chunk.from_dict(chunk_batch[0].as_dict()) == chunk_batch[0], f"Invalid chunk request: {chunk_batch[0].as_dict}"

# TODO: make async
st = time.time()
reply = self.http_pool.request(
"POST",
f"{server.gateway_api_url}/api/v1/chunk_requests",
Expand All @@ -654,9 +646,10 @@ def dispatch(
)
if reply.status != 200:
raise Exception(f"Failed to dispatch chunk requests {server.instance_name()}: {reply.data.decode('utf-8')}")
et = time.time()
reply_json = json.loads(reply.data.decode("utf-8"))
logger.fs.debug(f"Added {n_added} chunks to server {server}: {reply_json}")
n_added += reply_json["n_added"]
logger.fs.debug(f"Added {n_added} chunks to server {server} in {et-st}: {reply_json}")
queue_size[min_idx] = reply_json["qsize"] # update queue size
# dont try again with some gateway
min_idx = (min_idx + 1) % len(src_gateways)
Expand Down Expand Up @@ -734,6 +727,27 @@ def size_gb(self):
return total_size / 1e9


@dataclass
class TestCopyJob(CopyJob):
# TODO: remove this class (unnecessary since we have TestObjectStore object)

"""Test copy which does not interact with object stores but uses random data generation on gateways"""

def __init__(
self,
src_path: str,
dst_paths: List[str] or str,
recursive: bool = False,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
num_chunks: int = 10,
chunk_size_bytes: int = 1024,
):
super().__init__(src_path, dst_paths, recursive, requester_pays, uuid)
self.num_chunks = num_chunks
self.chunk_size_bytes = chunk_size_bytes


@dataclass
class SyncJob(CopyJob):
"""sync job that copies the source objects that does not exist in the destination bucket to the destination"""
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/cli_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def gcp_check(
rprint(f"[bright_black]GCP Python SDK auth: {auth}[/bright_black]")
check_assert(auth, "GCP Python SDK auth created")
cred = auth.credentials
sa_cred = auth.service_account_credentials
sa_cred = auth.service_account_key_path
if debug:
rprint(f"[bright_black]GCP Python SDK credentials: {cred}[/bright_black]")
rprint(f"[bright_black]GCP Python SDK service account credentials: {sa_cred}[/bright_black]")
Expand Down
6 changes: 3 additions & 3 deletions skyplane/cli/cli_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,9 @@ def init(
cloud_config = load_gcp_config(cloud_config, force_init=reinit_gcp, non_interactive=non_interactive)

# load cloudflare config
if not reinit_cloudflare:
typer.secho("\n(1) Configuring cloudflare R2:", fg="yellow", bold=True)
if not disable_config_cloudflare:
if not reinit_cloudflare: # TODO: fix reinit logic
typer.secho("\n(1) Configuring Cloudflare R2:", fg="yellow", bold=True)
if not disable_config_aws:
cloud_config = load_cloudflare_config(cloud_config, non_interactive=non_interactive)

# load IBMCloud config
Expand Down
22 changes: 9 additions & 13 deletions skyplane/compute/aws/aws_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional
self._access_key = None
self._secret_key = None

# cached credentials
# TODO: fix this as its very messy
self.__cached_credentials = {}

def _get_ec2_vm_quota(self, region) -> Dict[str, int]:
"""Given the region, get the maximum number of vCPU that can be launched.
Expand Down Expand Up @@ -103,22 +107,14 @@ def enabled(self):
@imports.inject("boto3", pip_extra="aws")
def infer_credentials(boto3, self):
# todo load temporary credentials from STS
cached_credential = getattr(self.__cached_credentials, "boto3_credential", None)
if cached_credential is None:
session = boto3.Session()
credentials = session.get_credentials()
if credentials:
credentials = credentials.get_frozen_credentials()
cached_credential = (credentials.access_key, credentials.secret_key)
setattr(self.__cached_credentials, "boto3_credential", cached_credential)
return cached_credential if cached_credential else (None, None)
session = boto3.Session()
credentials = session.get_credentials()
credentials = credentials.get_frozen_credentials()
return credentials.access_key, credentials.secret_key

@imports.inject("boto3", pip_extra="aws")
def get_boto3_session(boto3, self, aws_region: Optional[str] = None):
if self.config_mode == "manual":
return boto3.Session(aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, region_name=aws_region)
else:
return boto3.Session(region_name=aws_region)
return boto3.Session(aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, region_name=aws_region)

def get_boto3_resource(self, service_name, aws_region=None):
return self.get_boto3_session().resource(service_name, region_name=aws_region)
Expand Down
2 changes: 2 additions & 0 deletions skyplane/compute/cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True):
from skyplane.compute.gcp.gcp_cloud_provider import GCPCloudProvider

return GCPCloudProvider.get_transfer_cost(f"gcp:{_}", dst_key, premium_tier)
elif src_provider == "test":
return 0
else:
raise ValueError(f"Unknown provider {src_provider}")

Expand Down
52 changes: 31 additions & 21 deletions skyplane/compute/gcp/gcp_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ def __init__(self, config: Optional[SkyplaneConfig] = None):
else:
self.config = SkyplaneConfig.load_config(config_path)
self._credentials = None
self._service_credentials_file = None
self._service_account_email = None

@imports.inject("googleapiclient.discovery", pip_extra="gcp")
def save_region_config(discovery, self):
if self.project_id is None:
print(
f" No project ID detected when trying to save GCP region list! Consquently, the GCP region list is empty. Run 'skyplane init --reinit-gcp' or file an issue to remedy this."
)
self.clear_region_config()
return
if self.service_account_key_path:
service_json = json.load(open(self.service_account_key_path))
assert "project_id" in service_json, f"Service account file {self.service_account_key_path} does not container project_id"
self.project_id = service_json["project_id"]
else:
print(
f" No project ID detected when trying to save GCP region list! Also no {self.service_account_key_path}. Consquently, the GCP region list is empty. Run 'skyplane init --reinit-gcp' or file an issue to remedy this."
)
self.clear_region_config()
return
with gcp_config_path.open("w") as f:
region_list = []
credentials = self.credentials
service_account_credentials_file = self.service_account_credentials # force creation of file
service_account_key_path = self.get_service_account_key(self.service_account_email) # force creation of file
service = discovery.build("compute", "beta", credentials=credentials)
request = service.zones().list(project=self.project_id)
while request is not None:
Expand Down Expand Up @@ -85,20 +90,20 @@ def credentials(self):
self._credentials, _ = self.get_adc_credential(self.project_id)
return self._credentials

@property
def service_account_credentials(self):
if self._service_credentials_file is None:
self._service_account_email = self.create_service_account(self.service_account_name)
# create service key
self._service_credentials_file = self.get_service_account_key(self._service_account_email)
# @property
# def service_account_credentials(self):
# if self._service_credentials_file is None:
# self._service_account_email = self.create_service_account(self.service_account_name)
# # create service key
# self._service_credentials_file = self.get_service_account_key(self._service_account_email)

return self._service_credentials_file
# return self._service_credentials_file

@property
def project_id(self):
assert (
self.config.gcp_project_id is not None
), "No project ID detected. Run 'skyplane init --reinit-gcp' or file an issue to remedy this."
), f"No project ID detected. Run 'skyplane init --reinit-gcp' or file an issue to remedy this {self.config}."
return self.config.gcp_project_id

@staticmethod
Expand Down Expand Up @@ -141,6 +146,12 @@ def service_account_name(self):
# TODO: append skyplane cleint ID
return self.config.get_flag("gcp_service_account_name")

@property
def service_account_email(self):
if not self._service_account_email:
self._service_account_email = self.create_service_account(self.service_account_name)
return self._service_account_email

@property
def service_account_key_path(self):
if "GCP_SERVICE_ACCOUNT_FILE" in os.environ:
Expand All @@ -150,14 +161,13 @@ def service_account_key_path(self):
key_path = key_root / "gcp" / self.project_id / "service_account_key.json"
return key_path

def get_service_account_key_path(self):
return self.service_account_key_path

def get_service_account_key(self, service_account_email):
service = self.get_gcp_client(service_name="iam")
def get_service_account_key(self, service_account_email: str):
"""Get service account key file for a given service account email."""

# write key file
if not os.path.exists(self.service_account_key_path):
print(f" Creating service account key: {self.service_account_key_path}")
service = self.get_gcp_client(service_name="iam")
# list existing keys
keys = service.projects().serviceAccounts().keys().list(name="projects/-/serviceAccounts/" + service_account_email).execute()

Expand Down Expand Up @@ -256,7 +266,7 @@ def get_gcp_client(discovery, self, service_name="compute", version="v1"):
def get_storage_client(storage, self):
# TODO: cache storage account clinet
# check that storage account works
return storage.Client.from_service_account_json(self.service_account_credentials)
return storage.Client.from_service_account_json(self.service_account_key_path)

def get_gcp_instances(self, gcp_region: str):
return self.get_gcp_client().instances().list(project=self.project_id, zone=gcp_region).execute()
Loading

0 comments on commit 6b0f52b

Please sign in to comment.