From 309b116452ebc43f92ba039b59b151d464b96ec0 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 8 Mar 2022 18:29:48 +0000 Subject: [PATCH 01/34] Squashed commit of the following: commit f6558105942e0f7047c41264760abfb8b2c2803a Merge: 74b659c 5d1a748 Author: Paras Jain Date: Tue Mar 8 18:27:16 2022 +0000 Merge branch 'main' into dev/paras/better_init_config commit 74b659cacb86f821224f3c09e1ed4be82c60651d Author: Paras Jain Date: Tue Mar 8 18:15:53 2022 +0000 Fix notimplementederror commit 1793523a51143b502313b550ffeccce421f6013f Author: Paras Jain Date: Fri Mar 4 07:35:19 2022 +0000 Push commit 75d80ccd0f3438b0513b7e053d069a5c1f8fd113 Merge: fb5bc72 5f0cf91 Author: Paras Jain Date: Thu Mar 3 22:31:39 2022 +0000 Merge branch 'main' into dev/paras/better_init_config commit fb5bc72ac6a70d1a6db4f871502f7a7b702f9e08 Author: Paras Jain Date: Thu Mar 3 22:30:48 2022 +0000 Polished skylark init workflow commit e625dc917986cdf3b1473f4606a0a792e2e25833 Author: Paras Jain Date: Thu Mar 3 05:09:44 2022 +0000 Write chunk status log dataframe to transfer log dir commit dfeb3fad246338a58c1fb87c5942ed6342bbf0e3 Author: Paras Jain Date: Wed Mar 2 20:13:34 2022 +0000 AWS config loader --- skylark/__init__.py | 14 +- skylark/benchmark/network/latency.py | 92 ----------- skylark/benchmark/network/traceroute.py | 2 +- skylark/benchmark/profile_solver.py | 0 .../replicate/benchmark_triangles.py | 102 ------------ skylark/benchmark/replicate/test_direct.py | 136 --------------- skylark/benchmark/stop_all_instances.py | 43 ----- skylark/cli/cli.py | 137 +++------------ skylark/cli/cli_azure.py | 10 +- skylark/cli/cli_helper.py | 156 ++++++++++++++++-- skylark/cli/experiments/throughput.py | 15 +- skylark/compute/aws/aws_cloud_provider.py | 4 +- skylark/compute/azure/azure_cloud_provider.py | 27 ++- skylark/compute/azure/azure_server.py | 19 +-- skylark/compute/gcp/gcp_cloud_provider.py | 7 +- skylark/compute/server.py | 12 +- skylark/config.py | 152 ++++++++++++++--- skylark/obj_store/azure_interface.py | 16 +- skylark/obj_store/gcs_interface.py | 5 +- skylark/obj_store/object_store_interface.py | 16 +- skylark/replicate/replicator_client.py | 11 +- skylark/test/test_replicator_client.py | 7 - 22 files changed, 369 insertions(+), 614 deletions(-) delete mode 100644 skylark/benchmark/network/latency.py delete mode 100644 skylark/benchmark/profile_solver.py delete mode 100644 skylark/benchmark/replicate/benchmark_triangles.py delete mode 100644 skylark/benchmark/replicate/test_direct.py delete mode 100644 skylark/benchmark/stop_all_instances.py diff --git a/skylark/__init__.py b/skylark/__init__.py index 401632eed..b2947778f 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -1,10 +1,20 @@ +import os + from pathlib import Path # paths skylark_root = Path(__file__).parent.parent -key_root = skylark_root / "data" / "keys" -config_file = skylark_root / "data" / "config.json" +config_root = Path("~/.skylark").expanduser() +config_root.mkdir(exist_ok=True) + +if "SKYLARK_CONFIG" in os.environ: + config_path = Path(os.environ["SKYLARK_CONFIG"]).expanduser() +else: + config_path = config_root / "config" + +key_root = config_root / "keys" tmp_log_dir = Path("/tmp/skylark") +tmp_log_dir.mkdir(exist_ok=True) # header def print_header(): diff --git a/skylark/benchmark/network/latency.py b/skylark/benchmark/network/latency.py deleted file mode 100644 index 21bc8d157..000000000 --- a/skylark/benchmark/network/latency.py +++ /dev/null @@ -1,92 +0,0 @@ -import argparse -import json -import re -from typing import List, Tuple - -from skylark.utils import logger -from tqdm import tqdm - -from skylark import skylark_root -from skylark.benchmark.utils import provision -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.aws.aws_server import AWSServer -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.compute.gcp.gcp_server import GCPServer -from skylark.compute.server import Server -from skylark.utils.utils import do_parallel - - -def parse_args(): - aws_regions = AWSCloudProvider.region_list() - gcp_regions = GCPCloudProvider.region_list() - parser = argparse.ArgumentParser(description="Provision EC2 instances") - parser.add_argument("--aws_instance_class", type=str, default="i3en.large", help="Instance class") - parser.add_argument("--aws_region_list", type=str, nargs="+", default=aws_regions) - parser.add_argument("--gcp_instance_class", type=str, default="n1-highcpu-8", help="Instance class") - parser.add_argument("--use-premium-network", action="store_true", help="Use premium network") - parser.add_argument("--gcp_project", type=str, default="bair-commons-307400", help="GCP project") - parser.add_argument("--gcp_region_list", type=str, nargs="+", default=gcp_regions) - return parser.parse_args() - - -def main(args): - data_dir = skylark_root / "data" - log_dir = data_dir / "logs" - log_dir.mkdir(exist_ok=True, parents=True) - - aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) - aws_instances: dict[str, list[AWSServer]] - gcp_instances: dict[str, list[GCPServer]] - aws_instances, gcp_instances = provision( - aws=aws, - gcp=gcp, - aws_regions_to_provision=args.aws_region_list, - gcp_regions_to_provision=args.gcp_region_list, - aws_instance_class=args.aws_instance_class, - gcp_instance_class=args.gcp_instance_class, - ) - instance_list: List[Server] = [i for ilist in aws_instances.values() for i in ilist] - instance_list.extend([i for ilist in gcp_instances.values() for i in ilist]) - - # compute pairwise latency by running ping - def compute_latency(arg_pair: Tuple[Server, Server]) -> str: - instance_src, instance_dst = arg_pair - stdout, stderr = instance_src.run_command(f"ping -c 10 {instance_dst.public_ip()}") - latency_result = stdout.strip().split("\n")[-1] - tqdm.write(f"Latency from {instance_src.region_tag} to {instance_dst.region_tag} is {latency_result}") - return latency_result - - instance_pairs = [(i1, i2) for i1 in instance_list for i2 in instance_list if i1 != i2] - latency_results = do_parallel( - compute_latency, - instance_pairs, - progress_bar=True, - n=24, - desc="Latency", - arg_fmt=lambda x: f"{x[0].region_tag} to {x[1].region_tag}", - ) - - def parse_ping_result(string): - """make regex with named groups""" - try: - regex = r"rtt min/avg/max/mdev = (?P\d+\.\d+)/(?P\d+\.\d+)/(?P\d+\.\d+)/(?P\d+\.\d+) ms" - m = re.search(regex, string) - return dict(min=float(m.group("min")), avg=float(m.group("avg")), max=float(m.group("max")), mdev=float(m.group("mdev"))) - except Exception as e: - logger.exception(e) - return {} - - # save results - latency_results_out = [] - for (i1, i2), r in latency_results: - row = dict(src=i1.region_tag, dst=i2.region_tag, ping_str=r, **parse_ping_result(r)) - logger.info(row) - latency_results_out.append(row) - - with open(str(data_dir / "latency.json"), "w") as f: - json.dump(latency_results_out, f) - - -if __name__ == "__main__": - main(parse_args()) diff --git a/skylark/benchmark/network/traceroute.py b/skylark/benchmark/network/traceroute.py index 7462aa128..7a64e0778 100644 --- a/skylark/benchmark/network/traceroute.py +++ b/skylark/benchmark/network/traceroute.py @@ -46,7 +46,7 @@ def main(args): log_dir.mkdir(exist_ok=True, parents=True) aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) + gcp = GCPCloudProvider() aws_instances, gcp_instances = provision( aws=aws, gcp=gcp, diff --git a/skylark/benchmark/profile_solver.py b/skylark/benchmark/profile_solver.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/skylark/benchmark/replicate/benchmark_triangles.py b/skylark/benchmark/replicate/benchmark_triangles.py deleted file mode 100644 index 53c33c63c..000000000 --- a/skylark/benchmark/replicate/benchmark_triangles.py +++ /dev/null @@ -1,102 +0,0 @@ -import atexit -from datetime import datetime -import pickle -from pathlib import Path - -import typer -from skylark.utils import logger -from skylark import GB, MB, skylark_root -from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology -from skylark.replicate.replicator_client import ReplicatorClient - - -def bench_triangle( - src_region: str, - dst_region: str, - inter_region: str = None, - log_dir: Path = None, - num_gateways: int = 1, - num_outgoing_connections: int = 16, - chunk_size_mb: int = 8, - n_chunks: int = 2048, - gcp_project: str = "skylark-333700", - gateway_docker_image: str = "ghcr.io/parasj/skylark:main", - aws_instance_class: str = "m5.8xlarge", - gcp_instance_class: str = None, - gcp_use_premium_network: bool = False, - key_prefix: str = "/test/benchmark_triangles", -): - if log_dir is None: - log_dir = skylark_root / "data" / "experiments" / "benchmark_triangles" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - log_dir.mkdir(exist_ok=True, parents=True) - result_dir = log_dir / "results" - result_dir.mkdir(exist_ok=True, parents=True) - - try: - if inter_region: - topo = ReplicationTopology() - for i in range(num_gateways): - topo.add_edge(src_region, i, inter_region, i, num_outgoing_connections) - topo.add_edge(inter_region, i, dst_region, i, num_outgoing_connections) - else: - topo = ReplicationTopology() - for i in range(num_gateways): - topo.add_edge(src_region, i, dst_region, i, num_outgoing_connections) - rc = ReplicatorClient( - topo, - gcp_project=gcp_project, - gateway_docker_image=gateway_docker_image, - aws_instance_class=aws_instance_class, - gcp_instance_class=gcp_instance_class, - gcp_use_premium_network=gcp_use_premium_network, - ) - - rc.provision_gateways(reuse_instances=False) - atexit.register(rc.deprovision_gateways) - for node, gw in rc.bound_nodes.items(): - logger.info(f"Provisioned {node}: {gw.gateway_log_viewer_url}") - - job = ReplicationJob( - source_region=src_region, - source_bucket=None, - dest_region=dst_region, - dest_bucket=None, - objs=[f"{key_prefix}/{i}" for i in range(n_chunks)], - random_chunk_size_mb=chunk_size_mb, - ) - - total_bytes = n_chunks * chunk_size_mb * MB - job = rc.run_replication_plan(job) - logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer(job, show_pbar=False, time_limit_seconds=600) - stats["success"] = True - stats["log"] = rc.get_chunk_status_log_df() - rc.deprovision_gateways() - except Exception as e: - logger.error(f"Failed to benchmark triangle {src_region} -> {dst_region}") - logger.exception(e) - - stats = {} - stats["error"] = str(e) - stats["success"] = False - - stats["src_region"] = src_region - stats["dst_region"] = dst_region - stats["inter_region"] = inter_region - stats["num_gateways"] = num_gateways - stats["num_outgoing_connections"] = num_outgoing_connections - stats["chunk_size_mb"] = chunk_size_mb - stats["n_chunks"] = n_chunks - - logger.info(f"Stats:") - for k, v in stats.items(): - if k not in ["log", "completed_chunk_ids"]: - logger.info(f"\t{k}: {v}") - - arg_hash = hash((src_region, dst_region, inter_region, num_gateways, num_outgoing_connections, chunk_size_mb, n_chunks)) - with open(result_dir / f"{arg_hash}.pkl", "wb") as f: - pickle.dump(stats, f) - - -if __name__ == "__main__": - typer.run(bench_triangle) diff --git a/skylark/benchmark/replicate/test_direct.py b/skylark/benchmark/replicate/test_direct.py deleted file mode 100644 index e623a255c..000000000 --- a/skylark/benchmark/replicate/test_direct.py +++ /dev/null @@ -1,136 +0,0 @@ -import argparse -import json -import time -from datetime import datetime - -from skylark.utils import logger - -from skylark import skylark_root -from skylark.benchmark.utils import provision -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.utils.utils import do_parallel - - -def parse_args(): - full_region_list = [] - full_region_list += [f"aws:{r}" for r in AWSCloudProvider.region_list()] - full_region_list += [f"gcp:{r}" for r in GCPCloudProvider.region_list()] - parser = argparse.ArgumentParser(description="Test throughput with Skylark Gateway") - parser.add_argument("--aws_instance_class", type=str, default="c5.4xlarge", help="Instance class") - parser.add_argument("--gcp_instance_class", type=str, default="n1-highcpu-8", help="Instance class") - parser.add_argument("--gcp_project", type=str, default="skylark-333700", help="GCP project") - parser.add_argument( - "--gcp_test_standard_network", action="store_true", help="Test GCP standard network in addition to premium (default)" - ) - parser.add_argument("--src_region", default="aws:us-east-1", choices=full_region_list, help="Source region") - parser.add_argument("--dst_region", default="aws:us-east-2", choices=full_region_list, help="Destination region") - parser.add_argument("--gateway_docker_image", type=str, default="ghcr.io/parasj/skylark:latest", help="Gateway docker image") - return parser.parse_args() - - -def setup(tup): - server, docker_image = tup - server.run_command("sudo apt-get update && sudo apt-get install -y iperf3") - docker_installed = "Docker version" in server.run_command(f"sudo docker --version")[0] - if not docker_installed: - logger.debug(f"[{server.region_tag}] Installing docker") - server.run_command("curl -fsSL https://get.docker.com -o get-docker.sh && sudo sh get-docker.sh") - out, err = server.run_command("sudo docker run --rm hello-world") - assert "Hello from Docker!" in out - server.run_command("sudo docker pull {}".format(docker_image)) - - -def parse_output(output): - stdout, stderr = output - last_line = stdout.strip().split("\n") - if len(last_line) > 0: - try: - return json.loads(last_line[-1]) - except json.decoder.JSONDecodeError: - logger.error(f"JSON parse error, stdout = '{stdout}', stderr = '{stderr}'") - else: - logger.error(f"No output from server, stderr = {stderr}") - return None - - -def main(args): - data_dir = skylark_root / "data" - log_dir = data_dir / "logs" / "gateway_test" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - log_dir.mkdir(exist_ok=True, parents=True) - - aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) - - # provision and setup servers - aws_regions = [r.split(":")[1] for r in [args.src_region, args.dst_region] if r.startswith("aws:")] - gcp_regions = [r.split(":")[1] for r in [args.src_region, args.dst_region] if r.startswith("gcp:")] - aws_instances, gcp_instances = provision( - aws=aws, - gcp=gcp, - aws_regions_to_provision=aws_regions, - gcp_regions_to_provision=gcp_regions, - aws_instance_class=args.aws_instance_class, - gcp_instance_class=args.gcp_instance_class, - gcp_use_premium_network=not args.gcp_test_standard_network, - log_dir=str(log_dir), - ) - - # select servers - src_cloud_region = args.src_region.split(":")[1] - if args.src_region.startswith("aws:"): - src_server = aws_instances[src_cloud_region][0] - elif args.src_region.startswith("gcp:"): - src_server = gcp_instances[src_cloud_region][0] - else: - raise ValueError(f"Unknown region {args.src_region}") - dst_cloud_region = args.dst_region.split(":")[1] - if args.dst_region.startswith("aws:"): - dst_server = aws_instances[dst_cloud_region][0] - elif args.dst_region.startswith("gcp:"): - dst_server = gcp_instances[dst_cloud_region][0] - else: - raise ValueError(f"Unknown region {args.dst_region}") - do_parallel( - setup, - [(src_server, args.gateway_docker_image), (dst_server, args.gateway_docker_image)], - progress_bar=True, - arg_fmt=lambda tup: tup[0].region_tag, - ) - - # generate random 1GB file on src server in /dev/shm/skylark/chunks_in - src_server.run_command("mkdir -p /dev/shm/skylark/chunks_in") - src_server.run_command("sudo dd if=/dev/urandom of=/dev/shm/skylark/chunks_in/1 bs=100M count=10 iflag=fullblock") - assert src_server.run_command("ls /dev/shm/skylark/chunks_in/1 | wc -l")[0].strip() == "1" - - # stop existing gateway containers - src_server.run_command("sudo docker kill gateway_server") - dst_server.run_command("sudo docker kill gateway_server") - - # start gateway on dst server - dst_server.run_command("dig +short myip.opendns.com @resolver1.opendns.com")[0].strip() - server_cmd = f"sudo docker run -d --rm --ipc=host --network=host --name=gateway_server {args.gateway_docker_image} /env/bin/python /pkg/skylark/replicate/gateway_server.py --port 3333 --num_connections 1" - dst_server.run_command(server_cmd) - - # wait for port to appear on dst server - while True: - if dst_server.run_command("sudo netstat -tulpn | grep 3333")[0].strip() != "": - break - time.sleep(1) - - # benchmark src to dst copy - client_cmd = f"sudo docker run --rm --ipc=host --network=host --name=gateway_client {args.gateway_docker_image} /env/bin/python /pkg/skylark/replicate/gateway_client.py --dst_host {dst_server.public_ip} --dst_port 3333 --chunk_id 1" - dst_data = parse_output(src_server.run_command(client_cmd)) - src_server.run_command("sudo docker kill gateway_server") - logger.info(f"Src to dst copy: {dst_data}") - - # run iperf server on dst - out, err = dst_server.run_command("iperf3 -s -D") - out, err = src_server.run_command("iperf3 -c {} -t 10".format(dst_server.public_ip)) - dst_server.run_command("sudo pkill iperf3") - print(out) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/skylark/benchmark/stop_all_instances.py b/skylark/benchmark/stop_all_instances.py deleted file mode 100644 index 0836e3ffd..000000000 --- a/skylark/benchmark/stop_all_instances.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse - -from skylark.utils import logger -from tqdm import tqdm - -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.compute.server import Server -from skylark.utils.utils import do_parallel - - -def stop_instance(instance: Server): - instance.terminate_instance() - tqdm.write(f"Terminated instance {instance}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Stop all instances") - parser.add_argument("--disable-aws", action="store_true", help="Disables AWS operations if present") - parser.add_argument("--gcp-project", type=str, help="GCP project", default=None) - parser.add_argument("--azure-subscription", type=str, help="Microsoft Azure Subscription", default=None) - args = parser.parse_args() - - instances = [] - - if not args.disable_aws: - logger.info("Getting matching AWS instances") - aws = AWSCloudProvider() - for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True): - instances += instance_list - - if args.gcp_project: - logger.info("Getting matching GCP instances") - gcp = GCPCloudProvider(gcp_project=args.gcp_project) - instances += gcp.get_matching_instances() - - if args.azure_subscription: - logger.info("Getting matching Azure instances") - azure = AzureCloudProvider(azure_subscription=args.azure_subscription) - instances += azure.get_matching_instances() - - do_parallel(stop_instance, instances, progress_bar=True) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 8b4439ddd..c6b1a4aaf 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -17,15 +17,17 @@ import json import os from pathlib import Path +import pprint from typing import Optional +import boto3 import skylark.cli.cli_aws import skylark.cli.cli_azure import skylark.cli.cli_solver import skylark.cli.experiments import typer from skylark.utils import logger -from skylark import GB, MB, config_file, print_header +from skylark import config_path, GB, MB, print_header from skylark.cli.cli_helper import ( check_ulimit, copy_azure_local, @@ -38,11 +40,14 @@ copy_gcs_local, copy_local_gcs, deprovision_skylark_instances, + load_aws_config, + load_azure_config, + load_gcp_config, ls_local, ls_s3, parse_path, ) -from skylark.config import load_config +from skylark.config import SkylarkConfig, load_config from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology from skylark.replicate.replicator_client import ReplicatorClient from skylark.obj_store.object_store_interface import ObjectStoreInterface @@ -57,10 +62,6 @@ @app.command() def ls(directory: str): """List objects in the object store.""" - config = load_config() - gcp_project = config.get("gcp_project_id") - azure_subscription = config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() provider, bucket, key = parse_path(directory) if provider == "local": @@ -75,10 +76,6 @@ def ls(directory: str): def cp(src: str, dst: str): """Copy objects from the object store to the local filesystem.""" print_header() - config = load_config() - gcp_project = config.get("gcp_project_id") - azure_subscription = config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() provider_src, bucket_src, path_src = parse_path(src) @@ -114,8 +111,6 @@ def replicate_random( total_transfer_size_mb: int = typer.Option(2048, "--size-total-mb", "-s", help="Total transfer size in MB."), chunk_size_mb: int = typer.Option(8, "--chunk-size-mb", help="Chunk size in MB."), reuse_gateways: bool = False, - azure_subscription: Optional[str] = None, - gcp_project: Optional[str] = None, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), aws_instance_class: str = "m5.8xlarge", azure_instance_class: str = "Standard_D32_v5", @@ -127,10 +122,6 @@ def replicate_random( ): """Replicate objects from remote object store to another remote object store.""" print_header() - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() if inter_region: @@ -147,8 +138,6 @@ def replicate_random( rc = ReplicatorClient( topo, - azure_subscription=azure_subscription, - gcp_project=gcp_project, gateway_docker_image=gateway_docker_image, aws_instance_class=aws_instance_class, azure_instance_class=azure_instance_class, @@ -202,8 +191,6 @@ def replicate_json( reuse_gateways: bool = False, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), # cloud provider specific options - azure_subscription: Optional[str] = None, - gcp_project: Optional[str] = None, aws_instance_class: str = "m5.8xlarge", azure_instance_class: str = "Standard_D32_v5", gcp_instance_class: Optional[str] = "n2-standard-32", @@ -214,10 +201,6 @@ def replicate_json( ): """Replicate objects from remote object store to another remote object store.""" print_header() - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() with path.open("r") as f: @@ -225,8 +208,6 @@ def replicate_json( rc = ReplicatorClient( topo, - azure_subscription=azure_subscription, - gcp_project=gcp_project, gateway_docker_image=gateway_docker_image, aws_instance_class=aws_instance_class, azure_instance_class=azure_instance_class, @@ -291,94 +272,30 @@ def replicate_json( @app.command() -def deprovision(azure_subscription: Optional[str] = None, gcp_project: Optional[str] = None): +def deprovision(): """Deprovision gateways.""" - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded from config file: gcp_project={gcp_project}, azure_subscription={azure_subscription}") - deprovision_skylark_instances(azure_subscription=azure_subscription, gcp_project_id=gcp_project) + deprovision_skylark_instances() @app.command() -def init( - azure_tenant_id: str = typer.Option(None, envvar="AZURE_TENANT_ID", prompt="Azure tenant ID"), - azure_client_id: str = typer.Option(None, envvar="AZURE_CLIENT_ID", prompt="Azure client ID"), - azure_client_secret: str = typer.Option(None, envvar="AZURE_CLIENT_SECRET", prompt="Azure client secret"), - azure_subscription_id: str = typer.Option(None, envvar="AZURE_SUBSCRIPTION_ID", prompt="Azure subscription ID"), - gcp_application_credentials_file: Path = typer.Option( - None, - envvar="GOOGLE_APPLICATION_CREDENTIALS", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - help="Path to GCP application credentials file (usually a JSON file)", - ), - gcp_project: str = typer.Option(None, envvar="GCP_PROJECT_ID", prompt="GCP project ID"), -): - out_config = {} - - # AWS config - def load_aws_credentials(): - if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: - return os.environ["AWS_ACCESS_KEY_ID"], os.environ["AWS_SECRET_ACCESS_KEY"] - if (Path.home() / ".aws" / "credentials").exists(): - with open(Path.home() / ".aws" / "credentials") as f: - access_key, secret_key = None, None - lines = f.readlines() - for line in lines: - if line.startswith("aws_access_key_id"): - access_key = line.split("=")[1].strip() - if line.startswith("aws_secret_access_key"): - secret_key = line.split("=")[1].strip() - if access_key and secret_key: - return access_key, secret_key - return None, None - - aws_access_key, aws_secret_key = load_aws_credentials() - if aws_access_key is None: - aws_access_key = typer.prompt("AWS access key") - assert aws_access_key is not None and aws_access_key != "" - if aws_secret_key is None: - aws_secret_key = typer.prompt("AWS secret key") - assert aws_secret_key is not None and aws_secret_key != "" - - if config_file.exists(): - typer.confirm("Config file already exists. Overwrite?", abort=True) - - out_config["aws_access_key_id"] = aws_access_key - out_config["aws_secret_access_key"] = aws_secret_key - - # Azure config - typer.secho("Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg=typer.colors.GREEN) - if azure_tenant_id is not None or len(azure_tenant_id) > 0: - logger.info(f"Setting Azure tenant ID to {azure_tenant_id}") - out_config["azure_tenant_id"] = azure_tenant_id - if azure_client_id is not None or len(azure_client_id) > 0: - logger.info(f"Setting Azure client ID to {azure_client_id}") - out_config["azure_client_id"] = azure_client_id - if azure_client_secret is not None or len(azure_client_secret) > 0: - logger.info(f"Setting Azure client secret to {azure_client_secret}") - out_config["azure_client_secret"] = azure_client_secret - if azure_subscription_id is not None or len(azure_subscription_id) > 0: - logger.info(f"Setting Azure subscription ID to {azure_subscription_id}") - out_config["azure_subscription_id"] = azure_subscription_id - - # GCP config - if gcp_application_credentials_file is not None and gcp_application_credentials_file.exists(): - logger.info(f"Setting GCP application credentials file to {gcp_application_credentials_file}") - out_config["gcp_application_credentials_file"] = str(gcp_application_credentials_file) - if gcp_project is not None or len(gcp_project) > 0: - logger.info(f"Setting GCP project ID to {gcp_project}") - out_config["gcp_project_id"] = gcp_project - - # write to config file - config_file.parent.mkdir(parents=True, exist_ok=True) - with config_file.open("w") as f: - json.dump(out_config, f) - typer.secho(f"Config: {out_config}", fg=typer.colors.GREEN) - typer.secho(f"Wrote config to {config_file}", fg=typer.colors.GREEN) +def init(reinit_aws: bool = False, reinit_azure: bool = False, reinit_gcp: bool = False): + print_header() + config = SkylarkConfig.load() + + # load AWS config + typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) + config = load_aws_config(config, force_init=reinit_aws) + + # load Azure config + typer.secho("\n(2) Configuring Azure:", fg="yellow", bold=True) + config = load_azure_config(config, force_init=reinit_azure) + + # load GCP config + typer.secho("\n(3) Configuring GCP:", fg="yellow", bold=True) + config = load_gcp_config(config, force_init=reinit_gcp) + + config.to_config_file(config_path) + typer.secho(f"\nConfig file saved to {config_path}", fg="green") return 0 diff --git a/skylark/cli/cli_azure.py b/skylark/cli/cli_azure.py index 8f21c109f..778d9a931 100644 --- a/skylark/cli/cli_azure.py +++ b/skylark/cli/cli_azure.py @@ -9,7 +9,7 @@ import typer from azure.identity import DefaultAzureCredential from azure.mgmt.compute import ComputeManagementClient -from skylark.config import load_config +from skylark.config import SkylarkConfig from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.utils.utils import do_parallel @@ -18,20 +18,16 @@ @app.command() def get_valid_skus( - azure_subscription: str = typer.Option("", "--azure-subscription", help="Azure subscription ID"), regions: List[str] = typer.Option(AzureCloudProvider.region_list(), "--regions", "-r"), prefix: str = typer.Option("", "--prefix", help="Filter by prefix"), top_k: int = typer.Option(-1, "--top-k", help="Print top k entries"), ): - config = load_config() - azure_subscription = azure_subscription or config.get("azure_subscription_id") - typer.secho(f"Loaded from config file: azure_subscription={azure_subscription}", fg="blue") - + config = SkylarkConfig.load() credential = DefaultAzureCredential() # query azure API for each region to get available SKUs for each resource type def get_skus(region): - client = ComputeManagementClient(credential, azure_subscription) + client = ComputeManagementClient(credential, config.azure_subscription_id) valid_skus = [] for sku in client.resource_skus.list(filter="location eq '{}'".format(region)): if len(sku.restrictions) == 0 and (not prefix or sku.name.startswith(prefix)): diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 2875c12bc..9f37adadb 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -7,7 +7,9 @@ from shutil import copyfile from typing import Dict, List, Optional +import boto3 import typer +from skylark.config import SkylarkConfig from skylark.utils import logger from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider @@ -241,7 +243,8 @@ def check_ulimit(hard_limit=1024 * 1024, soft_limit=1024 * 1024): subprocess.check_output(increase_soft_limit) -def deprovision_skylark_instances(azure_subscription: Optional[str] = None, gcp_project_id: Optional[str] = None): +def deprovision_skylark_instances(): + config = SkylarkConfig.load() instances = [] aws = AWSCloudProvider() @@ -250,22 +253,149 @@ def deprovision_skylark_instances(azure_subscription: Optional[str] = None, gcp_ ): instances += instance_list - if not azure_subscription: - typer.secho( - "No Microsoft Azure subscription given, so Azure instances will not be terminated", color=typer.colors.YELLOW, bold=True - ) - else: - azure = AzureCloudProvider(azure_subscription=azure_subscription) + if config.azure_enabled: + azure = AzureCloudProvider() instances += azure.get_matching_instances() - if not gcp_project_id: - typer.secho("No GCP project ID given, so GCP instances will not be deprovisioned", color=typer.colors.YELLOW, bold=True) - else: - gcp = GCPCloudProvider(gcp_project=gcp_project_id) + if config.gcp_enabled: + gcp = GCPCloudProvider() instances += gcp.get_matching_instances() if instances: - typer.secho(f"Deprovisioning {len(instances)} instances", color=typer.colors.YELLOW, bold=True) + typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True) do_parallel(lambda instance: instance.terminate_instance(), instances, progress_bar=True, desc="Deprovisioning") else: - typer.secho("No instances to deprovision, exiting...", color=typer.colors.YELLOW, bold=True) + typer.secho("No instances to deprovision, exiting...", fg="yellow", bold=True) + + +def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: + if force_init: + typer.secho(" AWS credentials will be re-initialized", fg="red") + config.aws_enabled = False + config.aws_access_key_id = None + config.aws_secret_access_key = None + + if config.aws_enabled and config.aws_access_key_id is not None and config.aws_secret_access_key is not None: + typer.secho(" AWS credentials already configured! To reconfigure AWS, run `skylark init --reinit-aws`.", fg="blue") + return config + + # get AWS credentials from boto3 + session = boto3.Session() + credentials = session.get_credentials() + credentials = credentials.get_frozen_credentials() + if credentials.access_key is None or credentials.secret_key is None: + typer.secho(" AWS credentials not found in boto3 session, please use the AWS CLI to set them via `aws configure`", fg="red") + typer.secho(" https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html", fg="red") + typer.secho(" Disabling AWS support", fg="blue") + config.aws_enabled = False + config.aws_access_key_id = None + config.aws_secret_access_key = None + return config + + typer.secho(f" Loaded AWS credentials from the AWS CLI [IAM access key ID: ...{credentials.access_key[-6:]}]", fg="blue") + config.aws_enabled = True + config.aws_access_key_id = credentials.access_key + config.aws_secret_access_key = credentials.secret_key + return config + + +def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: + if force_init: + typer.secho(" Azure credentials will be re-initialized", fg="red") + config.azure_enabled = False + config.azure_tenant_id = None + config.azure_client_id = None + config.azure_client_secret = None + config.azure_subscription_id = None + + if ( + config.azure_enabled + and config.azure_tenant_id is not None + and config.azure_client_id is not None + and config.azure_client_secret is not None + and not force_init + ): + typer.secho(" Azure credentials already configured! To reconfigure Azure, run `skylark init --reinit-azure`.", fg="blue") + return config + + # get Azure credentials from Azure default credential provider + azure_tenant_id = os.environ.get("AZURE_TENANT_ID", config.azure_tenant_id) + azure_client_id = os.environ.get("AZURE_CLIENT_ID", config.azure_client_id) + azure_client_secret = os.environ.get("AZURE_CLIENT_SECRET", config.azure_client_secret) + azure_subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", config.azure_subscription_id) + + # prompt for missing credentials + if not azure_tenant_id or not azure_client_id or not azure_client_secret or not azure_subscription_id: + typer.secho( + " Azure credentials not found in environment variables, please use the Azure CLI to set them via `az login`", fg="red" + ) + typer.secho(" Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg="red") + if not typer.confirm(" Do you want to manually enter your service principal keys?", default=False): + typer.secho(" Disabling Azure support in Skylark", fg="blue") + config.azure_enabled = False + config.azure_tenant_id = None + config.azure_client_id = None + config.azure_client_secret = None + return config + + if not azure_tenant_id: + azure_tenant_id = typer.prompt(" Azure tenant ID") + if not azure_client_id: + azure_client_id = typer.prompt(" Azure client ID") + if not azure_client_secret: + azure_client_secret = typer.prompt(" Azure client secret") + if not azure_subscription_id: + azure_subscription_id = typer.prompt(" Azure subscription ID") + + config.azure_enabled = True + config.azure_tenant_id = azure_tenant_id + config.azure_client_id = azure_client_id + config.azure_client_secret = azure_client_secret + config.azure_subscription_id = azure_subscription_id + return config + + +def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: + if force_init: + typer.secho(" GCP credentials will be re-initialized", fg="red") + config.gcp_enabled = False + config.gcp_application_credentials_file = None + config.gcp_project_id = None + + if config.gcp_enabled and config.gcp_project_id is not None and config.gcp_application_credentials_file is not None: + typer.secho(" GCP already configured! To reconfigure GCP, run `skylark init --reinit-gcp`.", fg="blue") + return config + + # load from environment variables + gcp_application_credentials_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", config.gcp_application_credentials_file) + if not gcp_application_credentials_file: + typer.secho( + " GCP credentials not found in environment variables, please use the GCP CLI to set them via `gcloud auth application-default login`", + fg="red", + ) + typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") + if not typer.confirm(" Do you want to manually enter your service account key?", default=False): + typer.secho(" Disabling GCP support in Skylark", fg="blue") + config.gcp_enabled = False + config.gcp_project_id = None + config.gcp_application_credentials_file = None + return config + gcp_application_credentials_file = typer.prompt(" GCP application credentials file path") + + # check if the file exists + gcp_application_credentials_file = Path(gcp_application_credentials_file).expanduser().resolve() + if not gcp_application_credentials_file.exists(): + typer.secho(f" GCP application credentials file not found at {gcp_application_credentials_file}", fg="red") + typer.secho(" Disabling GCP support in Skylark", fg="blue") + config.gcp_enabled = False + config.gcp_project_id = None + config.gcp_application_credentials_file = None + return config + + config.gcp_enabled = True + config.gcp_application_credentials_file = str(gcp_application_credentials_file) + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", config.gcp_project_id) + if not project_id: + project_id = typer.prompt(" GCP project ID") + config.gcp_project_id = project_id + return config diff --git a/skylark/cli/experiments/throughput.py b/skylark/cli/experiments/throughput.py index ff2d65636..5a85dd573 100644 --- a/skylark/cli/experiments/throughput.py +++ b/skylark/cli/experiments/throughput.py @@ -10,7 +10,7 @@ import typer from skylark import GB, skylark_root from skylark.benchmark.utils import provision, split_list -from skylark.config import load_config +from skylark.config import SkylarkConfig from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider @@ -81,9 +81,6 @@ def throughput_grid( aws_instance_class: str = typer.Option("m5.8xlarge", help="AWS instance class to use"), azure_instance_class: str = typer.Option("Standard_D32_v5", help="Azure instance class to use"), gcp_instance_class: str = typer.Option("n2-standard-32", help="GCP instance class to use"), - # cloud options - gcp_project: Optional[str] = None, - azure_subscription: Optional[str] = None, # iperf3 options iperf3_runtime: int = typer.Option(5, help="Runtime for iperf3 in seconds"), iperf3_connections: int = typer.Option(64, help="Number of connections to test"), @@ -91,10 +88,8 @@ def throughput_grid( def check_stderr(tup): assert tup[1].strip() == "", f"Command failed, err: {tup[1]}" - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded from config file: gcp_project={gcp_project}, azure_subscription={azure_subscription}") + config = SkylarkConfig.load() + assert config.aws_enabled and config.azure_enabled and config.gcp_enabled, "All cloud providers must be enabled." if resume: index_key = [ @@ -151,8 +146,8 @@ def check_stderr(tup): # provision servers aws = AWSCloudProvider() - azure = AzureCloudProvider(azure_subscription) - gcp = GCPCloudProvider(gcp_project) + azure = AzureCloudProvider() + gcp = GCPCloudProvider() aws_instances, azure_instances, gcp_instances = provision( aws=aws, azure=azure, diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 16f820653..846fba27d 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -66,10 +66,8 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): src_rows = transfer_df.loc[src] src_rows = src_rows[src_rows.index != "internet"] return src_rows.max()["cost"] - elif dst_provider == "gcp" or dst_provider == "azure": - return transfer_df.loc[src, "internet"]["cost"] else: - raise NotImplementedError + return transfer_df.loc[src, "internet"]["cost"] def get_instance_list(self, region: str) -> List[AWSServer]: ec2 = AWSServer.get_boto3_resource("ec2", region) diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index f650fec4c..b413bb6f9 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -5,7 +5,7 @@ from typing import List, Optional import paramiko -from skylark.config import load_config +from skylark.config import SkylarkConfig from skylark.utils import logger from skylark import key_root from skylark.compute.azure.azure_server import AzureServer @@ -20,19 +20,13 @@ class AzureCloudProvider(CloudProvider): - def __init__(self, azure_subscription, key_root=key_root / "azure", read_credential=True): + def __init__(self, key_root=key_root / "azure"): super().__init__() - if read_credential: - config = load_config() - self.subscription_id = azure_subscription if azure_subscription is not None else config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) - else: - self.credential = DefaultAzureCredential() - self.subscription_id = azure_subscription + config = SkylarkConfig().load() + assert config.azure_enabled, "Azure cloud provider is not enabled in the config file." + self.credential = DefaultAzureCredential() + self.subscription_id = config.azure_subscription_id + key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "azure_key" self.public_key_path = key_root / "azure_key.pub" @@ -213,6 +207,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): dst_provider, dst_region = dst_key.split(":") assert src_provider == "azure" if not premium_tier: + # TODO: tracked in https://github.com/parasj/skylark/issues/59 return NotImplementedError() src_continent = AzureCloudProvider.lookup_continent(src_region) @@ -271,7 +266,7 @@ def get_instance_list(self, region: str) -> List[AzureServer]: for vm in compute_client.virtual_machines.list(AzureServer.resource_group_name): if vm.tags.get("skylark", None) == "true" and AzureServer.is_valid_vm_name(vm.name) and vm.location == region: name = AzureServer.base_name_from_vm_name(vm.name) - s = AzureServer(self.subscription_id, name) + s = AzureServer(name) if s.is_valid(): server_list.append(s) else: @@ -302,7 +297,7 @@ def set_up_resource_group(self, clean_up_orphans=True): for vnet in network_client.virtual_networks.list(AzureServer.resource_group_name): if vnet.tags.get("skylark", None) == "true" and AzureServer.is_valid_vnet_name(vnet.name): name = AzureServer.base_name_from_vnet_name(vnet.name) - s = AzureServer(self.subscription_id, name, assume_exists=False) + s = AzureServer(name, assume_exists=False) if not s.is_valid(): logger.warning(f"Cleaning up orphaned Azure resources for {name}...") instances_to_terminate.append(s) @@ -445,4 +440,4 @@ def provision_instance( ) poller.result() - return AzureServer(self.subscription_id, name) + return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 716dd91fc..1f4f5fd3e 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -4,7 +4,7 @@ import paramiko from skylark import key_root -from skylark.config import load_config +from skylark.config import SkylarkConfig, load_config from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike @@ -22,25 +22,16 @@ class AzureServer(Server): def __init__( self, - subscription_id: Optional[str], name: str, key_root: PathLike = key_root / "azure", log_dir=None, ssh_private_key=None, - read_credential=True, assume_exists=True, ): - if read_credential: - config = load_config() - self.subscription_id = subscription_id if subscription_id is not None else config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) - else: - self.credential = DefaultAzureCredential() - self.subscription_id = subscription_id + config = SkylarkConfig.load() + assert config.azure_enabled, "Azure is not enabled in the config" + self.credential = DefaultAzureCredential() + self.subscription_id = config.azure_subscription_id self.name = name self.location = None diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 49ba12a26..8a5fd168b 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -6,6 +6,7 @@ import googleapiclient import paramiko +from skylark.config import SkylarkConfig from skylark.utils import logger from oslo_concurrency import lockutils @@ -17,9 +18,11 @@ class GCPCloudProvider(CloudProvider): - def __init__(self, gcp_project, key_root=key_root / "gcp"): + def __init__(self, key_root=key_root / "gcp"): super().__init__() - self.gcp_project = gcp_project + config = SkylarkConfig.load() + assert config.gcp_enabled, "GCP is not enabled in the config" + self.gcp_project = config.gcp_project_id key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "gcp-cert.pem" self.public_key_path = key_root / "gcp-cert.pub" diff --git a/skylark/compute/server.py b/skylark/compute/server.py index f95291738..739f046a0 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -5,16 +5,13 @@ from pathlib import Path from typing import Dict import requests -from skylark import config_file from skylark.utils import logger from skylark.compute.utils import make_dozzle_command, make_sysctl_tcp_tuning_command from skylark.utils.utils import PathLike, Timer, retry_backoff, wait_for -import configparser +from skylark import config_path import os -from skylark import config_file - class ServerState(Enum): PENDING = auto() @@ -210,9 +207,9 @@ def check_stderr(tup): # read AWS config file to get credentials # TODO: Integrate this with updated skylark config file # copy config file - config = config_file.read_text()[:-2] + "}" + config = config_path.read_text()[:-2] + "}" config = json.dumps(config) # Convert to JSON string and remove trailing comma/new-line - self.run_command(f'mkdir -p /opt; echo "{config}" | sudo tee /opt/{config_file.name} > /dev/null') + self.run_command(f'mkdir -p /opt; echo "{config}" | sudo tee /opt/{config_path.name} > /dev/null') docker_envs = "" # If needed, add environment variables to docker command @@ -224,7 +221,8 @@ def check_stderr(tup): f"-d --rm --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024} {docker_envs}" ) docker_run_flags += " --mount type=tmpfs,dst=/skylark,tmpfs-size=$(($(free -b | head -n2 | tail -n1 | awk '{print $2}')/2))" - docker_run_flags += f" -v /opt/{config_file.name}:/pkg/data/{config_file.name}" + docker_run_flags += f" -v /opt/{config_path.name}:/pkg/data/{config_path.name}" + docker_run_flags += f" -e SKYLARK_CONFIG=/pkg/data/{config_path.name}" gateway_daemon_cmd = f"python -u /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /skylark/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" start_out, start_err = self.run_command(docker_launch_cmd) diff --git a/skylark/config.py b/skylark/config.py index 1254d0c04..b0b4f44c4 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -1,30 +1,134 @@ -import json +from dataclasses import dataclass +import functools import os -from skylark import config_file +from pathlib import Path +from typing import Optional from skylark.utils import logger +import configparser + +from skylark import config_path def load_config(): - if config_file.exists(): - try: - with config_file.open("r") as f: - config = json.load(f) - if "aws_access_key_id" in config: - os.environ["AWS_ACCESS_KEY_ID"] = config["aws_access_key_id"] - if "aws_secret_access_key" in config: - os.environ["AWS_SECRET_ACCESS_KEY"] = config["aws_secret_access_key"] - if "azure_tenant_id" in config: - os.environ["AZURE_TENANT_ID"] = config["azure_tenant_id"] - if "azure_client_id" in config: - os.environ["AZURE_CLIENT_ID"] = config["azure_client_id"] - if "azure_client_secret" in config: - os.environ["AZURE_CLIENT_SECRET"] = config["azure_client_secret"] - if "gcp_application_credentials_file" in config: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config["gcp_application_credentials_file"] - - return config - except json.JSONDecodeError as e: - logger.error(f"Error decoding config file: {e}") - raise e - return {} + raise NotImplementedError() + + +@dataclass +class SkylarkConfig: + aws_enabled: bool = False + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + + azure_enabled: bool = False + azure_tenant_id: Optional[str] = None + azure_client_id: Optional[str] = None + azure_client_secret: Optional[str] = None + azure_subscription_id: Optional[str] = None + + gcp_enabled: bool = False + gcp_project_id: Optional[str] = None + gcp_application_credentials_file: Optional[str] = None + + @staticmethod + def load() -> "SkylarkConfig": + if config_path.exists(): + config = SkylarkConfig.load_from_config_file(config_path) + else: + config = SkylarkConfig() + + # set environment variables + if config.gcp_enabled: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.gcp_application_credentials_file + + return config + + @staticmethod + @functools.lru_cache + def load_from_config_file(path=config_path) -> "SkylarkConfig": + path = Path(config_path) + config = configparser.ConfigParser() + if not path.exists(): + logger.error(f"Config file not found: {path}") + raise FileNotFoundError(f"Config file not found: {path}") + config.read(path) + + if "aws" in config: + aws_enabled = True + aws_access_key_id = config.get("aws", "access_key_id") + aws_secret_access_key = config.get("aws", "secret_access_key") + else: + aws_enabled = False + aws_access_key_id = None + aws_secret_access_key = None + + if "azure" in config: + azure_enabled = True + azure_tenant_id = config.get("azure", "tenant_id") + azure_client_id = config.get("azure", "client_id") + azure_client_secret = config.get("azure", "client_secret") + azure_subscription_id = config.get("azure", "subscription_id") + else: + azure_enabled = False + azure_tenant_id = None + azure_client_id = None + azure_client_secret = None + azure_subscription_id = None + + if "gcp" in config: + gcp_enabled = True + gcp_project_id = config.get("gcp", "project_id") + gcp_application_credentials_file = config.get("gcp", "application_credentials_file") + else: + gcp_enabled = False + gcp_project_id = None + gcp_application_credentials_file = None + + return SkylarkConfig( + aws_enabled=aws_enabled, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + azure_enabled=azure_enabled, + azure_tenant_id=azure_tenant_id, + azure_client_id=azure_client_id, + azure_client_secret=azure_client_secret, + azure_subscription_id=azure_subscription_id, + gcp_enabled=gcp_enabled, + gcp_project_id=gcp_project_id, + gcp_application_credentials_file=gcp_application_credentials_file, + ) + + def to_config_file(self, path=config_path): + path = Path(path) + config = configparser.ConfigParser() + if path.exists(): + config.read(os.path.expanduser(path)) + + if self.aws_enabled: + if "aws" not in config: + config.add_section("aws") + config.set("aws", "access_key_id", self.aws_access_key_id) + config.set("aws", "secret_access_key", self.aws_secret_access_key) + else: + config.remove_section("aws") + + if self.azure_enabled: + if "azure" not in config: + config.add_section("azure") + config.set("azure", "tenant_id", self.azure_tenant_id) + config.set("azure", "client_id", self.azure_client_id) + config.set("azure", "client_secret", self.azure_client_secret) + config.set("azure", "subscription_id", self.azure_subscription_id) + else: + config.remove_section("azure") + + if self.gcp_enabled: + if "gcp" not in config: + config.add_section("gcp") + config.set("gcp", "project_id", self.gcp_project_id) + config.set("gcp", "application_credentials_file", self.gcp_application_credentials_file) + else: + config.remove_section("gcp") + + with path.open("w") as f: + config.write(f) diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index f578b2b03..9e3a86231 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -4,14 +4,14 @@ from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.identity import DefaultAzureCredential, ClientSecretCredential from azure.storage.blob import BlobServiceClient -from skylark.config import load_config +from skylark.config import SkylarkConfig, load_config from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject class AzureObject(ObjectStoreObject): def full_path(self): - raise NotImplementedError() + return os.path.join(f"https://{self.bucket}.blob.core.windows.net", self.key) class AzureInterface(ObjectStoreInterface): @@ -23,13 +23,11 @@ def __init__(self, azure_region, container_name): self.pending_downloads, self.completed_downloads = 0, 0 self.pending_uploads, self.completed_uploads = 0, 0 # Authenticate - config = load_config() - self.subscription_id = config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) + config = SkylarkConfig.load() + assert config.azure_enabled, "Azure is not enabled in the config" + self.subscription_id = config.azure_subscription_id + raise NotImplementedError("TODO: COPY SESSION ID FROM CLIENT TO GATEWAY") + self.credential = None # Create a blob service client self.account_url = "https://{}.blob.core.windows.net".format("skylark" + self.azure_region) self.blob_service_client = BlobServiceClient(account_url=self.account_url, credential=self.credential) diff --git a/skylark/obj_store/gcs_interface.py b/skylark/obj_store/gcs_interface.py index 7316e3176..a10e28fe5 100644 --- a/skylark/obj_store/gcs_interface.py +++ b/skylark/obj_store/gcs_interface.py @@ -10,7 +10,7 @@ class GCSObject(ObjectStoreObject): def full_path(self): - raise NotImplementedError() + return os.path.join(f"gs://{self.bucket}", self.key) class GCSInterface(ObjectStoreInterface): @@ -37,7 +37,8 @@ def _on_done_upload(self, **kwargs): self.pending_uploads -= 1 def infer_gcs_region(self, bucket_name: str): - raise NotImplementedError() + bucket = self._gcs_client.get_bucket(bucket_name) + return bucket.location def bucket_exists(self): try: diff --git a/skylark/obj_store/object_store_interface.py b/skylark/obj_store/object_store_interface.py index 2efb0b6bb..a769d4e3d 100644 --- a/skylark/obj_store/object_store_interface.py +++ b/skylark/obj_store/object_store_interface.py @@ -12,30 +12,30 @@ class ObjectStoreObject: last_modified: str def full_path(self): - raise NotImplementedError + raise NotImplementedError() class ObjectStoreInterface: def bucket_exists(self): - raise NotImplementedError + raise NotImplementedError() def create_bucket(self): - raise NotImplementedError + raise NotImplementedError() def delete_bucket(self): - raise NotImplementedError + raise NotImplementedError() def list_objects(self, prefix=""): - raise NotImplementedError + raise NotImplementedError() def get_obj_size(self, obj_name): - raise NotImplementedError + raise NotImplementedError() def download_object(self, src_object_name, dst_file_path): - raise NotImplementedError + raise NotImplementedError() def upload_object(self, src_file_path, dst_object_name, content_type="infer"): - raise NotImplementedError + raise NotImplementedError() @staticmethod def create(region_tag: str, bucket: str): diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index 5c5026434..a7d7690b6 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -9,6 +9,7 @@ import uuid import requests +from skylark.config import SkylarkConfig from skylark.replicate.profiler import status_df_to_traceevent from skylark.utils import logger from tqdm import tqdm @@ -29,26 +30,24 @@ class ReplicatorClient: def __init__( self, topology: ReplicationTopology, - azure_subscription: Optional[str], - gcp_project: Optional[str], gateway_docker_image: str = "ghcr.io/parasj/skylark:latest", aws_instance_class: Optional[str] = "m5.4xlarge", # set to None to disable AWS azure_instance_class: Optional[str] = "Standard_D2_v5", # set to None to disable Azure gcp_instance_class: Optional[str] = "n2-standard-16", # set to None to disable GCP gcp_use_premium_network: bool = True, ): + config = SkylarkConfig.load() self.topology = topology self.gateway_docker_image = gateway_docker_image self.aws_instance_class = aws_instance_class self.azure_instance_class = azure_instance_class - self.azure_subscription = azure_subscription self.gcp_instance_class = gcp_instance_class self.gcp_use_premium_network = gcp_use_premium_network # provisioning - self.aws = AWSCloudProvider() if aws_instance_class != "None" else None - self.azure = AzureCloudProvider(azure_subscription) if azure_instance_class != "None" and azure_subscription is not None else None - self.gcp = GCPCloudProvider(gcp_project) if gcp_instance_class != "None" and gcp_project is not None else None + self.aws = AWSCloudProvider() if aws_instance_class != "None" and config.aws_enabled else None + self.azure = AzureCloudProvider() if azure_instance_class != "None" and config.azure_enabled else None + self.gcp = GCPCloudProvider() if gcp_instance_class != "None" and config.gcp_enabled else None self.bound_nodes: Dict[ReplicationTopologyGateway, Server] = {} def provision_gateways( diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index 407e63dce..2921a88a7 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -63,11 +63,6 @@ def parse_args(): def main(args): - config = load_config() - gcp_project = args.gcp_project or config.get("gcp_project_id") - azure_subscription = args.azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") - src_bucket = f"{args.bucket_prefix}-skylark-{args.src_region.split(':')[1]}" dst_bucket = f"{args.bucket_prefix}-skylark-{args.dest_region.split(':')[1]}" obj_store_interface_src = ObjectStoreInterface.create(args.src_region, src_bucket) @@ -133,8 +128,6 @@ def main(args): # Getting configs rc = ReplicatorClient( topo, - gcp_project=gcp_project, - azure_subscription=azure_subscription, gateway_docker_image=args.gateway_docker_image, aws_instance_class=args.aws_instance_class, azure_instance_class=args.azure_instance_class, From dc14afb5bd5b995956a2269851c7af14d17ac925 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 8 Mar 2022 11:34:56 -0800 Subject: [PATCH 02/34] skylark gcp ssh (#194) --- skylark/cli/cli.py | 2 ++ skylark/cli/cli_aws.py | 5 +-- skylark/cli/cli_gcp.py | 41 +++++++++++++++++++++++ skylark/compute/aws/aws_server.py | 3 ++ skylark/compute/gcp/gcp_cloud_provider.py | 2 +- skylark/compute/gcp/gcp_server.py | 9 +++-- skylark/compute/server.py | 7 ++-- 7 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 skylark/cli/cli_gcp.py diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index c6b1a4aaf..d62d17fb4 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -23,6 +23,7 @@ import boto3 import skylark.cli.cli_aws import skylark.cli.cli_azure +import skylark.cli.cli_gcp import skylark.cli.cli_solver import skylark.cli.experiments import typer @@ -56,6 +57,7 @@ app.add_typer(skylark.cli.experiments.app, name="experiments") app.add_typer(skylark.cli.cli_aws.app, name="aws") app.add_typer(skylark.cli.cli_azure.app, name="azure") +app.add_typer(skylark.cli.cli_gcp.app, name="gcp") app.add_typer(skylark.cli.cli_solver.app, name="solver") diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 2b658e675..155bbc871 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -40,7 +40,7 @@ def ssh(region: Optional[str] = None): typer.secho("Querying AWS for instances", fg="green") instances = aws.get_matching_instances(region=region) if len(instances) == 0: - typer.secho(f"No instancess found", fg="red") + typer.secho(f"No instances found", fg="red") typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} @@ -48,7 +48,8 @@ def ssh(region: Optional[str] = None): instance_name: AWSServer = questionary.select("Select an instance", choices=choices).ask() if instance_name is not None and instance_name in instance_map: instance = instance_map[instance_name] - proc = subprocess.Popen(split(f"ssh -i {str(instance.local_keyfile)} ec2-user@{instance.public_ip()}")) + cmd = instance.get_ssh_cmd() + proc = subprocess.Popen(split(cmd)) proc.wait() else: typer.secho(f"No instance selected", fg="red") diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py new file mode 100644 index 000000000..5c3f3b692 --- /dev/null +++ b/skylark/cli/cli_gcp.py @@ -0,0 +1,41 @@ +import os +import subprocess +from shlex import split +from typing import Optional + +import questionary +import typer +from skylark.config import load_config + +from skylark.utils import logger +from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider +from skylark.compute.gcp.gcp_server import GCPServer + +app = typer.Typer(name="skylark-gcp") + + +@app.command() +def ssh( + region: Optional[str] = None, + gcp_project: str = typer.Option("", "--gcp-project", help="GCP project ID"), +): + config = load_config() + gcp_project = gcp_project or config.get("gcp_project_id") + typer.secho(f"Loaded from config file: gcp_project={gcp_project}", fg="blue") + gcp = GCPCloudProvider(gcp_project) + typer.secho("Querying GCP for instances", fg="green") + instances = gcp.get_matching_instances(region=region) + if len(instances) == 0: + typer.secho(f"No instances found", fg="red") + typer.Abort() + + instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} + choices = list(sorted(instance_map.keys())) + instance_name: GCPServer = questionary.select("Select an instance", choices=choices).ask() + if instance_name is not None and instance_name in instance_map: + cmd = instance_map[instance_name].get_ssh_cmd() + typer.secho(cmd, fg="green") + proc = subprocess.Popen(split(cmd)) + proc.wait() + else: + typer.secho(f"No instance selected", fg="red") diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 7bed09007..1b3e0677b 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -122,3 +122,6 @@ def get_ssh_client_impl(self): banner_timeout=200, ) return client + + def get_ssh_cmd(self): + return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}" diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 8a5fd168b..7737a12e6 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -250,7 +250,7 @@ def wait_for_operation_to_complete(self, zone, operation_name, timeout=120): time.sleep(time_intervals.pop(0)) def provision_instance( - self, region, instance_class, name=None, premium_network=False, uname=os.environ.get("USER"), tags={"skylark": "true"} + self, region, instance_class, name=None, premium_network=False, uname="skylark", tags={"skylark": "true"} ) -> GCPServer: assert not region.startswith("gcp:"), "Region should be GCP region" if name is None: diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index e79134f96..f1b2678fa 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -96,16 +96,19 @@ def terminate_instance_impl(self): compute = self.get_gcp_client() compute.instances().delete(project=self.gcp_project, zone=self.gcp_region, instance=self.instance_name()).execute() - def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): + def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh_client.connect( hostname=self.public_ip(), username=uname, - key_filename=str(self.ssh_private_key), - passphrase=ssh_key_password, + pkey=paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password=ssh_key_password), look_for_keys=False, banner_timeout=200, ) return ssh_client + + def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): + # todo can we include the key password inline? + return f"ssh -i {self.ssh_private_key} -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no {uname}@{self.public_ip()}" diff --git a/skylark/compute/server.py b/skylark/compute/server.py index 739f046a0..f86ce026b 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -90,6 +90,9 @@ def init_log_files(self, log_dir): def get_ssh_client_impl(self): raise NotImplementedError() + def get_ssh_cmd(self) -> str: + raise NotImplementedError() + @property def ssh_client(self): """Create SSH client and cache.""" @@ -209,7 +212,7 @@ def check_stderr(tup): # copy config file config = config_path.read_text()[:-2] + "}" config = json.dumps(config) # Convert to JSON string and remove trailing comma/new-line - self.run_command(f'mkdir -p /opt; echo "{config}" | sudo tee /opt/{config_path.name} > /dev/null') + self.run_command(f'mkdir -p /tmp; echo "{config}" | sudo tee /tmp/{config_path.name} > /dev/null') docker_envs = "" # If needed, add environment variables to docker command @@ -221,7 +224,7 @@ def check_stderr(tup): f"-d --rm --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024} {docker_envs}" ) docker_run_flags += " --mount type=tmpfs,dst=/skylark,tmpfs-size=$(($(free -b | head -n2 | tail -n1 | awk '{print $2}')/2))" - docker_run_flags += f" -v /opt/{config_path.name}:/pkg/data/{config_path.name}" + docker_run_flags += f" -v /tmp/{config_path.name}:/pkg/data/{config_path.name}" docker_run_flags += f" -e SKYLARK_CONFIG=/pkg/data/{config_path.name}" gateway_daemon_cmd = f"python -u /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /skylark/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" From 954a2f1d6f7043b5414c346ccffad1b301098041 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 8 Mar 2022 19:43:41 +0000 Subject: [PATCH 03/34] Autoload config for GCP cli --- skylark/cli/cli_gcp.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py index 5c3f3b692..ccd4d7282 100644 --- a/skylark/cli/cli_gcp.py +++ b/skylark/cli/cli_gcp.py @@ -15,14 +15,8 @@ @app.command() -def ssh( - region: Optional[str] = None, - gcp_project: str = typer.Option("", "--gcp-project", help="GCP project ID"), -): - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - typer.secho(f"Loaded from config file: gcp_project={gcp_project}", fg="blue") - gcp = GCPCloudProvider(gcp_project) +def ssh(region: Optional[str] = None): + gcp = GCPCloudProvider() typer.secho("Querying GCP for instances", fg="green") instances = gcp.get_matching_instances(region=region) if len(instances) == 0: From e7821bca041aa4a6dee56d81c6b0309aeae4d2e5 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 8 Mar 2022 22:00:18 +0000 Subject: [PATCH 04/34] Misc changes --- skylark/cli/cli_aws.py | 2 +- skylark/cli/cli_gcp.py | 5 +---- skylark/cli/cli_helper.py | 24 ++++++++++++------------ skylark/cli/experiments/throughput.py | 10 +++++----- skylark/config.py | 9 +++++++++ skylark/replicate/replicator_client.py | 1 + 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 155bbc871..2e5e575d2 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -41,7 +41,7 @@ def ssh(region: Optional[str] = None): instances = aws.get_matching_instances(region=region) if len(instances) == 0: typer.secho(f"No instances found", fg="red") - typer.Abort() + raise typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} choices = list(sorted(instance_map.keys())) diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py index ccd4d7282..78c9ec2a0 100644 --- a/skylark/cli/cli_gcp.py +++ b/skylark/cli/cli_gcp.py @@ -1,13 +1,10 @@ -import os import subprocess from shlex import split from typing import Optional import questionary import typer -from skylark.config import load_config -from skylark.utils import logger from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.gcp.gcp_server import GCPServer @@ -21,7 +18,7 @@ def ssh(region: Optional[str] = None): instances = gcp.get_matching_instances(region=region) if len(instances) == 0: typer.secho(f"No instances found", fg="red") - typer.Abort() + raise typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} choices = list(sorted(instance_map.keys())) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 9f37adadb..8cda047bf 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -232,7 +232,7 @@ def check_ulimit(hard_limit=1024 * 1024, soft_limit=1024 * 1024): f"Failed to increase ulimit to {soft_limit}, please set manually with 'ulimit -n {soft_limit}'. Current limit is {new_limit}", fg="red", ) - typer.Abort() + raise typer.Abort() else: typer.secho(f"Successfully increased ulimit to {new_limit}", fg="green") if current_limit_soft < soft_limit: @@ -247,19 +247,19 @@ def deprovision_skylark_instances(): config = SkylarkConfig.load() instances = [] - aws = AWSCloudProvider() - for _, instance_list in do_parallel( - aws.get_matching_instances, aws.region_list(), progress_bar=True, leave_pbar=False, desc="Retrieve AWS instances" - ): - instances += instance_list - + query_jobs = [] + if config.aws_enabled: + aws = AWSCloudProvider() + for region in aws.region_list(): + query_jobs.append(lambda: aws.get_matching_instances(region)) if config.azure_enabled: - azure = AzureCloudProvider() - instances += azure.get_matching_instances() - + query_jobs.append(lambda: AzureCloudProvider().get_matching_instances()) if config.gcp_enabled: - gcp = GCPCloudProvider() - instances += gcp.get_matching_instances() + query_jobs.append(lambda: GCPCloudProvider().get_matching_instances()) + + # query in parallel + for _, instance_list in do_parallel(lambda f: f(), query_jobs, progress_bar=True, desc="Query instances", arg_fmt=None): + instances.extend(instance_list) if instances: typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True) diff --git a/skylark/cli/experiments/throughput.py b/skylark/cli/experiments/throughput.py index 5a85dd573..1d4e2613f 100644 --- a/skylark/cli/experiments/throughput.py +++ b/skylark/cli/experiments/throughput.py @@ -113,21 +113,21 @@ def check_stderr(tup): gcp_region_list = gcp_region_list if enable_gcp else [] if not enable_aws and not enable_azure and not enable_gcp: logger.error("At least one of -aws, -azure, -gcp must be enabled.") - typer.Abort() + raise typer.Abort() # validate AWS regions if not enable_aws: aws_region_list = [] elif not all(r in all_aws_regions for r in aws_region_list): logger.error(f"Invalid AWS region list: {aws_region_list}") - typer.Abort() + raise typer.Abort() # validate Azure regions if not enable_azure: azure_region_list = [] elif not all(r in all_azure_regions for r in azure_region_list): logger.error(f"Invalid Azure region list: {azure_region_list}") - typer.Abort() + raise typer.Abort() # validate GCP regions assert not enable_gcp_standard or enable_gcp, f"GCP is disabled but GCP standard is enabled" @@ -135,14 +135,14 @@ def check_stderr(tup): gcp_region_list = [] elif not all(r in all_gcp_regions for r in gcp_region_list): logger.error(f"Invalid GCP region list: {gcp_region_list}") - typer.Abort() + raise typer.Abort() # validate GCP standard instances if not enable_gcp_standard: gcp_standard_region_list = [] if not all(r in all_gcp_regions_standard for r in gcp_standard_region_list): logger.error(f"Invalid GCP standard region list: {gcp_standard_region_list}") - typer.Abort() + raise typer.Abort() # provision servers aws = AWSCloudProvider() diff --git a/skylark/config.py b/skylark/config.py index b0b4f44c4..74b1ca9a0 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -38,6 +38,15 @@ def load() -> "SkylarkConfig": config = SkylarkConfig() # set environment variables + if config.aws_enabled: + # todo load AWS credentials from CLI + os.environ["AWS_ACCESS_KEY_ID"] = config.aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = config.aws_secret_access_key + if config.azure_enabled: + if config.azure_tenant_id and config.azure_client_id and config.azure_client_secret: + os.environ["AZURE_TENANT_ID"] = config.azure_tenant_id + os.environ["AZURE_CLIENT_ID"] = config.azure_client_id + os.environ["AZURE_CLIENT_SECRET"] = config.azure_client_secret if config.gcp_enabled: os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.gcp_application_credentials_file diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index a7d7690b6..cbb1ed7f7 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -188,6 +188,7 @@ def setup(server: Server): def deprovision_gateways(self): def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: + logger.warning(f"Deprovisioning {server.uuid()}") server.terminate_instance() do_parallel(deprovision_gateway_instance, self.bound_nodes.values(), n=-1) From d69ed6fff10908be3b7305b77af39439d71f96cc Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 8 Mar 2022 22:05:08 +0000 Subject: [PATCH 05/34] Print deprovision before final result log --- skylark/cli/cli.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index d62d17fb4..920db0445 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -175,6 +175,11 @@ def replicate_random( stats = rc.monitor_transfer(job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} + + if not reuse_gateways: + atexit.unregister(rc.deprovision_gateways) + rc.deprovision_gateways() + typer.echo(f"\n{json.dumps(out_json)}") return 0 if stats["success"] else 1 @@ -231,8 +236,6 @@ def replicate_json( logger.warning(f"total_transfer_size_mb ({size_total_mb}) is not a multiple of n_chunks ({n_chunks})") chunk_size_mb = size_total_mb // n_chunks - print("REGION", topo.source_region()) - if use_random_data: job = ReplicationJob( source_region=topo.source_region(), @@ -269,6 +272,12 @@ def replicate_json( ) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} + + # deprovision + if not reuse_gateways: + atexit.unregister(rc.deprovision_gateways) + rc.deprovision_gateways() + typer.echo(f"\n{json.dumps(out_json)}") return 0 if stats["success"] else 1 From 1523f83d1da3c90700bd62757796c5abb657fb34 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 9 Mar 2022 01:53:19 +0000 Subject: [PATCH 06/34] ? --- scripts/requirements-gateway.txt | 1 + setup.py | 1 + skylark/__init__.py | 9 ++ skylark/cli/cli.py | 2 - skylark/cli/cli_aws.py | 11 ++- skylark/cli/cli_helper.py | 63 ++++---------- skylark/compute/aws/aws_auth.py | 51 +++++++++++ skylark/compute/aws/aws_cloud_provider.py | 16 ++-- skylark/compute/aws/aws_server.py | 15 ---- skylark/compute/azure/azure_auth.py | 25 ++++++ skylark/compute/azure/azure_cloud_provider.py | 44 ++++++---- skylark/compute/azure/azure_server.py | 25 ++---- skylark/config.py | 86 ++++++++----------- 13 files changed, 188 insertions(+), 161 deletions(-) create mode 100644 skylark/compute/aws/aws_auth.py create mode 100644 skylark/compute/azure/azure_auth.py diff --git a/scripts/requirements-gateway.txt b/scripts/requirements-gateway.txt index 32793883e..c5d5c1660 100644 --- a/scripts/requirements-gateway.txt +++ b/scripts/requirements-gateway.txt @@ -3,6 +3,7 @@ azure-identity azure-mgmt-compute azure-mgmt-network azure-mgmt-resource +azure-mgmt-authorization azure-storage-blob>=12.0.0 boto3 click>=7.1.2 diff --git a/setup.py b/setup.py index df8575268..2a470fbf4 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", + "azure-mgmt-authorization", "azure-storage-blob>=12.0.0", "boto3", "click>=7.1.2", diff --git a/skylark/__init__.py b/skylark/__init__.py index b2947778f..4c2792520 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -2,6 +2,8 @@ from pathlib import Path +from skylark.config import SkylarkConfig + # paths skylark_root = Path(__file__).parent.parent config_root = Path("~/.skylark").expanduser() @@ -36,3 +38,10 @@ def print_header(): KB = 1024 MB = 1024 * 1024 GB = 1024 * 1024 * 1024 + + +# cloud config +if config_path.exists(): + cloud_config = SkylarkConfig.load_from_config_file(config_path) +else: + cloud_config = SkylarkConfig() # empty config \ No newline at end of file diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 920db0445..87daf4ad7 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -17,10 +17,8 @@ import json import os from pathlib import Path -import pprint from typing import Optional -import boto3 import skylark.cli.cli_aws import skylark.cli.cli_azure import skylark.cli.cli_gcp diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 2e5e575d2..172d891a9 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -12,6 +12,7 @@ import questionary import typer from skylark import GB +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.aws.aws_server import AWSServer from skylark.obj_store.s3_interface import S3Interface @@ -23,9 +24,10 @@ @app.command() def vcpu_limits(quota_code="L-1216C47A"): """List the vCPU limits for each region.""" + aws_auth = AWSAuthentication() def get_service_quota(region): - service_quotas = AWSServer.get_boto3_client("service-quotas", region) + service_quotas = aws_auth.get_boto3_client("service-quotas", region) response = service_quotas.get_service_quota(ServiceCode="ec2", QuotaCode=quota_code) return response["Quota"]["Value"] @@ -57,10 +59,11 @@ def ssh(region: Optional[str] = None): @app.command() def cp_datasync(src_bucket: str, dst_bucket: str, path: str): + aws_auth = AWSAuthentication() src_region = S3Interface.infer_s3_region(src_bucket) dst_region = S3Interface.infer_s3_region(dst_bucket) - iam_client = AWSServer.get_boto3_client("iam", "us-east-1") + iam_client = aws_auth.get_boto3_client("iam", "us-east-1") try: response = iam_client.get_role(RoleName="datasync-role") typer.secho("IAM role exists datasync-role", fg="green") @@ -81,14 +84,14 @@ def cp_datasync(src_bucket: str, dst_bucket: str, path: str): iam_arn = response["Role"]["Arn"] typer.secho(f"IAM role ARN: {iam_arn}", fg="green") - ds_client_src = AWSServer.get_boto3_client("datasync", src_region) + ds_client_src = aws_auth.get_boto3_client("datasync", src_region) src_response = ds_client_src.create_location_s3( S3BucketArn=f"arn:aws:s3:::{src_bucket}", Subdirectory=path, S3Config={"BucketAccessRoleArn": iam_arn}, ) src_s3_arn = src_response["LocationArn"] - ds_client_dst = AWSServer.get_boto3_client("datasync", dst_region) + ds_client_dst = aws_auth.get_boto3_client("datasync", dst_region) dst_response = ds_client_dst.create_location_s3( S3BucketArn=f"arn:aws:s3:::{dst_bucket}", Subdirectory=path, diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 8cda047bf..3ac1794bb 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -271,11 +271,12 @@ def deprovision_skylark_instances(): def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: if force_init: typer.secho(" AWS credentials will be re-initialized", fg="red") - config.aws_enabled = False + config.aws_config_mode = "disabled" config.aws_access_key_id = None config.aws_secret_access_key = None - if config.aws_enabled and config.aws_access_key_id is not None and config.aws_secret_access_key is not None: + aws_configured_iam = config.aws_config_mode == "iam_manual" and config.aws_access_key_id and config.aws_secret_access_key + if aws_configured_iam: typer.secho(" AWS credentials already configured! To reconfigure AWS, run `skylark init --reinit-aws`.", fg="blue") return config @@ -287,13 +288,13 @@ def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC typer.secho(" AWS credentials not found in boto3 session, please use the AWS CLI to set them via `aws configure`", fg="red") typer.secho(" https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html", fg="red") typer.secho(" Disabling AWS support", fg="blue") - config.aws_enabled = False + config.aws_config_mode = "disabled" config.aws_access_key_id = None config.aws_secret_access_key = None return config typer.secho(f" Loaded AWS credentials from the AWS CLI [IAM access key ID: ...{credentials.access_key[-6:]}]", fg="blue") - config.aws_enabled = True + config.aws_config_mode = "iam_manual" config.aws_access_key_id = credentials.access_key config.aws_secret_access_key = credentials.secret_key return config @@ -302,56 +303,22 @@ def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: if force_init: typer.secho(" Azure credentials will be re-initialized", fg="red") - config.azure_enabled = False - config.azure_tenant_id = None - config.azure_client_id = None - config.azure_client_secret = None + config.azure_config_mode = "disabled" config.azure_subscription_id = None - if ( - config.azure_enabled - and config.azure_tenant_id is not None - and config.azure_client_id is not None - and config.azure_client_secret is not None - and not force_init - ): + azure_configured_cli = config.azure_config_mode == "cli_auto" and config.azure_subscription_id + if azure_configured_cli: typer.secho(" Azure credentials already configured! To reconfigure Azure, run `skylark init --reinit-azure`.", fg="blue") return config - # get Azure credentials from Azure default credential provider - azure_tenant_id = os.environ.get("AZURE_TENANT_ID", config.azure_tenant_id) - azure_client_id = os.environ.get("AZURE_CLIENT_ID", config.azure_client_id) - azure_client_secret = os.environ.get("AZURE_CLIENT_SECRET", config.azure_client_secret) - azure_subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", config.azure_subscription_id) + # check if DefaultAzureCredential is available + cred = DefaultAzureCredential(exclude_shared_token_cache_credential=True) + - # prompt for missing credentials - if not azure_tenant_id or not azure_client_id or not azure_client_secret or not azure_subscription_id: - typer.secho( - " Azure credentials not found in environment variables, please use the Azure CLI to set them via `az login`", fg="red" - ) - typer.secho(" Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg="red") - if not typer.confirm(" Do you want to manually enter your service principal keys?", default=False): - typer.secho(" Disabling Azure support in Skylark", fg="blue") - config.azure_enabled = False - config.azure_tenant_id = None - config.azure_client_id = None - config.azure_client_secret = None - return config - - if not azure_tenant_id: - azure_tenant_id = typer.prompt(" Azure tenant ID") - if not azure_client_id: - azure_client_id = typer.prompt(" Azure client ID") - if not azure_client_secret: - azure_client_secret = typer.prompt(" Azure client secret") - if not azure_subscription_id: - azure_subscription_id = typer.prompt(" Azure subscription ID") - - config.azure_enabled = True - config.azure_tenant_id = azure_tenant_id - config.azure_client_id = azure_client_id - config.azure_client_secret = azure_client_secret - config.azure_subscription_id = azure_subscription_id + # typer.secho( + # " Azure credentials not found in environment variables, please use the Azure CLI to set them via `az login`", fg="red" + # ) + # typer.secho(" Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg="red") return config diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py new file mode 100644 index 000000000..e9520aa86 --- /dev/null +++ b/skylark/compute/aws/aws_auth.py @@ -0,0 +1,51 @@ +from typing import Optional + +import boto3 + + +class AWSAuthentication: + def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = None): + """Loads AWS authentication details. If no access key is provided, it will try to load credentials using boto3""" + if access_key and secret_key: + self.config_mode = "manual" + self.access_key = access_key + self.secret_key = secret_key + else: + infer_access_key, infer_secret_key = self.infer_credentials() + if infer_access_key and infer_secret_key: + self.config_mode = "iam_inferred" + self.access_key = infer_access_key + self.secret_key = infer_secret_key + else: + self.config_mode = "disabled" + self.access_key = None + self.secret_key = None + + def enabled(self): + return self.config_mode != "disabled" + + def infer_credentials(self): + # todo load temporary credentials from STS + session = boto3.Session() + credentials = session.get_credentials() + credentials = credentials.get_frozen_credentials() + return credentials.access_key, credentials.secret_key + + def get_boto3_session(self, aws_region: str): + if self.config_mode == "manual": + return boto3.Session( + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + region_name=aws_region, + ) + else: + return boto3.Session(region_name=aws_region) + + def get_boto3_resource(self, service_name, aws_region=None): + return self.get_boto3_session(aws_region).resource(service_name, region_name=aws_region) + + def get_boto3_client(self, service_name, aws_region=None): + if aws_region is None: + return self.get_boto3_session(aws_region).client(service_name) + else: + return self.get_boto3_session(aws_region).client(service_name, region_name=aws_region) \ No newline at end of file diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 846fba27d..add933bf7 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -3,6 +3,7 @@ import botocore import pandas as pd +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.utils import logger from oslo_concurrency import lockutils @@ -13,8 +14,9 @@ class AWSCloudProvider(CloudProvider): - def __init__(self): + def __init__(self, auth: AWSAuthentication): super().__init__() + self.auth: AWSAuthentication = auth @property def name(self): @@ -70,7 +72,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): return transfer_df.loc[src, "internet"]["cost"] def get_instance_list(self, region: str) -> List[AWSServer]: - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_client("ec2", region) valid_states = ["pending", "running", "stopped", "stopping"] instances = ec2.instances.filter(Filters=[{"Name": "instance-state-name", "Values": valid_states}]) try: @@ -82,7 +84,7 @@ def get_instance_list(self, region: str) -> List[AWSServer]: return [AWSServer(f"aws:{region}", i) for i in instance_ids] def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) assert len(vpcs) == 1, f"Found {len(vpcs)} vpcs with name {vpc_name}" sgs = [sg for sg in vpcs[0].security_groups.all() if sg.group_name == sg_name] @@ -90,7 +92,7 @@ def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark") return sgs[0] def get_vpc(self, region: str, vpc_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) if len(vpcs) == 0: return None @@ -98,7 +100,7 @@ def get_vpc(self, region: str, vpc_name="skylark"): return vpcs[0] def make_vpc(self, region: str, vpc_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) ec2client = ec2.meta.client vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) @@ -152,7 +154,7 @@ def make_vpc(self, region: str, vpc_name="skylark"): def delete_vpc(self, region: str, vpcid: str): """Delete VPC, from https://gist.github.com/vernhart/c6a0fc94c0aeaebe84e5cd6f3dede4ce""" logger.warning(f"[{region}] Deleting VPC {vpcid}") - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) ec2client = ec2.meta.client vpc = ec2.Vpc(vpcid) # detach and delete all gateways associated with the vpc @@ -221,7 +223,7 @@ def provision_instance( assert not region.startswith("aws:"), "Region should be AWS region" if name is None: name = f"skylark-aws-{str(uuid.uuid4()).replace('-', '')}" - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) AWSServer.ensure_keyfile_exists(region) vpc = self.get_vpc(region) diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 1b3e0677b..03116abe7 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -26,21 +26,6 @@ def __init__(self, region_tag, instance_id, log_dir=None): def uuid(self): return f"{self.region_tag}:{self.instance_id}" - @classmethod - def get_boto3_session(cls, aws_region) -> boto3.Session: - return boto3.Session(region_name=aws_region) - - @classmethod - def get_boto3_resource(cls, service_name, aws_region=None): - return cls.get_boto3_session(aws_region).resource(service_name, region_name=aws_region) - - @classmethod - def get_boto3_client(cls, service_name, aws_region=None): - if aws_region is None: - return cls.get_boto3_session(aws_region).client(service_name) - else: - return cls.get_boto3_session(aws_region).client(service_name, region_name=aws_region) - def get_boto3_instance_resource(self): ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) return ec2.Instance(self.instance_id) diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py new file mode 100644 index 000000000..8f4d32f9b --- /dev/null +++ b/skylark/compute/azure/azure_auth.py @@ -0,0 +1,25 @@ +from azure.identity import DefaultAzureCredential +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.authorization import AuthorizationManagementClient + +class AzureAuthentication: + def __init__(self, subscription_id: str): + self.subscription_id = subscription_id + self.credential = DefaultAzureCredential() + + def get_token(self, resource: str): + return self.credential.get_token(resource) + + def get_compute_client(self): + return ComputeManagementClient(self.credential, self.subscription_id) + + def get_resource_client(self): + return ResourceManagementClient(self.credential, self.subscription_id) + + def get_network_client(self): + return NetworkManagementClient(self.credential, self.subscription_id) + + def get_authorization_client(self): + return AuthorizationManagementClient(self.credential, self.subscription_id) \ No newline at end of file diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index b413bb6f9..1d29498bf 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -5,16 +5,14 @@ from typing import List, Optional import paramiko +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.config import SkylarkConfig from skylark.utils import logger from skylark import key_root from skylark.compute.azure.azure_server import AzureServer from skylark.compute.cloud_providers import CloudProvider -from azure.identity import DefaultAzureCredential, ClientSecretCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.compute.models import ResourceIdentityType from skylark.utils.utils import do_parallel @@ -24,8 +22,7 @@ def __init__(self, key_root=key_root / "azure"): super().__init__() config = SkylarkConfig().load() assert config.azure_enabled, "Azure cloud provider is not enabled in the config file." - self.credential = DefaultAzureCredential() - self.subscription_id = config.azure_subscription_id + self.auth = AzureAuthentication(config.subscription_id) key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "azure_key" @@ -259,9 +256,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): raise ValueError(f"Unknown transfer cost for {src_key} -> {dst_key}") def get_instance_list(self, region: str) -> List[AzureServer]: - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - + compute_client = self.auth.get_compute_client() server_list = [] for vm in compute_client.virtual_machines.list(AzureServer.resource_group_name): if vm.tags.get("skylark", None) == "true" and AzureServer.is_valid_vm_name(vm.name) and vm.location == region: @@ -286,12 +281,11 @@ def create_ssh_key(self): f.write(f"{key.get_name()} {key.get_base64()}\n") def set_up_resource_group(self, clean_up_orphans=True): - credential = self.credential - resource_client = ResourceManagementClient(credential, self.subscription_id) + resource_client = self.auth.get_resource_client() if resource_client.resource_groups.check_existence(AzureServer.resource_group_name): # Resource group already exists. # Take this moment to search for orphaned resources and clean them up... - network_client = NetworkManagementClient(credential, self.subscription_id) + network_client = self.auth.get_network_client() if clean_up_orphans: instances_to_terminate = [] for vnet in network_client.virtual_networks.list(AzureServer.resource_group_name): @@ -326,9 +320,8 @@ def provision_instance( pub_key = f.read() # Prepare for making Microsoft Azure API calls - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - network_client = NetworkManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() + network_client = self.auth.get_network_client() # Use the common resource group for this instance resource_group = AzureServer.resource_group_name @@ -436,8 +429,27 @@ def provision_instance( }, }, "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, + # give VM managed identity w/ system assigned identity + "identity": {"type": ResourceIdentityType.system_assigned}, }, ) - poller.result() + vm_result = poller.result() + + # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment + # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets + auth_client = self.auth.get_authorization_client() + role_name = 'Contributor' + roles = list(auth_client.role_definitions.list(resource_group, filter="roleName eq '{}'".format(role_name))) + assert len(roles) == 1 + + # Add RG scope to the MSI identities: + role_assignment = auth_client.role_assignments.create( + resource_group, + uuid.uuid4(), # Role assignment random name + { + 'role_definition_id': roles[0].id, + 'principal_id': vm_result.identity.principal_id + } + ) return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 1f4f5fd3e..dd1448b40 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -4,16 +4,13 @@ import paramiko from skylark import key_root +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.config import SkylarkConfig, load_config from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike import azure.core.exceptions -from azure.identity import DefaultAzureCredential, ClientSecretCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient class AzureServer(Server): @@ -30,8 +27,7 @@ def __init__( ): config = SkylarkConfig.load() assert config.azure_enabled, "Azure is not enabled in the config" - self.credential = DefaultAzureCredential() - self.subscription_id = config.azure_subscription_id + self.auth = AzureAuthentication(config.azure_subscription_id) self.name = name self.location = None @@ -96,8 +92,7 @@ def nic_name(name): return AzureServer.vm_name(name) + "-nic" def get_resource_group(self): - credential = self.credential - resource_client = ResourceManagementClient(credential, self.subscription_id) + resource_client = self.auth.get_resource_client() rg = resource_client.resource_groups.get(AzureServer.resource_group_name) # Sanity checks @@ -108,8 +103,7 @@ def get_resource_group(self): return rg def get_virtual_machine(self): - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() vm = compute_client.virtual_machines.get(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) # Sanity checks @@ -129,8 +123,7 @@ def uuid(self): return f"{self.subscription_id}:{self.region_tag}:{self.name}" def instance_state(self) -> ServerState: - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() vm_instance_view = compute_client.virtual_machines.instance_view(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) statuses = vm_instance_view.statuses for status in statuses: @@ -140,8 +133,7 @@ def instance_state(self) -> ServerState: @ignore_lru_cache() def public_ip(self): - credential = self.credential - network_client = NetworkManagementClient(credential, self.subscription_id) + network_client = self.auth.get_network_client() public_ip = network_client.public_ip_addresses.get(AzureServer.resource_group_name, AzureServer.ip_name(self.name)) # Sanity checks @@ -169,9 +161,8 @@ def network_tier(self): return "PREMIUM" def terminate_instance_impl(self): - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - network_client = NetworkManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() + network_client = self.auth.get_network_client() vm_poller = compute_client.virtual_machines.begin_delete(AzureServer.resource_group_name, self.vm_name(self.name)) _ = vm_poller.result() diff --git a/skylark/config.py b/skylark/config.py index 74b1ca9a0..7b0d7dd99 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -4,56 +4,40 @@ from pathlib import Path from typing import Optional +from azure.common.credentials import DefaultAzureCredential from skylark.utils import logger import configparser from skylark import config_path -def load_config(): - raise NotImplementedError() - @dataclass class SkylarkConfig: - aws_enabled: bool = False - aws_access_key_id: Optional[str] = None - aws_secret_access_key: Optional[str] = None - - azure_enabled: bool = False - azure_tenant_id: Optional[str] = None - azure_client_id: Optional[str] = None - azure_client_secret: Optional[str] = None + aws_config_mode: str = "disabled" # disabled, iam_manual + aws_access_key_id: Optional[str] = None # iam_manual + aws_secret_access_key: Optional[str] = None # iam_manual + + azure_config_mode: str = "disabled" # disabled, cli_auto azure_subscription_id: Optional[str] = None - gcp_enabled: bool = False + gcp_config_mode: str = "disabled" # disabled or service_account_file gcp_project_id: Optional[str] = None - gcp_application_credentials_file: Optional[str] = None - - @staticmethod - def load() -> "SkylarkConfig": - if config_path.exists(): - config = SkylarkConfig.load_from_config_file(config_path) - else: - config = SkylarkConfig() - - # set environment variables - if config.aws_enabled: - # todo load AWS credentials from CLI - os.environ["AWS_ACCESS_KEY_ID"] = config.aws_access_key_id - os.environ["AWS_SECRET_ACCESS_KEY"] = config.aws_secret_access_key - if config.azure_enabled: - if config.azure_tenant_id and config.azure_client_id and config.azure_client_secret: - os.environ["AZURE_TENANT_ID"] = config.azure_tenant_id - os.environ["AZURE_CLIENT_ID"] = config.azure_client_id - os.environ["AZURE_CLIENT_SECRET"] = config.azure_client_secret - if config.gcp_enabled: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.gcp_application_credentials_file - - return config + gcp_application_credentials_file: Optional[str] = None # service_account_file + + @property + def aws_enabled(self) -> bool: + return self.aws_config_mode != "disabled" + + @property + def azure_enabled(self) -> bool: + return self.azure_config_mode != "disabled" + + @property + def gcp_enabled(self) -> bool: + return self.gcp_config_mode != "disabled" @staticmethod - @functools.lru_cache def load_from_config_file(path=config_path) -> "SkylarkConfig": path = Path(config_path) config = configparser.ConfigParser() @@ -63,7 +47,6 @@ def load_from_config_file(path=config_path) -> "SkylarkConfig": config.read(path) if "aws" in config: - aws_enabled = True aws_access_key_id = config.get("aws", "access_key_id") aws_secret_access_key = config.get("aws", "secret_access_key") else: @@ -113,31 +96,30 @@ def to_config_file(self, path=config_path): if path.exists(): config.read(os.path.expanduser(path)) - if self.aws_enabled: - if "aws" not in config: - config.add_section("aws") + if "aws" not in config: + config.add_section("aws") + config.set("aws", "config_mode", self.aws_config_mode) + if self.aws_config_mode == "iam_manual": config.set("aws", "access_key_id", self.aws_access_key_id) config.set("aws", "secret_access_key", self.aws_secret_access_key) - else: - config.remove_section("aws") - if self.azure_enabled: - if "azure" not in config: - config.add_section("azure") + if "azure" not in config: + config.add_section("azure") + config.set("azure", "config_mode", self.azure_config_mode) + if self.azure_config_mode == "cli_auto": + config.set("azure", "subscription_id", self.azure_subscription_id) + elif self.azure_config_mode == "ad_manual": config.set("azure", "tenant_id", self.azure_tenant_id) config.set("azure", "client_id", self.azure_client_id) config.set("azure", "client_secret", self.azure_client_secret) config.set("azure", "subscription_id", self.azure_subscription_id) - else: - config.remove_section("azure") - if self.gcp_enabled: - if "gcp" not in config: - config.add_section("gcp") + if "gcp" not in config: + config.add_section("gcp") + config.set("gcp", "config_mode", self.gcp_config_mode) + if self.gcp_config_mode == "service_account_file": config.set("gcp", "project_id", self.gcp_project_id) config.set("gcp", "application_credentials_file", self.gcp_application_credentials_file) - else: - config.remove_section("gcp") with path.open("w") as f: config.write(f) From f38dace1f0fb008435e829106d2fb0a7fc1e3af0 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 10 Mar 2022 00:27:12 +0000 Subject: [PATCH 07/34] Fix pytype errors --- skylark/__init__.py | 4 +- skylark/cli/cli.py | 19 +-- skylark/cli/cli_aws.py | 5 +- skylark/cli/cli_azure.py | 7 +- skylark/cli/cli_helper.py | 111 ++++++++---------- skylark/cli/experiments/throughput.py | 4 - skylark/compute/aws/aws_auth.py | 10 +- skylark/compute/aws/aws_cloud_provider.py | 6 +- skylark/compute/aws/aws_server.py | 17 +-- skylark/compute/azure/azure_auth.py | 37 ++++-- skylark/compute/azure/azure_cloud_provider.py | 11 +- skylark/compute/azure/azure_server.py | 5 +- skylark/compute/gcp/gcp_auth.py | 29 +++++ skylark/compute/gcp/gcp_cloud_provider.py | 35 +++--- skylark/compute/gcp/gcp_server.py | 25 ++-- skylark/config.py | 111 +++++------------- skylark/obj_store/azure_interface.py | 15 +-- skylark/obj_store/s3_interface.py | 36 ++---- skylark/replicate/replicator_client.py | 32 +++-- skylark/test/test_replicator_client.py | 2 - 20 files changed, 233 insertions(+), 288 deletions(-) create mode 100644 skylark/compute/gcp/gcp_auth.py diff --git a/skylark/__init__.py b/skylark/__init__.py index 4c2792520..9292957c6 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -42,6 +42,6 @@ def print_header(): # cloud config if config_path.exists(): - cloud_config = SkylarkConfig.load_from_config_file(config_path) + cloud_config = SkylarkConfig.load_config(config_path) else: - cloud_config = SkylarkConfig() # empty config \ No newline at end of file + cloud_config = SkylarkConfig.load_infer_cli() diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 87daf4ad7..007d1c86c 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -25,6 +25,7 @@ import skylark.cli.cli_solver import skylark.cli.experiments import typer +from skylark.config import SkylarkConfig from skylark.utils import logger from skylark import config_path, GB, MB, print_header from skylark.cli.cli_helper import ( @@ -46,7 +47,6 @@ ls_s3, parse_path, ) -from skylark.config import SkylarkConfig, load_config from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology from skylark.replicate.replicator_client import ReplicatorClient from skylark.obj_store.object_store_interface import ObjectStoreInterface @@ -270,7 +270,7 @@ def replicate_json( ) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} - + # deprovision if not reuse_gateways: atexit.unregister(rc.deprovision_gateways) @@ -287,23 +287,26 @@ def deprovision(): @app.command() -def init(reinit_aws: bool = False, reinit_azure: bool = False, reinit_gcp: bool = False): +def init(reinit_azure: bool = False, reinit_gcp: bool = False): print_header() - config = SkylarkConfig.load() + if config_path.exists(): + cloud_config = SkylarkConfig.load_config(config_path) + else: + cloud_config = SkylarkConfig() # load AWS config typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) - config = load_aws_config(config, force_init=reinit_aws) + cloud_config = load_aws_config(cloud_config) # load Azure config typer.secho("\n(2) Configuring Azure:", fg="yellow", bold=True) - config = load_azure_config(config, force_init=reinit_azure) + cloud_config = load_azure_config(cloud_config, force_init=reinit_azure) # load GCP config typer.secho("\n(3) Configuring GCP:", fg="yellow", bold=True) - config = load_gcp_config(config, force_init=reinit_gcp) + cloud_config = load_gcp_config(cloud_config, force_init=reinit_gcp) - config.to_config_file(config_path) + cloud_config.to_config_file(config_path) typer.secho(f"\nConfig file saved to {config_path}", fg="green") return 0 diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 172d891a9..aab6cd002 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -60,8 +60,9 @@ def ssh(region: Optional[str] = None): @app.command() def cp_datasync(src_bucket: str, dst_bucket: str, path: str): aws_auth = AWSAuthentication() - src_region = S3Interface.infer_s3_region(src_bucket) - dst_region = S3Interface.infer_s3_region(dst_bucket) + s3_interface = S3Interface("us-east-1", None) + src_region = s3_interface.infer_s3_region(src_bucket) + dst_region = s3_interface.infer_s3_region(dst_bucket) iam_client = aws_auth.get_boto3_client("iam", "us-east-1") try: diff --git a/skylark/cli/cli_azure.py b/skylark/cli/cli_azure.py index 778d9a931..0193a6cee 100644 --- a/skylark/cli/cli_azure.py +++ b/skylark/cli/cli_azure.py @@ -9,6 +9,7 @@ import typer from azure.identity import DefaultAzureCredential from azure.mgmt.compute import ComputeManagementClient +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.config import SkylarkConfig from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.utils.utils import do_parallel @@ -22,12 +23,10 @@ def get_valid_skus( prefix: str = typer.Option("", "--prefix", help="Filter by prefix"), top_k: int = typer.Option(-1, "--top-k", help="Print top k entries"), ): - config = SkylarkConfig.load() - credential = DefaultAzureCredential() + auth = AzureAuthentication() + client = auth.get_compute_client() - # query azure API for each region to get available SKUs for each resource type def get_skus(region): - client = ComputeManagementClient(credential, config.azure_subscription_id) valid_skus = [] for sku in client.resource_skus.list(filter="location eq '{}'".format(region)): if len(sku.restrictions) == 0 and (not prefix or sku.name.startswith(prefix)): diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 3ac1794bb..0f5002a30 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -9,6 +9,9 @@ import boto3 import typer +from skylark.compute.aws.aws_auth import AWSAuthentication +from skylark.compute.azure.azure_auth import AzureAuthentication +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.config import SkylarkConfig from skylark.utils import logger from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider @@ -244,17 +247,18 @@ def check_ulimit(hard_limit=1024 * 1024, soft_limit=1024 * 1024): def deprovision_skylark_instances(): - config = SkylarkConfig.load() instances = [] - query_jobs = [] - if config.aws_enabled: + if AWSAuthentication().enabled(): + logger.debug("AWS authentication enabled, querying for instances") aws = AWSCloudProvider() for region in aws.region_list(): query_jobs.append(lambda: aws.get_matching_instances(region)) - if config.azure_enabled: + if AzureAuthentication().enabled(): + logger.debug("Azure authentication enabled, querying for instances") query_jobs.append(lambda: AzureCloudProvider().get_matching_instances()) - if config.gcp_enabled: + if GCPAuthentication().enabled(): + logger.debug("GCP authentication enabled, querying for instances") query_jobs.append(lambda: GCPCloudProvider().get_matching_instances()) # query in parallel @@ -268,18 +272,7 @@ def deprovision_skylark_instances(): typer.secho("No instances to deprovision, exiting...", fg="yellow", bold=True) -def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: - if force_init: - typer.secho(" AWS credentials will be re-initialized", fg="red") - config.aws_config_mode = "disabled" - config.aws_access_key_id = None - config.aws_secret_access_key = None - - aws_configured_iam = config.aws_config_mode == "iam_manual" and config.aws_access_key_id and config.aws_secret_access_key - if aws_configured_iam: - typer.secho(" AWS credentials already configured! To reconfigure AWS, run `skylark init --reinit-aws`.", fg="blue") - return config - +def load_aws_config(config: SkylarkConfig) -> SkylarkConfig: # get AWS credentials from boto3 session = boto3.Session() credentials = session.get_credentials() @@ -288,37 +281,40 @@ def load_aws_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC typer.secho(" AWS credentials not found in boto3 session, please use the AWS CLI to set them via `aws configure`", fg="red") typer.secho(" https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html", fg="red") typer.secho(" Disabling AWS support", fg="blue") - config.aws_config_mode = "disabled" - config.aws_access_key_id = None - config.aws_secret_access_key = None return config typer.secho(f" Loaded AWS credentials from the AWS CLI [IAM access key ID: ...{credentials.access_key[-6:]}]", fg="blue") - config.aws_config_mode = "iam_manual" - config.aws_access_key_id = credentials.access_key - config.aws_secret_access_key = credentials.secret_key return config def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: if force_init: typer.secho(" Azure credentials will be re-initialized", fg="red") - config.azure_config_mode = "disabled" config.azure_subscription_id = None - azure_configured_cli = config.azure_config_mode == "cli_auto" and config.azure_subscription_id - if azure_configured_cli: + if config.azure_subscription_id: typer.secho(" Azure credentials already configured! To reconfigure Azure, run `skylark init --reinit-azure`.", fg="blue") return config - # check if DefaultAzureCredential is available - cred = DefaultAzureCredential(exclude_shared_token_cache_credential=True) - - - # typer.secho( - # " Azure credentials not found in environment variables, please use the Azure CLI to set them via `az login`", fg="red" - # ) - # typer.secho(" Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg="red") + # check if Azure is enabled + auth = AzureAuthentication() + if not auth.enabled(): + typer.secho(" Default Azure credentials are not set up yet. Run `az login` to set them up.", fg="red") + typer.secho(" https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate", fg="red") + typer.secho(" Disabling Azure support", fg="blue") + return config + else: + typer.secho(" Azure credentials found in Azure CLI", fg="blue") + inferred_subscription_id = auth.infer_subscription_id() + if inferred_subscription_id: + typer.secho(f" Inferred Azure subscription ID: {inferred_subscription_id}", fg="blue") + if typer.confirm(f" Do you want to use the inferred subscription ID `{inferred_subscription_id}`?", default=True): + config.azure_subscription_id = inferred_subscription_id + return config + if typer.confirm( + " Azure credentials are configured, but no default subscription ID was set. Do you want to set one now?", default=True + ): + config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:") return config @@ -329,40 +325,25 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC config.gcp_application_credentials_file = None config.gcp_project_id = None - if config.gcp_enabled and config.gcp_project_id is not None and config.gcp_application_credentials_file is not None: + if config.gcp_project_id is not None and config.gcp_application_credentials_file is not None: typer.secho(" GCP already configured! To reconfigure GCP, run `skylark init --reinit-gcp`.", fg="blue") return config - # load from environment variables - gcp_application_credentials_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", config.gcp_application_credentials_file) - if not gcp_application_credentials_file: - typer.secho( - " GCP credentials not found in environment variables, please use the GCP CLI to set them via `gcloud auth application-default login`", - fg="red", - ) + # check if GCP is enabled + auth = GCPAuthentication() + if not auth.enabled(): + typer.secho(" Default GCP credentials are not set up yet. Run `gcloud auth login` to set them up.", fg="red") typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") - if not typer.confirm(" Do you want to manually enter your service account key?", default=False): - typer.secho(" Disabling GCP support in Skylark", fg="blue") - config.gcp_enabled = False - config.gcp_project_id = None - config.gcp_application_credentials_file = None - return config - gcp_application_credentials_file = typer.prompt(" GCP application credentials file path") - - # check if the file exists - gcp_application_credentials_file = Path(gcp_application_credentials_file).expanduser().resolve() - if not gcp_application_credentials_file.exists(): - typer.secho(f" GCP application credentials file not found at {gcp_application_credentials_file}", fg="red") - typer.secho(" Disabling GCP support in Skylark", fg="blue") - config.gcp_enabled = False - config.gcp_project_id = None - config.gcp_application_credentials_file = None + typer.secho(" Disabling GCP support", fg="blue") return config - - config.gcp_enabled = True - config.gcp_application_credentials_file = str(gcp_application_credentials_file) - project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", config.gcp_project_id) - if not project_id: - project_id = typer.prompt(" GCP project ID") - config.gcp_project_id = project_id + else: + typer.secho(" GCP credentials found in GCP CLI", fg="blue") + inferred_project_id = auth.infer_project_id() + if inferred_project_id: + typer.secho(f" Inferred GCP project ID: {inferred_project_id}", fg="blue") + if typer.confirm(f" Do you want to use the inferred project ID `{inferred_project_id}`?", default=True): + config.gcp_project_id = inferred_project_id + return config + if typer.confirm(" GCP credentials are configured, but no default project ID was set. Do you want to set one now?", default=True): + config.gcp_project_id = typer.prompt(" Enter the GCP project ID:") return config diff --git a/skylark/cli/experiments/throughput.py b/skylark/cli/experiments/throughput.py index 1d4e2613f..04c7fd955 100644 --- a/skylark/cli/experiments/throughput.py +++ b/skylark/cli/experiments/throughput.py @@ -10,7 +10,6 @@ import typer from skylark import GB, skylark_root from skylark.benchmark.utils import provision, split_list -from skylark.config import SkylarkConfig from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider @@ -88,9 +87,6 @@ def throughput_grid( def check_stderr(tup): assert tup[1].strip() == "", f"Command failed, err: {tup[1]}" - config = SkylarkConfig.load() - assert config.aws_enabled and config.azure_enabled and config.gcp_enabled, "All cloud providers must be enabled." - if resume: index_key = [ "iperf3_connections", diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index e9520aa86..336010b0b 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -20,17 +20,17 @@ def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = self.config_mode = "disabled" self.access_key = None self.secret_key = None - + def enabled(self): return self.config_mode != "disabled" - + def infer_credentials(self): # todo load temporary credentials from STS session = boto3.Session() credentials = session.get_credentials() credentials = credentials.get_frozen_credentials() return credentials.access_key, credentials.secret_key - + def get_boto3_session(self, aws_region: str): if self.config_mode == "manual": return boto3.Session( @@ -40,7 +40,7 @@ def get_boto3_session(self, aws_region: str): ) else: return boto3.Session(region_name=aws_region) - + def get_boto3_resource(self, service_name, aws_region=None): return self.get_boto3_session(aws_region).resource(service_name, region_name=aws_region) @@ -48,4 +48,4 @@ def get_boto3_client(self, service_name, aws_region=None): if aws_region is None: return self.get_boto3_session(aws_region).client(service_name) else: - return self.get_boto3_session(aws_region).client(service_name, region_name=aws_region) \ No newline at end of file + return self.get_boto3_session(aws_region).client(service_name, region_name=aws_region) diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index add933bf7..f3a63b824 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -14,9 +14,9 @@ class AWSCloudProvider(CloudProvider): - def __init__(self, auth: AWSAuthentication): + def __init__(self): super().__init__() - self.auth: AWSAuthentication = auth + self.auth = AWSAuthentication() @property def name(self): @@ -224,8 +224,6 @@ def provision_instance( if name is None: name = f"skylark-aws-{str(uuid.uuid4()).replace('-', '')}" ec2 = self.auth.get_boto3_resource("ec2", region) - AWSServer.ensure_keyfile_exists(region) - vpc = self.get_vpc(region) assert vpc is not None, "No VPC found" subnets = list(vpc.subnets.all()) diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 03116abe7..797459633 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -4,6 +4,7 @@ import boto3 import paramiko +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.utils import logger from oslo_concurrency import lockutils @@ -18,6 +19,7 @@ class AWSServer(Server): def __init__(self, region_tag, instance_id, log_dir=None): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "aws" + self.auth = AWSAuthentication() self.aws_region = self.region_tag.split(":")[1] self.instance_id = instance_id self.boto3_session = boto3.Session(region_name=self.aws_region) @@ -27,11 +29,10 @@ def uuid(self): return f"{self.region_tag}:{self.instance_id}" def get_boto3_instance_resource(self): - ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) + ec2 = self.auth.get_boto3_resource("ec2", self.aws_region) return ec2.Instance(self.instance_id) - @staticmethod - def ensure_keyfile_exists(aws_region, prefix=key_root / "aws"): + def ensure_keyfile_exists(self, aws_region, prefix=key_root / "aws"): prefix = Path(prefix) key_name = f"skylark-{aws_region}" local_key_file = prefix / f"{key_name}.pem" @@ -39,8 +40,8 @@ def ensure_keyfile_exists(aws_region, prefix=key_root / "aws"): @lockutils.synchronized(f"aws_keyfile_lock_{aws_region}", external=True, lock_path="/tmp/skylark_locks") def create_keyfile(): if not local_key_file.exists(): # we have to check again since another process may have created it - ec2 = AWSServer.get_boto3_resource("ec2", aws_region) - ec2_client = AWSServer.get_boto3_client("ec2", aws_region) + ec2 = self.auth.get_boto3_resource("ec2", aws_region) + ec2_client = self.auth.get_boto3_client("ec2", aws_region) local_key_file.parent.mkdir(parents=True, exist_ok=True) # delete key pair from ec2 if it exists keys_in_region = set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]) @@ -92,8 +93,7 @@ def __repr__(self): return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})" def terminate_instance_impl(self): - ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) - ec2.instances.filter(InstanceIds=[self.instance_id]).terminate() + self.auth.get_boto3_resource("ec2", self.aws_region).instances.filter(InstanceIds=[self.instance_id]).terminate() def get_ssh_client_impl(self): client = paramiko.SSHClient() @@ -101,7 +101,8 @@ def get_ssh_client_impl(self): client.connect( self.public_ip(), username="ec2-user", - pkey=paramiko.RSAKey.from_private_key_file(self.local_keyfile), + # todo generate keys with password "skylark" + pkey=paramiko.RSAKey.from_private_key_file(str(self.local_keyfile)), look_for_keys=False, allow_agent=False, banner_timeout=200, diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index 8f4d32f9b..a88ece351 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -1,25 +1,48 @@ +import os +import subprocess +from typing import Optional from azure.identity import DefaultAzureCredential from azure.mgmt.compute import ComputeManagementClient from azure.mgmt.network import NetworkManagementClient from azure.mgmt.resource import ResourceManagementClient from azure.mgmt.authorization import AuthorizationManagementClient +from azure.storage.blob import BlobServiceClient + +from skylark import cloud_config + class AzureAuthentication: - def __init__(self, subscription_id: str): + def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): self.subscription_id = subscription_id self.credential = DefaultAzureCredential() - + + def enabled(self) -> bool: + return self.subscription_id is not None + + @staticmethod + def infer_subscription_id() -> Optional[str]: + if "AZURE_SUBSCRIPTION_ID" in os.environ: + return os.environ["AZURE_SUBSCRIPTION_ID"] + else: + try: + return subprocess.check_output(["az", "account", "show", "--query", "id"]).decode("utf-8").strip() + except subprocess.CalledProcessError: + return None + def get_token(self, resource: str): return self.credential.get_token(resource) - + def get_compute_client(self): return ComputeManagementClient(self.credential, self.subscription_id) - + def get_resource_client(self): return ResourceManagementClient(self.credential, self.subscription_id) - + def get_network_client(self): return NetworkManagementClient(self.credential, self.subscription_id) - + def get_authorization_client(self): - return AuthorizationManagementClient(self.credential, self.subscription_id) \ No newline at end of file + return AuthorizationManagementClient(self.credential, self.subscription_id) + + def get_storage_client(self, account_url: str): + return BlobServiceClient(account_url=account_url, credential=self.credential) diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index 1d29498bf..33f283068 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -20,9 +20,7 @@ class AzureCloudProvider(CloudProvider): def __init__(self, key_root=key_root / "azure"): super().__init__() - config = SkylarkConfig().load() - assert config.azure_enabled, "Azure cloud provider is not enabled in the config file." - self.auth = AzureAuthentication(config.subscription_id) + self.auth = AzureAuthentication() key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "azure_key" @@ -438,7 +436,7 @@ def provision_instance( # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets auth_client = self.auth.get_authorization_client() - role_name = 'Contributor' + role_name = "Contributor" roles = list(auth_client.role_definitions.list(resource_group, filter="roleName eq '{}'".format(role_name))) assert len(roles) == 1 @@ -446,10 +444,7 @@ def provision_instance( role_assignment = auth_client.role_assignments.create( resource_group, uuid.uuid4(), # Role assignment random name - { - 'role_definition_id': roles[0].id, - 'principal_id': vm_result.identity.principal_id - } + {"role_definition_id": roles[0].id, "principal_id": vm_result.identity.principal_id}, ) return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index dd1448b40..ceb618e7d 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -5,7 +5,6 @@ import paramiko from skylark import key_root from skylark.compute.azure.azure_auth import AzureAuthentication -from skylark.config import SkylarkConfig, load_config from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike @@ -25,9 +24,7 @@ def __init__( ssh_private_key=None, assume_exists=True, ): - config = SkylarkConfig.load() - assert config.azure_enabled, "Azure is not enabled in the config" - self.auth = AzureAuthentication(config.azure_subscription_id) + self.auth = AzureAuthentication() self.name = name self.location = None diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py new file mode 100644 index 000000000..81740dec7 --- /dev/null +++ b/skylark/compute/gcp/gcp_auth.py @@ -0,0 +1,29 @@ +from functools import lru_cache +import os +import subprocess +import googleapiclient.discovery + +from skylark import cloud_config + + +class GCPAuthentication: + def __init__(self, project_id: str = cloud_config.gcp_project_id): + self.project_id = project_id + + def enabled(self): + return self.project_id is not None + + @staticmethod + def infer_project_id(): + if "GOOGLE_PROJECT_ID" in os.environ: + return os.environ["GOOGLE_PROJECT_ID"] + try: + return subprocess.check_output(["gcloud", "config", "get-value", "project"]).decode("utf-8").strip() + except subprocess.CalledProcessError: + return None + + def get_gcp_client(self, service_name="compute", version="v1"): + return googleapiclient.discovery.build(service_name, version) + + def get_gcp_instances(self, gcp_region: str): + return self.get_gcp_client().instances().list(project=self.project_id, zone=gcp_region).execute() diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 7737a12e6..e5d9ca977 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -6,6 +6,7 @@ import googleapiclient import paramiko +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.config import SkylarkConfig from skylark.utils import logger @@ -20,9 +21,7 @@ class GCPCloudProvider(CloudProvider): def __init__(self, key_root=key_root / "gcp"): super().__init__() - config = SkylarkConfig.load() - assert config.gcp_enabled, "GCP is not enabled in the config" - self.gcp_project = config.gcp_project_id + self.auth = GCPAuthentication() key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "gcp-cert.pem" self.public_key_path = key_root / "gcp-cert.pub" @@ -154,11 +153,11 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): raise ValueError("Unknown src_continent: {}".format(src_continent)) def get_instance_list(self, region) -> List[GCPServer]: - gcp_instance_result = GCPServer.gcp_instances(self.gcp_project, region) + gcp_instance_result = self.auth.get_gcp_instances(region) if "items" in gcp_instance_result: instance_list = [] for i in gcp_instance_result["items"]: - instance_list.append(GCPServer(f"gcp:{region}", self.gcp_project, i["name"], ssh_private_key=self.private_key_path)) + instance_list.append(GCPServer(f"gcp:{region}", i["name"], ssh_private_key=self.private_key_path)) return instance_list else: return [] @@ -181,14 +180,14 @@ def create_ssh_key(self): f.write(f"{key.get_name()} {key.get_base64()}\n") def configure_default_network(self): - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() try: - compute.networks().get(project=self.gcp_project, network="default").execute() + compute.networks().get(project=self.auth.project_id, network="default").execute() except googleapiclient.errors.HttpError as e: if e.resp.status == 404: # create network op = ( compute.networks() - .insert(project=self.gcp_project, body={"name": "default", "subnetMode": "auto", "autoCreateSubnetworks": True}) + .insert(project=self.auth.project_id, body={"name": "default", "subnetMode": "auto", "autoCreateSubnetworks": True}) .execute() ) self.wait_for_operation_to_complete("global", op["name"]) @@ -197,18 +196,18 @@ def configure_default_network(self): def configure_default_firewall(self, ip="0.0.0.0/0"): """Configure default firewall to allow access from all ports from all IPs (if not exists).""" - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() @lockutils.synchronized(f"gcp_configure_default_firewall", external=True, lock_path="/tmp/skylark_locks") def create_firewall(body, update_firewall=False): if update_firewall: - op = compute.firewalls().update(project=self.gcp_project, firewall="default", body=fw_body).execute() + op = compute.firewalls().update(project=self.auth.project_id, firewall="default", body=fw_body).execute() else: - op = compute.firewalls().insert(project=self.gcp_project, body=fw_body).execute() + op = compute.firewalls().insert(project=self.auth.project_id, body=fw_body).execute() self.wait_for_operation_to_complete("global", op["name"]) try: - current_firewall = compute.firewalls().get(project=self.gcp_project, firewall="default").execute() + current_firewall = compute.firewalls().get(project=self.auth.project_id, firewall="default").execute() except googleapiclient.errors.HttpError as e: if e.resp.status == 404: current_firewall = None @@ -231,11 +230,11 @@ def create_firewall(body, update_firewall=False): logger.debug(f"[GCP] Updated firewall") def get_operation_state(self, zone, operation_name): - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() if zone == "global": - return compute.globalOperations().get(project=self.gcp_project, operation=operation_name).execute() + return compute.globalOperations().get(project=self.auth.project_id, operation=operation_name).execute() else: - return compute.zoneOperations().get(project=self.gcp_project, zone=zone, operation=operation_name).execute() + return compute.zoneOperations().get(project=self.auth.project_id, zone=zone, operation=operation_name).execute() def wait_for_operation_to_complete(self, zone, operation_name, timeout=120): time_intervals = [0.1] * 10 + [0.2] * 10 + [1.0] * int(timeout) # backoff @@ -255,7 +254,7 @@ def provision_instance( assert not region.startswith("gcp:"), "Region should be GCP region" if name is None: name = f"skylark-gcp-{str(uuid.uuid4()).replace('-', '')}" - compute = GCPServer.get_gcp_client("compute", "v1") + compute = self.auth.get_gcp_client() with open(os.path.expanduser(self.public_key_path)) as f: pub_key = f.read() @@ -289,9 +288,9 @@ def provision_instance( ] }, } - result = compute.instances().insert(project=self.gcp_project, zone=region, body=req_body).execute() + result = compute.instances().insert(project=self.auth.project_id, zone=region, body=req_body).execute() self.wait_for_operation_to_complete(region, result["name"]) - server = GCPServer(f"gcp:{region}", self.gcp_project, name) + server = GCPServer(f"gcp:{region}", name) server.wait_for_ready() server.run_command("sudo /sbin/iptables -A INPUT -j ACCEPT") return server diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index f1b2678fa..e5ad6d2b9 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -2,9 +2,9 @@ from functools import lru_cache from pathlib import Path -import googleapiclient.discovery import paramiko from skylark import key_root +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike @@ -14,7 +14,6 @@ class GCPServer(Server): def __init__( self, region_tag: str, - gcp_project: str, instance_name: str, key_root: PathLike = key_root / "gcp", log_dir=None, @@ -23,7 +22,7 @@ def __init__( super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "gcp", f"Region name doesn't match pattern gcp: {self.region_tag}" self.gcp_region = self.region_tag.split(":")[1] - self.gcp_project = gcp_project + self.auth = GCPAuthentication() self.gcp_instance_name = instance_name key_root = Path(key_root) key_root.mkdir(parents=True, exist_ok=True) @@ -33,20 +32,11 @@ def __init__( self.ssh_private_key = ssh_private_key def uuid(self): - return f"{self.gcp_project}:{self.region_tag}:{self.gcp_instance_name}" - - @classmethod - def get_gcp_client(cls, service_name="compute", version="v1"): - return googleapiclient.discovery.build(service_name, version) - - @staticmethod - def gcp_instances(gcp_project, gcp_region): - compute = GCPServer.get_gcp_client() - return compute.instances().list(project=gcp_project, zone=gcp_region).execute() + return f"{self.region_tag}:{self.gcp_instance_name}" @lru_cache def get_gcp_instance(self): - instances = self.gcp_instances(self.gcp_project, self.gcp_region) + instances = self.auth.get_gcp_instances(self.gcp_region) if "items" in instances: for i in instances["items"]: if i["name"] == self.gcp_instance_name: @@ -90,11 +80,12 @@ def network_tier(self): return interface["accessConfigs"][0]["networkTier"] def __repr__(self): - return f"GCPServer(region_tag={self.region_tag}, gcp_project={self.gcp_project}, instance_name={self.gcp_instance_name})" + return f"GCPServer(region_tag={self.region_tag}, instance_name={self.gcp_instance_name})" def terminate_instance_impl(self): - compute = self.get_gcp_client() - compute.instances().delete(project=self.gcp_project, zone=self.gcp_region, instance=self.instance_name()).execute() + self.auth.get_gcp_client().instances().delete( + project=self.auth.project_id, zone=self.gcp_region, instance=self.instance_name() + ).execute() def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" diff --git a/skylark/config.py b/skylark/config.py index 7b0d7dd99..885f0af51 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -1,44 +1,26 @@ from dataclasses import dataclass -import functools import os from pathlib import Path from typing import Optional +import subprocess +from skylark.compute.azure.azure_auth import AzureAuthentication +from skylark.compute.gcp.gcp_auth import GCPAuthentication -from azure.common.credentials import DefaultAzureCredential from skylark.utils import logger import configparser from skylark import config_path - @dataclass class SkylarkConfig: - aws_config_mode: str = "disabled" # disabled, iam_manual - aws_access_key_id: Optional[str] = None # iam_manual - aws_secret_access_key: Optional[str] = None # iam_manual - - azure_config_mode: str = "disabled" # disabled, cli_auto + is_inferred_credential: bool = False azure_subscription_id: Optional[str] = None - - gcp_config_mode: str = "disabled" # disabled or service_account_file gcp_project_id: Optional[str] = None - gcp_application_credentials_file: Optional[str] = None # service_account_file - - @property - def aws_enabled(self) -> bool: - return self.aws_config_mode != "disabled" - - @property - def azure_enabled(self) -> bool: - return self.azure_config_mode != "disabled" - - @property - def gcp_enabled(self) -> bool: - return self.gcp_config_mode != "disabled" @staticmethod - def load_from_config_file(path=config_path) -> "SkylarkConfig": + def load_config(path) -> "SkylarkConfig": + """Load from a config file.""" path = Path(config_path) config = configparser.ConfigParser() if not path.exists(): @@ -46,80 +28,47 @@ def load_from_config_file(path=config_path) -> "SkylarkConfig": raise FileNotFoundError(f"Config file not found: {path}") config.read(path) - if "aws" in config: - aws_access_key_id = config.get("aws", "access_key_id") - aws_secret_access_key = config.get("aws", "secret_access_key") - else: - aws_enabled = False - aws_access_key_id = None - aws_secret_access_key = None - - if "azure" in config: - azure_enabled = True - azure_tenant_id = config.get("azure", "tenant_id") - azure_client_id = config.get("azure", "client_id") - azure_client_secret = config.get("azure", "client_secret") + azure_subscription_id = None + if "azure" in config and "subscription_id" in config["azure"]: azure_subscription_id = config.get("azure", "subscription_id") - else: - azure_enabled = False - azure_tenant_id = None - azure_client_id = None - azure_client_secret = None - azure_subscription_id = None - if "gcp" in config: - gcp_enabled = True + gcp_project_id = None + if "gcp" in config and "project_id" in config["gcp"]: gcp_project_id = config.get("gcp", "project_id") - gcp_application_credentials_file = config.get("gcp", "application_credentials_file") - else: - gcp_enabled = False - gcp_project_id = None - gcp_application_credentials_file = None return SkylarkConfig( - aws_enabled=aws_enabled, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - azure_enabled=azure_enabled, - azure_tenant_id=azure_tenant_id, - azure_client_id=azure_client_id, - azure_client_secret=azure_client_secret, + is_inferred_credential=False, azure_subscription_id=azure_subscription_id, - gcp_enabled=gcp_enabled, gcp_project_id=gcp_project_id, - gcp_application_credentials_file=gcp_application_credentials_file, ) - def to_config_file(self, path=config_path): + @staticmethod + def load_infer_cli() -> "SkylarkConfig": + """Attempt to infer configuration using cloud CLIs.""" + azure_subscription_id = AzureAuthentication.infer_subscription_id() + gcp_project_id = GCPAuthentication.infer_project_id() + return SkylarkConfig( + is_inferred_credential=True, + azure_subscription_id=azure_subscription_id, + gcp_project_id=gcp_project_id, + ) + + def to_config_file(self, path): + assert not self.is_inferred_credential, "Cannot write inferred config to file" path = Path(path) config = configparser.ConfigParser() if path.exists(): config.read(os.path.expanduser(path)) - if "aws" not in config: - config.add_section("aws") - config.set("aws", "config_mode", self.aws_config_mode) - if self.aws_config_mode == "iam_manual": - config.set("aws", "access_key_id", self.aws_access_key_id) - config.set("aws", "secret_access_key", self.aws_secret_access_key) - - if "azure" not in config: - config.add_section("azure") - config.set("azure", "config_mode", self.azure_config_mode) - if self.azure_config_mode == "cli_auto": - config.set("azure", "subscription_id", self.azure_subscription_id) - elif self.azure_config_mode == "ad_manual": - config.set("azure", "tenant_id", self.azure_tenant_id) - config.set("azure", "client_id", self.azure_client_id) - config.set("azure", "client_secret", self.azure_client_secret) + if self.azure_subscription_id: + if "azure" not in config: + config.add_section("azure") config.set("azure", "subscription_id", self.azure_subscription_id) - if "gcp" not in config: - config.add_section("gcp") - config.set("gcp", "config_mode", self.gcp_config_mode) - if self.gcp_config_mode == "service_account_file": + if self.gcp_project_id: + if "gcp" not in config: + config.add_section("gcp") config.set("gcp", "project_id", self.gcp_project_id) - config.set("gcp", "application_credentials_file", self.gcp_application_credentials_file) with path.open("w") as f: config.write(f) diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 9e3a86231..e9b164dcd 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -2,9 +2,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError -from azure.identity import DefaultAzureCredential, ClientSecretCredential -from azure.storage.blob import BlobServiceClient -from skylark.config import SkylarkConfig, load_config +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject @@ -16,21 +14,16 @@ def full_path(self): class AzureInterface(ObjectStoreInterface): def __init__(self, azure_region, container_name): - # TODO: the azure region should get corresponding os.getenv() self.azure_region = azure_region self.container_name = container_name self.bucket_name = self.container_name # For compatibility self.pending_downloads, self.completed_downloads = 0, 0 self.pending_uploads, self.completed_uploads = 0, 0 - # Authenticate - config = SkylarkConfig.load() - assert config.azure_enabled, "Azure is not enabled in the config" - self.subscription_id = config.azure_subscription_id - raise NotImplementedError("TODO: COPY SESSION ID FROM CLIENT TO GATEWAY") - self.credential = None + # Create a blob service client + self.auth = AzureAuthentication() self.account_url = "https://{}.blob.core.windows.net".format("skylark" + self.azure_region) - self.blob_service_client = BlobServiceClient(account_url=self.account_url, credential=self.credential) + self.blob_service_client = self.auth.get_storage_client(self.account_url) self.pool = ThreadPoolExecutor(max_workers=256) # TODO: This might need some tuning self.max_concurrency = 1 diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index 45a2c8a1f..f246e13e8 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -8,9 +8,8 @@ from awscrt.http import HttpHeaders, HttpRequest from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from skylark.compute.aws.aws_auth import AWSAuthentication -from skylark.compute.aws.aws_server import AWSServer -from skylark.config import load_config from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject @@ -20,33 +19,21 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=None): - + def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=None, num_threads=4): + self.auth = AWSAuthentication() self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region self.bucket_name = bucket_name self.pending_downloads, self.completed_downloads = 0, 0 self.pending_uploads, self.completed_uploads = 0, 0 self.s3_part_size = part_size self.s3_throughput_target_gbps = throughput_target_gbps - # num_threads=os.cpu_count() - # num_threads=256 - num_threads = 4 # 256 event_loop_group = EventLoopGroup(num_threads=num_threads, cpu_group=None) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - # Authenticate - config = load_config() - aws_access_key_id = config["aws_access_key_id"] - aws_secret_access_key = config["aws_secret_access_key"] - if aws_access_key_id and aws_secret_access_key: - credential_provider = AwsCredentialsProvider.new_static(aws_access_key_id, aws_secret_access_key) - else: # use default - credential_provider = AwsCredentialsProvider.new_default_chain(bootstrap) - self._s3_client = S3Client( bootstrap=bootstrap, region=self.aws_region, - credential_provider=credential_provider, + credential_provider=AwsCredentialsProvider.new_default_chain(bootstrap), throughput_target_gbps=100, part_size=None, tls_mode=S3RequestTlsMode.ENABLED if use_tls else S3RequestTlsMode.DISABLED, @@ -60,18 +47,17 @@ def _on_done_upload(self, **kwargs): self.completed_uploads += 1 self.pending_uploads -= 1 - @staticmethod - def infer_s3_region(bucket_name: str): - s3_client = AWSServer.get_boto3_client("s3") + def infer_s3_region(self, bucket_name: str): + s3_client = self.auth.get_boto3_client("s3") region = s3_client.get_bucket_location(Bucket=bucket_name).get("LocationConstraint", "us-east-1") return region if region is not None else "us-east-1" def bucket_exists(self): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) return self.bucket_name in [b["Name"] for b in s3_client.list_buckets()["Buckets"]] def create_bucket(self): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) if not self.bucket_exists(): if self.aws_region == "us-east-1": s3_client.create_bucket(Bucket=self.bucket_name) @@ -81,7 +67,7 @@ def create_bucket(self): def list_objects(self, prefix="") -> Iterator[S3Object]: prefix = prefix if not prefix.startswith("/") else prefix[1:] - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) paginator = s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) for page in page_iterator: @@ -89,13 +75,13 @@ def list_objects(self, prefix="") -> Iterator[S3Object]: yield S3Object("s3", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"]) def delete_objects(self, keys: List[str]): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) while keys: batch, keys = keys[:1000], keys[1000:] # take up to 1000 keys at a time s3_client.delete_objects(Bucket=self.bucket_name, Delete={"Objects": [{"Key": k} for k in batch]}) def get_obj_metadata(self, obj_name): - s3_resource = AWSServer.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name) + s3_resource = self.auth.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name) try: return s3_resource.Object(str(obj_name).lstrip("/")) except botocore.exceptions.ClientError as e: diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index cbb1ed7f7..f3bd4a68a 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -9,6 +9,7 @@ import uuid import requests +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.config import SkylarkConfig from skylark.replicate.profiler import status_df_to_traceevent from skylark.utils import logger @@ -36,7 +37,6 @@ def __init__( gcp_instance_class: Optional[str] = "n2-standard-16", # set to None to disable GCP gcp_use_premium_network: bool = True, ): - config = SkylarkConfig.load() self.topology = topology self.gateway_docker_image = gateway_docker_image self.aws_instance_class = aws_instance_class @@ -45,9 +45,9 @@ def __init__( self.gcp_use_premium_network = gcp_use_premium_network # provisioning - self.aws = AWSCloudProvider() if aws_instance_class != "None" and config.aws_enabled else None - self.azure = AzureCloudProvider() if azure_instance_class != "None" and config.azure_enabled else None - self.gcp = GCPCloudProvider() if gcp_instance_class != "None" and config.gcp_enabled else None + self.aws = AWSCloudProvider() + self.azure = AzureCloudProvider() + self.gcp = GCPCloudProvider() self.bound_nodes: Dict[ReplicationTopologyGateway, Server] = {} def provision_gateways( @@ -58,9 +58,15 @@ def provision_gateways( azure_regions_to_provision = [r for r in regions_to_provision if r.startswith("azure:")] gcp_regions_to_provision = [r for r in regions_to_provision if r.startswith("gcp:")] - assert len(aws_regions_to_provision) == 0 or self.aws is not None, "AWS not enabled" - assert len(azure_regions_to_provision) == 0 or self.azure is not None, "Azure not enabled" - assert len(gcp_regions_to_provision) == 0 or self.gcp is not None, "GCP not enabled" + assert ( + len(aws_regions_to_provision) == 0 or self.aws.auth.enabled() + ), "AWS credentials not configured but job provisions AWS gateways" + assert ( + len(azure_regions_to_provision) == 0 or self.azure.auth.enabled() + ), "Azure credentials not configured but job provisions Azure gateways" + assert ( + len(gcp_regions_to_provision) == 0 or self.gcp.auth.enabled() + ), "GCP credentials not configured but job provisions GCP gateways" # init clouds jobs = [] @@ -78,7 +84,7 @@ def provision_gateways( # reuse existing AWS instances if reuse_instances: - if self.aws is not None: + if self.aws.auth.enabled(): aws_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.aws_instance_class, @@ -94,7 +100,7 @@ def provision_gateways( else: current_aws_instances = {} - if self.azure is not None: + if self.azure.auth.enabled(): azure_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.azure_instance_class, @@ -110,7 +116,7 @@ def provision_gateways( else: current_azure_instances = {} - if self.gcp is not None: + if self.gcp.auth.enabled(): gcp_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.gcp_instance_class, @@ -131,13 +137,13 @@ def provision_gateways( def provision_gateway_instance(region: str) -> Server: provider, subregion = region.split(":") if provider == "aws": - assert self.aws is not None + assert self.aws.auth.enabled() server = self.aws.provision_instance(subregion, self.aws_instance_class) elif provider == "azure": - assert self.azure is not None + assert self.azure.auth.enabled() server = self.azure.provision_instance(subregion, self.azure_instance_class) elif provider == "gcp": - assert self.gcp is not None + assert self.gcp.auth.enabled() # todo specify network tier in ReplicationTopology server = self.gcp.provision_instance(subregion, self.gcp_instance_class, premium_network=self.gcp_use_premium_network) else: diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index 2921a88a7..3544e40a9 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -18,8 +18,6 @@ from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology from skylark.replicate.replicator_client import ReplicatorClient -from skylark.config import load_config - def parse_args(): parser = argparse.ArgumentParser(description="Run a replication job") From 7adbdf0ff8d6c778e0af3fa6e774b2686c02debc Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 10 Mar 2022 02:22:26 +0000 Subject: [PATCH 08/34] Skylark init end to end --- skylark/__init__.py | 6 ++-- skylark/cli/cli_helper.py | 45 +++++++++++++++++------------ skylark/compute/azure/azure_auth.py | 7 ++++- skylark/compute/utils.py | 35 ++++++++++++++++++++++ skylark/config.py | 21 +------------- 5 files changed, 70 insertions(+), 44 deletions(-) diff --git a/skylark/__init__.py b/skylark/__init__.py index 9292957c6..874d670f0 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -1,6 +1,7 @@ import os from pathlib import Path +from skylark.compute.utils import query_which_cloud from skylark.config import SkylarkConfig @@ -38,10 +39,7 @@ def print_header(): KB = 1024 MB = 1024 * 1024 GB = 1024 * 1024 * 1024 - - -# cloud config if config_path.exists(): cloud_config = SkylarkConfig.load_config(config_path) else: - cloud_config = SkylarkConfig.load_infer_cli() + cloud_config = SkylarkConfig() diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 0f5002a30..4ab9b7c15 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -5,6 +5,7 @@ import subprocess from pathlib import Path from shutil import copyfile +import sys from typing import Dict, List, Optional import boto3 @@ -298,47 +299,53 @@ def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> Skylar # check if Azure is enabled auth = AzureAuthentication() - if not auth.enabled(): - typer.secho(" Default Azure credentials are not set up yet. Run `az login` to set them up.", fg="red") + try: + auth.credential.get_token("https://management.azure.com/") + azure_enabled = True + except Exception as e: + print(e) + print(e.message) + azure_enabled = False + if not azure_enabled: + typer.secho(" No local Azure credentials! Run `az login` to set them up.", fg="red") typer.secho(" https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate", fg="red") typer.secho(" Disabling Azure support", fg="blue") return config + inferred_subscription_id = AzureAuthentication.infer_subscription_id() + if typer.confirm(" Azure credentials found, do you want to enable Azure support in Skylark?", default=True): + config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:", default=inferred_subscription_id) else: - typer.secho(" Azure credentials found in Azure CLI", fg="blue") - inferred_subscription_id = auth.infer_subscription_id() - if inferred_subscription_id: - typer.secho(f" Inferred Azure subscription ID: {inferred_subscription_id}", fg="blue") - if typer.confirm(f" Do you want to use the inferred subscription ID `{inferred_subscription_id}`?", default=True): - config.azure_subscription_id = inferred_subscription_id - return config - if typer.confirm( - " Azure credentials are configured, but no default subscription ID was set. Do you want to set one now?", default=True - ): - config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:") + config.azure_subscription_id = None + typer.secho(" Disabling Azure support", fg="blue") return config def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: if force_init: typer.secho(" GCP credentials will be re-initialized", fg="red") - config.gcp_enabled = False - config.gcp_application_credentials_file = None config.gcp_project_id = None - if config.gcp_project_id is not None and config.gcp_application_credentials_file is not None: + if config.gcp_project_id is not None: typer.secho(" GCP already configured! To reconfigure GCP, run `skylark init --reinit-gcp`.", fg="blue") return config # check if GCP is enabled auth = GCPAuthentication() - if not auth.enabled(): - typer.secho(" Default GCP credentials are not set up yet. Run `gcloud auth login` to set them up.", fg="red") + try: + auth.get_gcp_client("cloudbilling", "v1").services().list().execute() + gcp_enabled = True + except Exception as e: + print(e) + print(e.message) + gcp_enabled = False + if not gcp_enabled: + typer.secho(" Default GCP credentials are not set up yet. Run `gcloud auth application-default login` to set them up.", fg="red") typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") typer.secho(" Disabling GCP support", fg="blue") return config else: typer.secho(" GCP credentials found in GCP CLI", fg="blue") - inferred_project_id = auth.infer_project_id() + inferred_project_id = GCPAuthentication.infer_project_id() if inferred_project_id: typer.secho(f" Inferred GCP project ID: {inferred_project_id}", fg="blue") if typer.confirm(f" Do you want to use the inferred project ID `{inferred_project_id}`?", default=True): diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index a88ece351..305fc5f00 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -9,12 +9,17 @@ from azure.storage.blob import BlobServiceClient from skylark import cloud_config +from skylark.compute.utils import query_which_cloud class AzureAuthentication: def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): self.subscription_id = subscription_id - self.credential = DefaultAzureCredential() + self.credential = DefaultAzureCredential( + exclude_managed_identity_credential=query_which_cloud() != "azure", # exclude MSI if not Azure + exclude_powershell_credential=True, + exclude_visual_studio_code_credential=True, + ) def enabled(self) -> bool: return self.subscription_id is not None diff --git a/skylark/compute/utils.py b/skylark/compute/utils.py index b5e80f041..6a21f1a81 100644 --- a/skylark/compute/utils.py +++ b/skylark/compute/utils.py @@ -1,6 +1,41 @@ +from functools import lru_cache +import subprocess from skylark.utils import logger +@lru_cache +def query_which_cloud() -> str: + if ( + subprocess.call( + 'curl -f --noproxy "*" http://169.254.169.254/1.0/meta-data/instance-id'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "aws" + elif ( + subprocess.call( + 'curl -f -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance?api-version=2021-02-01"'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "azure" + elif ( + subprocess.call( + 'curl -f --noproxy "*" http://metadata.google.internal/computeMetadata/v1/instance/hostname'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "gcp" + else: + return "unknown" + + def make_dozzle_command(port): cmd = """sudo docker run -d --rm --name dozzle \ -p {log_viewer_port}:8080 \ diff --git a/skylark/config.py b/skylark/config.py index 885f0af51..d43ed0f34 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -2,26 +2,20 @@ import os from pathlib import Path from typing import Optional -import subprocess -from skylark.compute.azure.azure_auth import AzureAuthentication -from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.utils import logger import configparser -from skylark import config_path - @dataclass class SkylarkConfig: - is_inferred_credential: bool = False azure_subscription_id: Optional[str] = None gcp_project_id: Optional[str] = None @staticmethod def load_config(path) -> "SkylarkConfig": """Load from a config file.""" - path = Path(config_path) + path = Path(path) config = configparser.ConfigParser() if not path.exists(): logger.error(f"Config file not found: {path}") @@ -37,24 +31,11 @@ def load_config(path) -> "SkylarkConfig": gcp_project_id = config.get("gcp", "project_id") return SkylarkConfig( - is_inferred_credential=False, - azure_subscription_id=azure_subscription_id, - gcp_project_id=gcp_project_id, - ) - - @staticmethod - def load_infer_cli() -> "SkylarkConfig": - """Attempt to infer configuration using cloud CLIs.""" - azure_subscription_id = AzureAuthentication.infer_subscription_id() - gcp_project_id = GCPAuthentication.infer_project_id() - return SkylarkConfig( - is_inferred_credential=True, azure_subscription_id=azure_subscription_id, gcp_project_id=gcp_project_id, ) def to_config_file(self, path): - assert not self.is_inferred_credential, "Cannot write inferred config to file" path = Path(path) config = configparser.ConfigParser() if path.exists(): From ca199978da5d64d5d1ad5668a633325d6a58c8e7 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 10 Mar 2022 05:10:56 +0000 Subject: [PATCH 09/34] Remove quotes from auto azure config --- skylark/cli/cli_helper.py | 22 ++++++++++++---------- skylark/compute/aws/aws_cloud_provider.py | 2 +- skylark/compute/azure/azure_auth.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 4ab9b7c15..ebde84ffb 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -311,6 +311,7 @@ def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> Skylar typer.secho(" https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate", fg="red") typer.secho(" Disabling Azure support", fg="blue") return config + typer.secho(" Azure credentials found in Azure CLI", fg="blue") inferred_subscription_id = AzureAuthentication.infer_subscription_id() if typer.confirm(" Azure credentials found, do you want to enable Azure support in Skylark?", default=True): config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:", default=inferred_subscription_id) @@ -336,21 +337,22 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC gcp_enabled = True except Exception as e: print(e) - print(e.message) gcp_enabled = False if not gcp_enabled: - typer.secho(" Default GCP credentials are not set up yet. Run `gcloud auth application-default login` to set them up.", fg="red") + typer.secho( + " Default GCP credentials are not set up yet. Run `gcloud auth application-default login` or set GOOGLE_APPLICATION_CREDENTIALS.", + fg="red", + ) typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") typer.secho(" Disabling GCP support", fg="blue") return config else: typer.secho(" GCP credentials found in GCP CLI", fg="blue") inferred_project_id = GCPAuthentication.infer_project_id() - if inferred_project_id: - typer.secho(f" Inferred GCP project ID: {inferred_project_id}", fg="blue") - if typer.confirm(f" Do you want to use the inferred project ID `{inferred_project_id}`?", default=True): - config.gcp_project_id = inferred_project_id - return config - if typer.confirm(" GCP credentials are configured, but no default project ID was set. Do you want to set one now?", default=True): - config.gcp_project_id = typer.prompt(" Enter the GCP project ID:") - return config + if typer.confirm(" GCP credentials found, do you want to enable GCP support in Skylark?", default=True): + config.gcp_project_id = typer.prompt(" Enter the GCP project ID:", default=inferred_project_id) + return config + else: + config.gcp_project_id = None + typer.secho(" Disabling GCP support", fg="blue") + return config diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index f3a63b824..5c91d2834 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -72,7 +72,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): return transfer_df.loc[src, "internet"]["cost"] def get_instance_list(self, region: str) -> List[AWSServer]: - ec2 = self.auth.get_boto3_client("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) valid_states = ["pending", "running", "stopped", "stopping"] instances = ec2.instances.filter(Filters=[{"Name": "instance-state-name", "Values": valid_states}]) try: diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index 305fc5f00..f1d0421e1 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -30,7 +30,7 @@ def infer_subscription_id() -> Optional[str]: return os.environ["AZURE_SUBSCRIPTION_ID"] else: try: - return subprocess.check_output(["az", "account", "show", "--query", "id"]).decode("utf-8").strip() + return subprocess.check_output(["az", "account", "show", "--query", "id"]).decode("utf-8").replace('"', "").strip() except subprocess.CalledProcessError: return None From c412f848032d148566e999d65865ec89f62306d6 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 10 Mar 2022 08:26:01 +0000 Subject: [PATCH 10/34] Role assignment --- skylark/compute/azure/azure_cloud_provider.py | 226 +++++++++--------- skylark/compute/azure/azure_server.py | 2 +- 2 files changed, 118 insertions(+), 110 deletions(-) diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index 33f283068..b4ff6f7c8 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -11,10 +11,11 @@ from skylark import key_root from skylark.compute.azure.azure_server import AzureServer from skylark.compute.cloud_providers import CloudProvider +from azure.mgmt.authorization.models import RoleAssignmentCreateParameters from azure.mgmt.compute.models import ResourceIdentityType -from skylark.utils.utils import do_parallel +from skylark.utils.utils import Timer, do_parallel class AzureCloudProvider(CloudProvider): @@ -327,124 +328,131 @@ def provision_instance( # TODO: On future requests to create resources, check if a resource # with that name already exists, and fail the operation if so - # Create a Virtual Network for this instance - poller = network_client.virtual_networks.begin_create_or_update( - resource_group, - AzureServer.vnet_name(name), - {"location": location, "tags": {"skylark": "true"}, "address_space": {"address_prefixes": ["10.0.0.0/24"]}}, - ) - poller.result() - - # Create a Network Security Group for this instance - # NOTE: This is insecure. We should fix this soon. - poller = network_client.network_security_groups.begin_create_or_update( - resource_group, - AzureServer.nsg_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "security_rules": [ - { - "name": name + "-allow-all", - "protocol": "Tcp", - "source_port_range": "*", - "source_address_prefix": "*", - "destination_port_range": "*", - "destination_address_prefix": "*", - "access": "Allow", - "priority": 300, - "direction": "Inbound", - } - # Azure appears to add default rules for outbound connections - ], - }, - ) - nsg_result = poller.result() - - # Create a subnet for this instance with the above Network Security Group - subnet_poller = network_client.subnets.begin_create_or_update( - resource_group, - AzureServer.vnet_name(name), - AzureServer.subnet_name(name), - {"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}}, - ) + with Timer("Creating Azure network"): + # Create a Virtual Network for this instance + poller = network_client.virtual_networks.begin_create_or_update( + resource_group, + AzureServer.vnet_name(name), + {"location": location, "tags": {"skylark": "true"}, "address_space": {"address_prefixes": ["10.0.0.0/24"]}}, + ) + poller.result() - # Create a public IPv4 address for this instance - ip_poller = network_client.public_ip_addresses.begin_create_or_update( - resource_group, - AzureServer.ip_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "sku": {"name": "Standard"}, - "public_ip_allocation_method": "Static", - "public_ip_address_version": "IPV4", - }, - ) + # Create a Network Security Group for this instance + # NOTE: This is insecure. We should fix this soon. + poller = network_client.network_security_groups.begin_create_or_update( + resource_group, + AzureServer.nsg_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "security_rules": [ + { + "name": name + "-allow-all", + "protocol": "Tcp", + "source_port_range": "*", + "source_address_prefix": "*", + "destination_port_range": "*", + "destination_address_prefix": "*", + "access": "Allow", + "priority": 300, + "direction": "Inbound", + } + # Azure appears to add default rules for outbound connections + ], + }, + ) + nsg_result = poller.result() - subnet_result = subnet_poller.result() - public_ip_result = ip_poller.result() - - # Create a NIC for this instance, with accelerated networking enabled - poller = network_client.network_interfaces.begin_create_or_update( - resource_group, - AzureServer.nic_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "ip_configurations": [ - {"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}} - ], - "enable_accelerated_networking": True, - }, - ) - nic_result = poller.result() + # Create a subnet for this instance with the above Network Security Group + subnet_poller = network_client.subnets.begin_create_or_update( + resource_group, + AzureServer.vnet_name(name), + AzureServer.subnet_name(name), + {"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}}, + ) - # Create the VM - with self.provisioning_semaphore: - poller = compute_client.virtual_machines.begin_create_or_update( + # Create a public IPv4 address for this instance + ip_poller = network_client.public_ip_addresses.begin_create_or_update( + resource_group, + AzureServer.ip_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "sku": {"name": "Standard"}, + "public_ip_allocation_method": "Static", + "public_ip_address_version": "IPV4", + }, + ) + + subnet_result = subnet_poller.result() + public_ip_result = ip_poller.result() + + with Timer("Creating Azure NIC"): + # Create a NIC for this instance, with accelerated networking enabled + poller = network_client.network_interfaces.begin_create_or_update( resource_group, - AzureServer.vm_name(name), + AzureServer.nic_name(name), { "location": location, "tags": {"skylark": "true"}, - "hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)}, - "storage_profile": { - "image_reference": { - "publisher": "canonical", - "offer": "0001-com-ubuntu-server-focal", - "sku": "20_04-lts", - "version": "latest", + "ip_configurations": [ + {"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}} + ], + "enable_accelerated_networking": True, + }, + ) + nic_result = poller.result() + + # Create the VM + with Timer("Creating Azure VM"): + with self.provisioning_semaphore: + poller = compute_client.virtual_machines.begin_create_or_update( + resource_group, + AzureServer.vm_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)}, + "storage_profile": { + "image_reference": { + "publisher": "canonical", + "offer": "0001-com-ubuntu-server-focal", + "sku": "20_04-lts", + "version": "latest", + }, + "os_disk": {"create_option": "FromImage", "delete_option": "Delete"}, }, - "os_disk": {"create_option": "FromImage", "delete_option": "Delete"}, - }, - "os_profile": { - "computer_name": AzureServer.vm_name(name), - "admin_username": uname, - "linux_configuration": { - "disable_password_authentication": True, - "ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]}, + "os_profile": { + "computer_name": AzureServer.vm_name(name), + "admin_username": uname, + "linux_configuration": { + "disable_password_authentication": True, + "ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]}, + }, }, + "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, + # give VM managed identity w/ system assigned identity + "identity": {"type": ResourceIdentityType.system_assigned}, }, - "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, - # give VM managed identity w/ system assigned identity - "identity": {"type": ResourceIdentityType.system_assigned}, - }, + ) + vm_result = poller.result() + + with Timer("Role assignment"): + # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment + # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets + auth_client = self.auth.get_authorization_client() + scope = f"/subscriptions/{self.auth.subscription_id}" + role_name = "Contributor" + roles = list(auth_client.role_definitions.list(scope, filter="roleName eq '{}'".format(role_name))) + assert len(roles) == 1 + + # Add RG scope to the MSI identities: + role_assignment = auth_client.role_assignments.create( + scope, + uuid.uuid4(), # Role assignment random name + RoleAssignmentCreateParameters( + properties=dict(role_definition_id=roles[0].id, principal_id=vm_result.identity.principal_id) + ), ) - vm_result = poller.result() - - # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment - # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets - auth_client = self.auth.get_authorization_client() - role_name = "Contributor" - roles = list(auth_client.role_definitions.list(resource_group, filter="roleName eq '{}'".format(role_name))) - assert len(roles) == 1 - - # Add RG scope to the MSI identities: - role_assignment = auth_client.role_assignments.create( - resource_group, - uuid.uuid4(), # Role assignment random name - {"role_definition_id": roles[0].id, "principal_id": vm_result.identity.principal_id}, - ) return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index ceb618e7d..72fefe200 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -117,7 +117,7 @@ def is_valid(self): return False def uuid(self): - return f"{self.subscription_id}:{self.region_tag}:{self.name}" + return f"{self.region_tag}:{self.name}" def instance_state(self) -> ServerState: compute_client = self.auth.get_compute_client() From c4107c9aed926120e28de3fb32eb2e6e1ccc9825 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 10 Mar 2022 20:53:38 +0000 Subject: [PATCH 11/34] ?~ --- skylark/cli/cli.py | 8 ++++---- skylark/compute/aws/aws_server.py | 2 +- skylark/compute/azure/azure_cloud_provider.py | 12 ++++++++--- skylark/compute/azure/azure_server.py | 20 ++++++++----------- skylark/compute/gcp/gcp_server.py | 2 +- skylark/compute/server.py | 7 ++++--- skylark/replicate/replicator_client.py | 6 ++++-- 7 files changed, 31 insertions(+), 26 deletions(-) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 007d1c86c..fe48b40ed 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -113,7 +113,7 @@ def replicate_random( reuse_gateways: bool = False, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), aws_instance_class: str = "m5.8xlarge", - azure_instance_class: str = "Standard_D32_v5", + azure_instance_class: str = "Standard_D32_v4", gcp_instance_class: Optional[str] = "n2-standard-32", gcp_use_premium_network: bool = True, key_prefix: str = "/test/replicate_random", @@ -146,7 +146,7 @@ def replicate_random( ) if not reuse_gateways: - atexit.register(rc.deprovision_gateways) + atexit.register(rc.deprovision_gateways, block=False) else: logger.warning( f"Instances will remain up and may result in continued cloud billing. Remember to call `skylark deprovision` to deprovision gateways." @@ -197,7 +197,7 @@ def replicate_json( gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), # cloud provider specific options aws_instance_class: str = "m5.8xlarge", - azure_instance_class: str = "Standard_D32_v5", + azure_instance_class: str = "Standard_D32_v4", gcp_instance_class: Optional[str] = "n2-standard-32", gcp_use_premium_network: bool = True, # logging options @@ -221,7 +221,7 @@ def replicate_json( ) if not reuse_gateways: - atexit.register(rc.deprovision_gateways) + atexit.register(rc.deprovision_gateways, block=False) else: logger.warning( f"Instances will remain up and may result in continued cloud billing. Remember to call `skylark deprovision` to deprovision gateways." diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 797459633..720f0dbca 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -92,7 +92,7 @@ def instance_state(self): def __repr__(self): return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})" - def terminate_instance_impl(self): + def terminate_instance_impl(self, block=True): self.auth.get_boto3_resource("ec2", self.aws_region).instances.filter(InstanceIds=[self.instance_id]).terminate() def get_ssh_client_impl(self): diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index b4ff6f7c8..83557b07e 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -414,10 +414,16 @@ def provision_instance( "tags": {"skylark": "true"}, "hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)}, "storage_profile": { + # "image_reference": { + # "publisher": "canonical", + # "offer": "0001-com-ubuntu-server-focal", + # "sku": "20_04-lts", + # "version": "latest", + # }, "image_reference": { - "publisher": "canonical", - "offer": "0001-com-ubuntu-server-focal", - "sku": "20_04-lts", + "publisher": "microsoft-aks", + "offer": "aks", + "sku": "aks-engine-ubuntu-1804-202112", "version": "latest", }, "os_disk": {"create_option": "FromImage", "delete_option": "Delete"}, diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 72fefe200..f6c80cc9e 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -157,28 +157,24 @@ def tags(self): def network_tier(self): return "PREMIUM" - def terminate_instance_impl(self): + def terminate_instance_impl(self, block=True): compute_client = self.auth.get_compute_client() network_client = self.auth.get_network_client() - vm_poller = compute_client.virtual_machines.begin_delete(AzureServer.resource_group_name, self.vm_name(self.name)) - _ = vm_poller.result() - + vm_poller.result() nic_poller = network_client.network_interfaces.begin_delete(AzureServer.resource_group_name, self.nic_name(self.name)) - _ = nic_poller.result() - ip_poller = network_client.public_ip_addresses.begin_delete(AzureServer.resource_group_name, self.ip_name(self.name)) subnet_poller = network_client.subnets.begin_delete( AzureServer.resource_group_name, self.vnet_name(self.name), self.subnet_name(self.name) ) - _ = ip_poller.result() - _ = subnet_poller.result() - nsg_poller = network_client.network_security_groups.begin_delete(AzureServer.resource_group_name, self.nsg_name(self.name)) - _ = nsg_poller.result() - vnet_poller = network_client.virtual_networks.begin_delete(AzureServer.resource_group_name, self.vnet_name(self.name)) - _ = vnet_poller.result() + if block: + nsg_poller.result() + ip_poller.result() + subnet_poller.result() + nic_poller.result() + vnet_poller.result() def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index e5ad6d2b9..6ead951e1 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -82,7 +82,7 @@ def network_tier(self): def __repr__(self): return f"GCPServer(region_tag={self.region_tag}, instance_name={self.gcp_instance_name})" - def terminate_instance_impl(self): + def terminate_instance_impl(self, block=True): self.auth.get_gcp_client().instances().delete( project=self.auth.project_id, zone=self.gcp_region, instance=self.instance_name() ).execute() diff --git a/skylark/compute/server.py b/skylark/compute/server.py index f86ce026b..016c6af4f 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -127,10 +127,10 @@ def tags(self): def network_tier(self): raise NotImplementedError() - def terminate_instance_impl(self): + def terminate_instance_impl(self, block=True): raise NotImplementedError() - def terminate_instance(self): + def terminate_instance(self, block=False): """Terminate instance""" self.close_server() self.terminate_instance_impl() @@ -204,7 +204,8 @@ def check_stderr(tup): # increase TCP connections, enable BBR optionally and raise file limits check_stderr(self.run_command(make_sysctl_tcp_tuning_command(cc="bbr" if use_bbr else "cubic"))) - retry_backoff(self.install_docker, exception_class=RuntimeError) + with Timer("Install docker"): + retry_backoff(self.install_docker, exception_class=RuntimeError) self.run_command(make_dozzle_command(log_viewer_port)) # read AWS config file to get credentials diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index f3bd4a68a..7c0d0ef82 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -191,12 +191,13 @@ def setup(server: Server): args.append((server, {self.bound_nodes[n].public_ip(): v for n, v in self.topology.get_outgoing_paths(node).items()})) do_parallel(lambda arg: arg[0].start_gateway(arg[1], gateway_docker_image=self.gateway_docker_image), args, n=-1) - def deprovision_gateways(self): + def deprovision_gateways(self, block=True): def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: logger.warning(f"Deprovisioning {server.uuid()}") - server.terminate_instance() + server.terminate_instance(block=block) + logger.warning("Deprovisioning instances") do_parallel(deprovision_gateway_instance, self.bound_nodes.values(), n=-1) def run_replication_plan(self, job: ReplicationJob) -> ReplicationJob: @@ -355,6 +356,7 @@ def fn(s: Server): do_parallel(fn, self.bound_nodes.values(), n=-1) if cancel_pending: + logger.debug("Registering shutdown handler") atexit.register(shutdown_handler) with Timer() as t: From 94832a26765e5c19b0ae0c73087f00cc2a9d017f Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Mon, 14 Mar 2022 23:16:46 +0000 Subject: [PATCH 12/34] GCP auth --- scripts/requirements-gateway.txt | 1 + setup.py | 1 + skylark/cli/cli_helper.py | 14 ++++---------- skylark/compute/gcp/gcp_auth.py | 20 +++++--------------- 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/scripts/requirements-gateway.txt b/scripts/requirements-gateway.txt index c5d5c1660..647c28580 100644 --- a/scripts/requirements-gateway.txt +++ b/scripts/requirements-gateway.txt @@ -9,6 +9,7 @@ boto3 click>=7.1.2 flask google-api-python-client +google-auth google-cloud-compute google-cloud-storage grpcio-status>=1.33.2 diff --git a/setup.py b/setup.py index 2a470fbf4..ad03a7789 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ "click>=7.1.2", "flask", "google-api-python-client", + "google-auth", "google-cloud-compute", "google-cloud-storage", "grpcio-status>=1.33.2", diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index ebde84ffb..a11c835ef 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -332,15 +332,9 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC # check if GCP is enabled auth = GCPAuthentication() - try: - auth.get_gcp_client("cloudbilling", "v1").services().list().execute() - gcp_enabled = True - except Exception as e: - print(e) - gcp_enabled = False - if not gcp_enabled: + if not auth.credentials: typer.secho( - " Default GCP credentials are not set up yet. Run `gcloud auth application-default login` or set GOOGLE_APPLICATION_CREDENTIALS.", + " Default GCP credentials are not set up yet. Run `gcloud auth application-default login`.", fg="red", ) typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") @@ -348,9 +342,9 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC return config else: typer.secho(" GCP credentials found in GCP CLI", fg="blue") - inferred_project_id = GCPAuthentication.infer_project_id() if typer.confirm(" GCP credentials found, do you want to enable GCP support in Skylark?", default=True): - config.gcp_project_id = typer.prompt(" Enter the GCP project ID:", default=inferred_project_id) + config.gcp_project_id = typer.prompt(" Enter the GCP project ID:", default=auth.project_id) + assert config.gcp_project_id is not None, "GCP project ID must not be None" return config else: config.gcp_project_id = None diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index 81740dec7..1dc582e60 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -1,26 +1,16 @@ -from functools import lru_cache -import os -import subprocess +from typing import Optional import googleapiclient.discovery +import google.auth from skylark import cloud_config class GCPAuthentication: - def __init__(self, project_id: str = cloud_config.gcp_project_id): - self.project_id = project_id + def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): + self.credentials, self.project_id = google.auth.default(quota_project_id=project_id) def enabled(self): - return self.project_id is not None - - @staticmethod - def infer_project_id(): - if "GOOGLE_PROJECT_ID" in os.environ: - return os.environ["GOOGLE_PROJECT_ID"] - try: - return subprocess.check_output(["gcloud", "config", "get-value", "project"]).decode("utf-8").strip() - except subprocess.CalledProcessError: - return None + return self.credentials is not None and self.project_id is not None def get_gcp_client(self, service_name="compute", version="v1"): return googleapiclient.discovery.build(service_name, version) From cc1a6a48652a5ac9dfc0c39e685d37b37c8a5784 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 00:17:01 +0000 Subject: [PATCH 13/34] Fix service quota query --- skylark/cli/cli_aws.py | 8 +++++++- skylark/utils/logger.py | 7 ++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index aab6cd002..671e14a82 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -17,6 +17,7 @@ from skylark.compute.aws.aws_server import AWSServer from skylark.obj_store.s3_interface import S3Interface from skylark.utils.utils import Timer, do_parallel +from skylark.utils import logger app = typer.Typer(name="skylark-aws") @@ -28,7 +29,12 @@ def vcpu_limits(quota_code="L-1216C47A"): def get_service_quota(region): service_quotas = aws_auth.get_boto3_client("service-quotas", region) - response = service_quotas.get_service_quota(ServiceCode="ec2", QuotaCode=quota_code) + try: + response = service_quotas.get_service_quota(ServiceCode="ec2", QuotaCode=quota_code) + except Exception as e: + logger.exception(e, print_traceback=False) + logger.error(f"Failed to get service quota for {quota_code} in {region}") + return -1 return response["Quota"]["Value"] quotas = do_parallel(get_service_quota, AWSCloudProvider.region_list()) diff --git a/skylark/utils/logger.py b/skylark/utils/logger.py index b4bcbebf0..38f0dd8b1 100644 --- a/skylark/utils/logger.py +++ b/skylark/utils/logger.py @@ -20,8 +20,9 @@ def log(msg, LEVEL="INFO", color="white", *args, **kwargs): error = partial(log, LEVEL="ERROR", color="red") -def exception(msg, *args, **kwargs): +def exception(msg, print_traceback=True, *args, **kwargs): error(f"Exception: {msg}", *args, **kwargs) - import traceback + if print_traceback: + import traceback - traceback.print_exc() + traceback.print_exc() From 84bf734a032b3c061130d6a59ffd472656124577 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 01:07:14 +0000 Subject: [PATCH 14/34] experiemnt paras --- scripts/experiment_paras.sh | 55 +++++++++++++++++++++++++++++++++++++ scripts/setup_bucket.py | 29 ------------------- skylark/cli/cli_helper.py | 4 +-- 3 files changed, 56 insertions(+), 32 deletions(-) create mode 100755 scripts/experiment_paras.sh diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh new file mode 100755 index 000000000..81e34bf41 --- /dev/null +++ b/scripts/experiment_paras.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set -xe + +src=$1 +dest=$2 + +key_prefix="fake_imagenet" +bucket_prefix="exps-paras" +src_bucket=(${src//:/ }) +src_bucket=${bucket_prefix}-skylark-${src_bucket[1]} +dest_bucket=(${dest//:/ }) +dest_bucket=${bucket_prefix}-skylark-${dest_bucket[1]} +echo $src_bucket +echo $dest_bucket +max_instance=1 +experiment=${src//[:]/-}_${dest//[:]/-}_${max_instance}_${key_prefix//[\/]/-} +filename=data/plan/${experiment}.json +echo $filename + +# creats buckets + bucket data and sets env variables +# python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} + + +# TODO:artificially increase the number of chunks +# TODO: try synthetic data + +source scripts/pack_docker.sh; + +## create plan +throughput=$(($max_instance*3)) +# throughput=25 +skylark solver solve-throughput ${src} ${dest} ${throughput} -o ${filename} --max-instances ${max_instance}; +echo ${filename} + +# make exp directory +mkdir -p data/results +mkdir -p data/results/${experiment} + +# save copy of plan +cp ${filename} data/results/${experiment} + +## run replication (random) +#skylark replicate-json ${filename} \ +# --use-random-data \ +# --size-total-mb 73728 \ +# --n-chunks 1152 &> data/results/${experiment}/random-logs.txt +#tail -1 data/results/${experiment}/random-logs.txt; + +# run replication (obj store) +skylark replicate-json ${filename} \ + --source-bucket $src_bucket \ + --dest-bucket $dest_bucket \ + --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt +tail -1 data/results/${experiment}/obj-store-logs.txt; +echo ${experiment} diff --git a/scripts/setup_bucket.py b/scripts/setup_bucket.py index 7f4be455a..a9c6ffe42 100644 --- a/scripts/setup_bucket.py +++ b/scripts/setup_bucket.py @@ -14,21 +14,11 @@ def parse_args(): parser = argparse.ArgumentParser(description="Setup replication experiment") - parser.add_argument("--src-data-path", default="../fake_imagenet", help="Data to upload to src bucket") - - # gateway path parameters parser.add_argument("--src-region", default="aws:us-east-1", help="AWS region of source bucket") parser.add_argument("--dest-region", default="aws:us-west-1", help="AWS region of destination bucket") - - # bucket namespace parser.add_argument("--bucket-prefix", default="sarah", help="Prefix for bucket to avoid naming collision") parser.add_argument("--key-prefix", default="", help="Prefix keys") - - # gateway provisioning - parser.add_argument("--gcp-project", default=None, help="GCP project ID") - parser.add_argument("--azure-subscription", default=None, help="Azure subscription") - parser.add_argument("--gateway-docker-image", default="ghcr.io/parasj/skylark:main", help="Docker image for gateway instances") args = parser.parse_args() return args @@ -56,13 +46,6 @@ def main(args): obj_store_interface_dst.create_bucket() print("running upload... (note: may need to chunk)") - - ## TODO: chunkify - # p = Pool(16) - # uploaded = p.starmap(upload, [(args.src_region, src_bucket, os.path.join(args.src_data_path, f), f"{args.key_prefix}/{f}") for f in os.listdir(args.src_data_path)]) - # p.close() - # print(f"uploaded {sum(uploaded)} files to {src_bucket}") - futures = [] for f in tqdm(os.listdir(args.src_data_path)): futures.append(obj_store_interface_src.upload_object(os.path.join(args.src_data_path, f), f"{args.key_prefix}/{f}")) @@ -71,18 +54,6 @@ def main(args): wait(futures) futures = [] - ### check files - # for f in tqdm(os.listdir(args.src_data_path)): - # assert obj_store_interface_src.exists(f"{args.key_prefix}/{f}") - - def done_uploading(): - bucket_size = len(list(obj_store_interface_src.list_objects(prefix=args.key_prefix))) - # f"Length mismatch {len(os.listdir(args.src_data_path))}, {bucket_size}" - print("bucket", bucket_size, len(os.listdir(args.src_data_path))) - return len(os.listdir(args.src_data_path)) == bucket_size - - ##wait_for(done_uploading, timeout=60, interval=0.1, desc=f"Waiting for files to upload") - if __name__ == "__main__": main(parse_args()) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index a11c835ef..46bdd199b 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -302,9 +302,7 @@ def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> Skylar try: auth.credential.get_token("https://management.azure.com/") azure_enabled = True - except Exception as e: - print(e) - print(e.message) + except: azure_enabled = False if not azure_enabled: typer.secho(" No local Azure credentials! Run `az login` to set them up.", fg="red") From 62cb42eb8db8e97665bf6a2b620eebb4bff1370a Mon Sep 17 00:00:00 2001 From: Shishir Patil Date: Tue, 15 Mar 2022 01:45:24 +0000 Subject: [PATCH 15/34] Catch resource group not available --- skylark/obj_store/azure_interface.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index e9b164dcd..270055b1a 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -1,7 +1,7 @@ import os from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List -from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError, ServiceRequestError from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject @@ -23,6 +23,7 @@ def __init__(self, azure_region, container_name): # Create a blob service client self.auth = AzureAuthentication() self.account_url = "https://{}.blob.core.windows.net".format("skylark" + self.azure_region) + print("===> Account URL:", self.account_url) self.blob_service_client = self.auth.get_storage_client(self.account_url) self.pool = ThreadPoolExecutor(max_workers=256) # TODO: This might need some tuning @@ -40,7 +41,8 @@ def _on_done_upload(self, **kwargs): self.completed_uploads += 1 self.pending_uploads -= 1 - def container_exists(self): # More like "is container empty?" + def container_exists(self): + # More like "is container empty?" # Get a client to interact with a specific container - though it may not yet exist if self.container_client is None: self.container_client = self.blob_service_client.get_container_client(self.container_name) @@ -49,6 +51,10 @@ def container_exists(self): # More like "is container empty?" return True except ResourceNotFoundError: return False + except ServiceRequestError: + logger.error("==> Unable to access storage account for region specified") + logger.error("==> Aborting. Please check your Azure credentials and region") + exit(-1) def create_container(self): try: From be4bfc2afe922a6a513a8c8c11294435f7a046c9 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 20:50:48 +0000 Subject: [PATCH 16/34] Merge branch 'main' into dev/paras/better_init_config --- skylark/cli/cli.py | 4 +- skylark/cli/cli_helper.py | 5 ++- skylark/compute/aws/aws_cloud_provider.py | 51 ++++++++++++++++++++++- skylark/compute/cloud_providers.py | 7 ++-- skylark/replicate/replicator_client.py | 15 ++++--- skylark/test/test_replicator_client.py | 2 +- skylark/utils/utils.py | 12 ++++-- 7 files changed, 74 insertions(+), 22 deletions(-) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index fe48b40ed..d325805fa 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -265,9 +265,7 @@ def replicate_json( total_bytes = sum([chunk_req.chunk.chunk_length_bytes for chunk_req in job.chunk_requests]) logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer( - job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds, cancel_pending=False - ) + stats = rc.monitor_transfer(job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 46bdd199b..92c2a7aef 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -1,4 +1,5 @@ import concurrent.futures +from functools import partial import os import re import resource @@ -254,7 +255,7 @@ def deprovision_skylark_instances(): logger.debug("AWS authentication enabled, querying for instances") aws = AWSCloudProvider() for region in aws.region_list(): - query_jobs.append(lambda: aws.get_matching_instances(region)) + query_jobs.append(partial(aws.get_matching_instances, region)) if AzureAuthentication().enabled(): logger.debug("Azure authentication enabled, querying for instances") query_jobs.append(lambda: AzureCloudProvider().get_matching_instances()) @@ -263,7 +264,7 @@ def deprovision_skylark_instances(): query_jobs.append(lambda: GCPCloudProvider().get_matching_instances()) # query in parallel - for _, instance_list in do_parallel(lambda f: f(), query_jobs, progress_bar=True, desc="Query instances", arg_fmt=None): + for _, instance_list in do_parallel(lambda f: f(), query_jobs, progress_bar=True, desc="Query instances", hide_args=True): instances.extend(instance_list) if instances: diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 5c91d2834..57d39dc5d 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -1,3 +1,4 @@ +import json import uuid from typing import List, Optional @@ -10,7 +11,7 @@ from skylark import skylark_root from skylark.compute.aws.aws_server import AWSServer from skylark.compute.cloud_providers import CloudProvider -from skylark.utils.utils import retry_backoff +from skylark.utils.utils import retry_backoff, wait_for class AWSCloudProvider(CloudProvider): @@ -194,6 +195,28 @@ def delete_vpc(self, region: str, vpcid: str): # finally, delete the vpc ec2client.delete_vpc(VpcId=vpcid) + def create_iam(self, iam_name: str = "skylark_gateway", attach_policy_arn: Optional[str] = None): + """Create IAM role if it doesn't exist and grant managed role if given.""" + + @lockutils.synchronized(f"aws_create_iam_{iam_name}", external=True, lock_path="/tmp/skylark_locks") + def fn(): + iam = self.auth.get_boto3_client("iam") + + # create IAM role + try: + iam.get_role(RoleName=iam_name) + except iam.exceptions.NoSuchEntityException: + doc = { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Principal": {"Service": "ec2.amazonaws.com"}, "Action": "sts:AssumeRole"}], + } + iam.create_role(RoleName=iam_name, AssumeRolePolicyDocument=json.dumps(doc), Tags=[{"Key": "skylark", "Value": "true"}]) + if attach_policy_arn: + iam.attach_role_policy(RoleName=iam_name, PolicyArn=attach_policy_arn) + + + return fn() + def add_ip_to_security_group(self, aws_region: str): """Add IP to security group. If security group ID is None, use group named skylark (create if not exists).""" @@ -219,18 +242,40 @@ def provision_instance( # ami_id: Optional[str] = None, tags={"skylark": "true"}, ebs_volume_size: int = 128, + iam_name: str = "skylark_gateway", ) -> AWSServer: assert not region.startswith("aws:"), "Region should be AWS region" if name is None: name = f"skylark-aws-{str(uuid.uuid4()).replace('-', '')}" + iam_instance_profile_name = f"{name}_profile" + iam = self.auth.get_boto3_client("iam") ec2 = self.auth.get_boto3_resource("ec2", region) vpc = self.get_vpc(region) assert vpc is not None, "No VPC found" subnets = list(vpc.subnets.all()) assert len(subnets) > 0, "No subnets found" + + def check_iam_role(): + try: + iam.get_role(RoleName=iam_name) + return True + except iam.exceptions.NoSuchEntityException: + return False + + def check_instance_profile(): + try: + iam.get_instance_profile(InstanceProfileName=iam_instance_profile_name) + return True + except iam.exceptions.NoSuchEntityException: + return False + + # wait for iam_role to be created and create instance profile + wait_for(check_iam_role, timeout=60, interval=.5) + iam.create_instance_profile(InstanceProfileName=iam_instance_profile_name, Tags=[{"Key": "skylark", "Value": "true"}]) + iam.add_role_to_instance_profile(InstanceProfileName=iam_instance_profile_name, RoleName=iam_name) + wait_for(check_instance_profile, timeout=60, interval=.5) def start_instance(): - # todo instance-initiated-shutdown-behavior terminate return ec2.create_instances( ImageId="resolve:ssm:/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id", InstanceType=instance_class, @@ -258,6 +303,8 @@ def start_instance(): "DeleteOnTermination": True, } ], + IamInstanceProfile={"Name": iam_instance_profile_name}, + InstanceInitiatedShutdownBehavior='terminate', ) instance = retry_backoff(start_instance, initial_backoff=1) diff --git a/skylark/compute/cloud_providers.py b/skylark/compute/cloud_providers.py index b37d04cc2..69e7d8402 100644 --- a/skylark/compute/cloud_providers.py +++ b/skylark/compute/cloud_providers.py @@ -44,13 +44,12 @@ def get_matching_instances( tags={"skylark": "true"}, ) -> List[Server]: if isinstance(region, str): - region = [region] + results = [(region, self.get_instance_list(region))] elif region is None: - region = self.region_list() + results = do_parallel(self.get_instance_list, self.region_list(), n=-1) - results = do_parallel(self.get_instance_list, region, n=-1) matching_instances = [] - for r, instances in results: + for _, instances in results: for instance in instances: if not (instance_type is None or instance_type == instance.instance_class()): continue diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index b4805a4de..82e325863 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -70,6 +70,7 @@ def provision_gateways( # init clouds jobs = [] + jobs.append(partial(self.aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess")) for r in set(aws_regions_to_provision): jobs.append(partial(self.aws.add_ip_to_security_group, r.split(":")[1])) if azure_regions_to_provision: @@ -315,7 +316,7 @@ def monitor_transfer( show_pbar=False, log_interval_s: Optional[float] = None, time_limit_seconds: Optional[float] = None, - cancel_pending: bool = True, + cleanup_gateway: bool = True, save_log: bool = True, write_profile: bool = True, copy_gateway_logs: bool = True, @@ -335,9 +336,11 @@ def shutdown_handler(): if copy_gateway_logs: for instance in self.bound_nodes.values(): logger.info(f"Copying gateway logs from {instance.uuid()}") - instance.run_command("sudo docker logs -t skylark_gateway &> /tmp/gateway.log") - log_out = transfer_dir / f"gateway_{instance.uuid()}.log" - instance.download_file("/tmp/gateway.log", log_out) + instance.run_command("sudo docker logs -t skylark_gateway 2> /tmp/gateway.stderr > /tmp/gateway.stdout") + log_out = transfer_dir / f"gateway_{instance.uuid()}.stdout" + log_err = transfer_dir / f"gateway_{instance.uuid()}.stderr" + instance.download_file("/tmp/gateway.stdout", log_out) + instance.download_file("/tmp/gateway.stderr", log_err) logger.debug(f"Wrote gateway logs to {transfer_dir}") if write_profile: chunk_status_df = self.get_chunk_status_log_df() @@ -356,7 +359,7 @@ def fn(s: Server): do_parallel(fn, self.bound_nodes.values(), n=-1) - if cancel_pending: + if cleanup_gateway: logger.debug("Registering shutdown handler") atexit.register(shutdown_handler) @@ -389,7 +392,7 @@ def fn(s: Server): or time_limit_seconds is not None and t.elapsed > time_limit_seconds ): - if cancel_pending: + if cleanup_gateway: atexit.unregister(shutdown_handler) shutdown_handler() return dict( diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index 3544e40a9..9bb15e53b 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -147,7 +147,7 @@ def main(args): total_bytes = args.n_chunks * args.chunk_size_mb * MB job = rc.run_replication_plan(job) logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer(job, show_pbar=True, cancel_pending=False) + stats = rc.monitor_transfer(job, show_pbar=True, cleanup_gateway=False) logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}Gbit/s)") diff --git a/skylark/utils/utils.py b/skylark/utils/utils.py index 1fb823932..ed5271c8c 100644 --- a/skylark/utils/utils.py +++ b/skylark/utils/utils.py @@ -48,7 +48,7 @@ def wait_for(fn: Callable[[], bool], timeout=60, interval=0.25, progress_bar=Fal def do_parallel( - func: Callable[[T], R], args_list: Iterable[T], n=-1, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None + func: Callable[[T], R], args_list: Iterable[T], n=-1, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None, hide_args=False ) -> List[Tuple[T, R]]: """Run list of jobs in parallel with tqdm progress bar""" args_list = list(args_list) @@ -71,15 +71,19 @@ def wrapped_fn(args): for future in as_completed(future_list): args, result = future.result() results.append((args, result)) - pbar.set_description(f"{desc} ({str(arg_fmt(args))})" if desc else str(arg_fmt(args))) + if not hide_args: + pbar.set_description(f"{desc} ({str(arg_fmt(args))})" if desc else str(arg_fmt(args))) + else: + pbar.set_description(desc) pbar.update() return results def retry_backoff( fn: Callable[[], R], - max_retries=4, + max_retries=8, initial_backoff=0.1, + max_backoff=8, exception_class=Exception, ) -> R: """Retry fn until it does not raise an exception. @@ -97,4 +101,4 @@ def retry_backoff( else: logger.warning(f"Retrying {fn.__name__} due to {e} (attempt {i + 1}/{max_retries})") time.sleep(backoff) - backoff *= 2 + backoff = min(backoff * 2, max_backoff) From 46722198a154aa39268c1040268c50a01976c3a0 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 20:53:41 +0000 Subject: [PATCH 17/34] Fix credential inference --- skylark/compute/aws/aws_auth.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index 336010b0b..7e2aeb437 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -28,8 +28,11 @@ def infer_credentials(self): # todo load temporary credentials from STS session = boto3.Session() credentials = session.get_credentials() - credentials = credentials.get_frozen_credentials() - return credentials.access_key, credentials.secret_key + if credentials: + credentials = credentials.get_frozen_credentials() + return credentials.access_key, credentials.secret_key + else: + return None, None def get_boto3_session(self, aws_region: str): if self.config_mode == "manual": From af7e34a3b4e4b64347010553b696cd410d92370c Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 23:27:42 +0000 Subject: [PATCH 18/34] Regenerate key pairs --- scripts/experiment_paras.sh | 1 + skylark/compute/aws/aws_cloud_provider.py | 11 ++++---- skylark/compute/aws/aws_server.py | 12 ++++----- skylark/compute/server.py | 31 +++++++++++++---------- skylark/gateway/gateway_obj_store.py | 6 +++-- 5 files changed, 32 insertions(+), 29 deletions(-) diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh index 81e34bf41..91887346c 100755 --- a/scripts/experiment_paras.sh +++ b/scripts/experiment_paras.sh @@ -50,6 +50,7 @@ cp ${filename} data/results/${experiment} skylark replicate-json ${filename} \ --source-bucket $src_bucket \ --dest-bucket $dest_bucket \ + --reuse-gateways \ --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt tail -1 data/results/${experiment}/obj-store-logs.txt; echo ${experiment} diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 57d39dc5d..885bd4aa9 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -214,7 +214,6 @@ def fn(): if attach_policy_arn: iam.attach_role_policy(RoleName=iam_name, PolicyArn=attach_policy_arn) - return fn() def add_ip_to_security_group(self, aws_region: str): @@ -254,14 +253,14 @@ def provision_instance( assert vpc is not None, "No VPC found" subnets = list(vpc.subnets.all()) assert len(subnets) > 0, "No subnets found" - + def check_iam_role(): try: iam.get_role(RoleName=iam_name) return True except iam.exceptions.NoSuchEntityException: return False - + def check_instance_profile(): try: iam.get_instance_profile(InstanceProfileName=iam_instance_profile_name) @@ -270,10 +269,10 @@ def check_instance_profile(): return False # wait for iam_role to be created and create instance profile - wait_for(check_iam_role, timeout=60, interval=.5) + wait_for(check_iam_role, timeout=60, interval=0.5) iam.create_instance_profile(InstanceProfileName=iam_instance_profile_name, Tags=[{"Key": "skylark", "Value": "true"}]) iam.add_role_to_instance_profile(InstanceProfileName=iam_instance_profile_name, RoleName=iam_name) - wait_for(check_instance_profile, timeout=60, interval=.5) + wait_for(check_instance_profile, timeout=60, interval=0.5) def start_instance(): return ec2.create_instances( @@ -304,7 +303,7 @@ def start_instance(): } ], IamInstanceProfile={"Name": iam_instance_profile_name}, - InstanceInitiatedShutdownBehavior='terminate', + InstanceInitiatedShutdownBehavior="terminate", ) instance = retry_backoff(start_instance, initial_backoff=1) diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 720f0dbca..500326aaa 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -1,3 +1,4 @@ +import subprocess import os from pathlib import Path from typing import Dict, Optional @@ -33,19 +34,17 @@ def get_boto3_instance_resource(self): return ec2.Instance(self.instance_id) def ensure_keyfile_exists(self, aws_region, prefix=key_root / "aws"): + ec2 = self.auth.get_boto3_resource("ec2", aws_region) + ec2_client = self.auth.get_boto3_client("ec2", aws_region) prefix = Path(prefix) key_name = f"skylark-{aws_region}" local_key_file = prefix / f"{key_name}.pem" @lockutils.synchronized(f"aws_keyfile_lock_{aws_region}", external=True, lock_path="/tmp/skylark_locks") def create_keyfile(): - if not local_key_file.exists(): # we have to check again since another process may have created it - ec2 = self.auth.get_boto3_resource("ec2", aws_region) - ec2_client = self.auth.get_boto3_client("ec2", aws_region) + if not local_key_file.exists(): local_key_file.parent.mkdir(parents=True, exist_ok=True) - # delete key pair from ec2 if it exists - keys_in_region = set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]) - if key_name in keys_in_region: + if key_name in set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]): logger.warning(f"Deleting key {key_name} in region {aws_region}") ec2_client.delete_key_pair(KeyName=key_name) key_pair = ec2.create_key_pair(KeyName=f"skylark-{aws_region}", KeyType="rsa") @@ -54,7 +53,6 @@ def create_keyfile(): if not key_str.endswith("\n"): key_str += "\n" f.write(key_str) - f.flush() # sometimes generates keys with zero bytes, so we flush to ensure it's written os.chmod(local_key_file, 0o600) logger.info(f"Created key file {local_key_file}") diff --git a/skylark/compute/server.py b/skylark/compute/server.py index 8ca3331aa..7b1c725b4 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -178,6 +178,12 @@ def download_file(self, remote_path, local_path): with client.open_sftp() as sftp: sftp.get(remote_path, local_path) + def upload_file(self, local_path, remote_path): + """Upload a file to the server""" + client = self.ssh_client + with client.open_sftp() as sftp: + sftp.put(local_path, remote_path) + def copy_public_key(self, pub_key_path: PathLike): """Append public key to authorized_keys file on server.""" pub_key_path = Path(pub_key_path) @@ -201,38 +207,35 @@ def start_gateway( log_viewer_port=8888, use_bbr=False, ): - self.wait_for_ready() - def check_stderr(tup): assert tup[1].strip() == "", f"Command failed, err: {tup[1]}" desc_prefix = f"Starting gateway {self.uuid()}, host: {self.public_ip()}" + self.wait_for_ready() # increase TCP connections, enable BBR optionally and raise file limits check_stderr(self.run_command(make_sysctl_tcp_tuning_command(cc="bbr" if use_bbr else "cubic"))) with Timer("Install docker"): retry_backoff(self.install_docker, exception_class=RuntimeError) - self.run_command(make_dozzle_command(log_viewer_port)) - # read AWS config file to get credentials - # TODO: Integrate this with updated skylark config file - # copy config file - config = config_path.read_text()[:-2] + "}" - config = json.dumps(config) # Convert to JSON string and remove trailing comma/new-line - self.run_command(f'mkdir -p /tmp; echo "{config}" | sudo tee /tmp/{config_path.name} > /dev/null') + # start log viewer + self.run_command(make_dozzle_command(log_viewer_port)) - docker_envs = "" # If needed, add environment variables to docker command + # copy cloud configuration + docker_envs = {} + if config_path.exists(): + self.upload_file(config_path, f"/tmp/{config_path.name}") + docker_envs["SKYLARK_CONFIG"] = f"/pkg/data/{config_path.name}" + # pull docker image and start container with Timer(f"{desc_prefix}: Docker pull"): docker_out, docker_err = self.run_command(f"sudo docker pull {gateway_docker_image}") assert "Status: Downloaded newer image" in docker_out or "Status: Image is up to date" in docker_out, (docker_out, docker_err) logger.debug(f"{desc_prefix}: Starting gateway container") - docker_run_flags = ( - f"-d --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024} {docker_envs}" - ) + docker_run_flags = f"-d --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024}" docker_run_flags += " --mount type=tmpfs,dst=/skylark,tmpfs-size=$(($(free -b | head -n2 | tail -n1 | awk '{print $2}')/2))" docker_run_flags += f" -v /tmp/{config_path.name}:/pkg/data/{config_path.name}" - docker_run_flags += f" -e SKYLARK_CONFIG=/pkg/data/{config_path.name}" + docker_run_flags += " " + " ".join(f"--env {k}={v}" for k, v in docker_envs.items()) gateway_daemon_cmd = f"python -u /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /skylark/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" start_out, start_err = self.run_command(docker_launch_cmd) diff --git a/skylark/gateway/gateway_obj_store.py b/skylark/gateway/gateway_obj_store.py index dd6116085..712328f7d 100644 --- a/skylark/gateway/gateway_obj_store.py +++ b/skylark/gateway/gateway_obj_store.py @@ -12,6 +12,8 @@ from dataclasses import dataclass +from skylark.utils.utils import retry_backoff + @dataclass class ObjStoreRequest: @@ -84,7 +86,7 @@ def worker_loop(self, worker_id: int): def upload(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) - obj_store_interface.upload_object(fpath, key).result() + retry_backoff(lambda: obj_store_interface.upload_object(fpath, key).result(), max_retries=4) chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) # update chunk state @@ -109,7 +111,7 @@ def upload(region, bucket, fpath, key, chunk_id): def download(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) - obj_store_interface.download_object(key, fpath).result() + retry_backoff(lambda: obj_store_interface.download_object(key, fpath).result(), max_retries=4) # update chunk state self.chunk_store.state_finish_download(chunk_id, f"obj_store:{self.worker_id}") From 92c0996e04c5fdd46cdfcdaa6905c268dfcbbca9 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 15 Mar 2022 23:48:38 +0000 Subject: [PATCH 19/34] Update --- skylark/gateway/chunk_store.py | 10 +++++----- skylark/gateway/gateway_obj_store.py | 15 --------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/skylark/gateway/chunk_store.py b/skylark/gateway/chunk_store.py index fda4ced81..ce67d106a 100644 --- a/skylark/gateway/chunk_store.py +++ b/skylark/gateway/chunk_store.py @@ -68,21 +68,21 @@ def state_finish_download(self, chunk_id: int, receiver_id: Optional[str] = None if state in [ChunkState.download_in_progress, ChunkState.downloaded]: self.set_chunk_state(chunk_id, ChunkState.downloaded, {"receiver_id": receiver_id}) else: - raise ValueError(f"Invalid transition finish_download from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition finish_download from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_queue_upload(self, chunk_id: int): state = self.get_chunk_state(chunk_id) if state in [ChunkState.downloaded, ChunkState.upload_queued]: self.set_chunk_state(chunk_id, ChunkState.upload_queued) else: - raise ValueError(f"Invalid transition upload_queued from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition upload_queued from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_start_upload(self, chunk_id: int, sender_id: Optional[str] = None): state = self.get_chunk_state(chunk_id) if state in [ChunkState.upload_queued, ChunkState.upload_in_progress]: self.set_chunk_state(chunk_id, ChunkState.upload_in_progress, {"sender_id": sender_id}) else: - raise ValueError(f"Invalid transition start_upload from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition start_upload from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_finish_upload(self, chunk_id: int, sender_id: Optional[str] = None): # todo log runtime to statistics store @@ -90,13 +90,13 @@ def state_finish_upload(self, chunk_id: int, sender_id: Optional[str] = None): if state in [ChunkState.upload_in_progress, ChunkState.upload_complete]: self.set_chunk_state(chunk_id, ChunkState.upload_complete, {"sender_id": sender_id}) else: - raise ValueError(f"Invalid transition finish_upload from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition finish_upload from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_fail(self, chunk_id: int): if self.get_chunk_state(chunk_id) != ChunkState.upload_complete: self.set_chunk_state(chunk_id, ChunkState.failed) else: - raise ValueError(f"Invalid transition fail from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition fail from {self.get_chunk_state(chunk_id)} (id={chunk_id})") ### # Chunk management diff --git a/skylark/gateway/gateway_obj_store.py b/skylark/gateway/gateway_obj_store.py index 712328f7d..37c1f3b27 100644 --- a/skylark/gateway/gateway_obj_store.py +++ b/skylark/gateway/gateway_obj_store.py @@ -68,10 +68,8 @@ def worker_loop(self, worker_id: int): request = self.worker_queue.get_nowait() chunk_req = request.chunk_req req_type = request.req_type - except queue.Empty: continue - fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute()) logger.debug(f"[obj_store:{self.worker_id}] Received chunk ID {chunk_req.chunk.chunk_id}") @@ -79,47 +77,34 @@ def worker_loop(self, worker_id: int): assert chunk_req.dst_type == "object_store" region = chunk_req.dst_region bucket = chunk_req.dst_object_store_bucket - self.chunk_store.state_start_upload(chunk_req.chunk.chunk_id, f"obj_store:{self.worker_id}") - logger.debug(f"[obj_store:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {bucket}") def upload(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) retry_backoff(lambda: obj_store_interface.upload_object(fpath, key).result(), max_retries=4) chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) - - # update chunk state self.chunk_store.state_finish_upload(chunk_id, f"obj_store:{self.worker_id}") - - # delete chunk chunk_file_path.unlink() - logger.debug(f"[obj_store:{self.worker_id}] Uploaded {chunk_id} to {bucket}") # wait for upload in seperate thread threading.Thread(target=upload, args=(region, bucket, fpath, chunk_req.chunk.key, chunk_req.chunk.chunk_id)).start() - elif req_type == "download": assert chunk_req.src_type == "object_store" region = chunk_req.src_region bucket = chunk_req.src_object_store_bucket - self.chunk_store.state_start_download(chunk_req.chunk.chunk_id, f"obj_store:{self.worker_id}") - logger.debug(f"[obj_store:{self.worker_id}] Starting download {chunk_req.chunk.chunk_id} from {bucket}") def download(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) retry_backoff(lambda: obj_store_interface.download_object(key, fpath).result(), max_retries=4) - - # update chunk state self.chunk_store.state_finish_download(chunk_id, f"obj_store:{self.worker_id}") logger.debug(f"[obj_store:{self.worker_id}] Downloaded {chunk_id} from {bucket}") # wait for request to return in sepearte thread, so we can update chunk state threading.Thread(target=download, args=(region, bucket, fpath, chunk_req.chunk.key, chunk_req.chunk.chunk_id)).start() - else: raise ValueError(f"Invalid location for chunk req, {req_type}: {chunk_req.src_type}->{chunk_req.dst_type}") From e25caec690b522fb330c0b335d21209d6c42ec35 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 00:20:29 +0000 Subject: [PATCH 20/34] Cache credentials to avoid credential errors --- scripts/experiment_paras.sh | 1 - scripts/setup_bucket.py | 5 ----- skylark/cli/cli.py | 5 ++--- skylark/compute/aws/aws_auth.py | 19 ++++++++++++------- skylark/compute/aws/aws_server.py | 2 +- skylark/compute/azure/azure_auth.py | 20 +++++++++++++++----- skylark/compute/azure/azure_server.py | 13 ++++++------- skylark/compute/gcp/gcp_auth.py | 12 +++++++++++- skylark/compute/gcp/gcp_server.py | 2 +- skylark/compute/server.py | 4 ++-- skylark/replicate/replicator_client.py | 4 ++-- 11 files changed, 52 insertions(+), 35 deletions(-) diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh index 91887346c..81e34bf41 100755 --- a/scripts/experiment_paras.sh +++ b/scripts/experiment_paras.sh @@ -50,7 +50,6 @@ cp ${filename} data/results/${experiment} skylark replicate-json ${filename} \ --source-bucket $src_bucket \ --dest-bucket $dest_bucket \ - --reuse-gateways \ --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt tail -1 data/results/${experiment}/obj-store-logs.txt; echo ${experiment} diff --git a/scripts/setup_bucket.py b/scripts/setup_bucket.py index a9c6ffe42..3d67819b4 100644 --- a/scripts/setup_bucket.py +++ b/scripts/setup_bucket.py @@ -7,10 +7,6 @@ from multiprocessing import Pool from concurrent.futures import wait -import ctypes - -libgcc_s = ctypes.CDLL("libgcc_s.so.1") - def parse_args(): parser = argparse.ArgumentParser(description="Setup replication experiment") @@ -20,7 +16,6 @@ def parse_args(): parser.add_argument("--bucket-prefix", default="sarah", help="Prefix for bucket to avoid naming collision") parser.add_argument("--key-prefix", default="", help="Prefix keys") args = parser.parse_args() - return args diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index d325805fa..5c5087da4 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -62,7 +62,6 @@ @app.command() def ls(directory: str): """List objects in the object store.""" - check_ulimit() provider, bucket, key = parse_path(directory) if provider == "local": for path in ls_local(Path(directory)): @@ -146,7 +145,7 @@ def replicate_random( ) if not reuse_gateways: - atexit.register(rc.deprovision_gateways, block=False) + atexit.register(rc.deprovision_gateways) else: logger.warning( f"Instances will remain up and may result in continued cloud billing. Remember to call `skylark deprovision` to deprovision gateways." @@ -221,7 +220,7 @@ def replicate_json( ) if not reuse_gateways: - atexit.register(rc.deprovision_gateways, block=False) + atexit.register(rc.deprovision_gateways) else: logger.warning( f"Instances will remain up and may result in continued cloud billing. Remember to call `skylark deprovision` to deprovision gateways." diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index 7e2aeb437..f47f51068 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -1,9 +1,12 @@ +import threading from typing import Optional import boto3 class AWSAuthentication: + __cached_credentials = threading.local() + def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = None): """Loads AWS authentication details. If no access key is provided, it will try to load credentials using boto3""" if access_key and secret_key: @@ -26,13 +29,15 @@ def enabled(self): def infer_credentials(self): # todo load temporary credentials from STS - session = boto3.Session() - credentials = session.get_credentials() - if credentials: - credentials = credentials.get_frozen_credentials() - return credentials.access_key, credentials.secret_key - else: - return None, None + cached_credential = getattr(self.__cached_credentials, "boto3_credential", (None, None)) + if cached_credential is None: + session = boto3.Session() + credentials = session.get_credentials() + if credentials: + credentials = credentials.get_frozen_credentials() + cached_credential = (credentials.access_key, credentials.secret_key) + setattr(self.__cached_credentials, "boto3_credential", cached_credential) + return cached_credential def get_boto3_session(self, aws_region: str): if self.config_mode == "manual": diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 500326aaa..fc4e38f9c 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -90,7 +90,7 @@ def instance_state(self): def __repr__(self): return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})" - def terminate_instance_impl(self, block=True): + def terminate_instance_impl(self): self.auth.get_boto3_resource("ec2", self.aws_region).instances.filter(InstanceIds=[self.instance_id]).terminate() def get_ssh_client_impl(self): diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index f1d0421e1..7bdb2e71e 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -1,5 +1,6 @@ import os import subprocess +import threading from typing import Optional from azure.identity import DefaultAzureCredential from azure.mgmt.compute import ComputeManagementClient @@ -13,13 +14,22 @@ class AzureAuthentication: + __cached_credentials = threading.local() + def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): self.subscription_id = subscription_id - self.credential = DefaultAzureCredential( - exclude_managed_identity_credential=query_which_cloud() != "azure", # exclude MSI if not Azure - exclude_powershell_credential=True, - exclude_visual_studio_code_credential=True, - ) + self.credential = self.get_credential(subscription_id) + + def get_credential(self, subscription_id: str): + cached_credential = getattr(self.__cached_credentials, f"credential_{subscription_id}", None) + if cached_credential is None: + cached_credential = DefaultAzureCredential( + exclude_managed_identity_credential=query_which_cloud() != "azure", # exclude MSI if not Azure + exclude_powershell_credential=True, + exclude_visual_studio_code_credential=True, + ) + setattr(self.__cached_credentials, f"credential_{subscription_id}", cached_credential) + return cached_credential def enabled(self) -> bool: return self.subscription_id is not None diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index f6c80cc9e..c29b8d62e 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -157,7 +157,7 @@ def tags(self): def network_tier(self): return "PREMIUM" - def terminate_instance_impl(self, block=True): + def terminate_instance_impl(self): compute_client = self.auth.get_compute_client() network_client = self.auth.get_network_client() vm_poller = compute_client.virtual_machines.begin_delete(AzureServer.resource_group_name, self.vm_name(self.name)) @@ -169,12 +169,11 @@ def terminate_instance_impl(self, block=True): ) nsg_poller = network_client.network_security_groups.begin_delete(AzureServer.resource_group_name, self.nsg_name(self.name)) vnet_poller = network_client.virtual_networks.begin_delete(AzureServer.resource_group_name, self.vnet_name(self.name)) - if block: - nsg_poller.result() - ip_poller.result() - subnet_poller.result() - nic_poller.result() - vnet_poller.result() + nsg_poller.result() + ip_poller.result() + subnet_poller.result() + nic_poller.result() + vnet_poller.result() def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index 1dc582e60..a7ad526d2 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -1,3 +1,4 @@ +import threading from typing import Optional import googleapiclient.discovery import google.auth @@ -6,8 +7,17 @@ class GCPAuthentication: + __cached_credentials = threading.local() + def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): - self.credentials, self.project_id = google.auth.default(quota_project_id=project_id) + self.credentials, self.project_id = self.make_credential(project_id) + + def make_credential(self, project_id): + cached_credential = getattr(self.__cached_credentials, f"credential_{project_id}", (None, None)) + if cached_credential == (None, None): + cached_credential = google.auth.default(quota_project_id=project_id) + setattr(self.__cached_credentials, f"credential_{project_id}", cached_credential) + return cached_credential def enabled(self): return self.credentials is not None and self.project_id is not None diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index 6ead951e1..e5ad6d2b9 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -82,7 +82,7 @@ def network_tier(self): def __repr__(self): return f"GCPServer(region_tag={self.region_tag}, instance_name={self.gcp_instance_name})" - def terminate_instance_impl(self, block=True): + def terminate_instance_impl(self): self.auth.get_gcp_client().instances().delete( project=self.auth.project_id, zone=self.gcp_region, instance=self.instance_name() ).execute() diff --git a/skylark/compute/server.py b/skylark/compute/server.py index 7b1c725b4..e53f2d8c2 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -127,10 +127,10 @@ def tags(self): def network_tier(self): raise NotImplementedError() - def terminate_instance_impl(self, block=True): + def terminate_instance_impl(self): raise NotImplementedError() - def terminate_instance(self, block=False): + def terminate_instance(self): """Terminate instance""" self.close_server() self.terminate_instance_impl() diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index 82e325863..c082dcf5d 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -192,11 +192,11 @@ def setup(server: Server): args.append((server, {self.bound_nodes[n].public_ip(): v for n, v in self.topology.get_outgoing_paths(node).items()})) do_parallel(lambda arg: arg[0].start_gateway(arg[1], gateway_docker_image=self.gateway_docker_image), args, n=-1) - def deprovision_gateways(self, block=True): + def deprovision_gateways(self): def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: logger.warning(f"Deprovisioning {server.uuid()}") - server.terminate_instance(block=block) + server.terminate_instance() logger.warning("Deprovisioning instances") do_parallel(deprovision_gateway_instance, self.bound_nodes.values(), n=-1) From b4431ca58891353e3163b887948a5e9d1f2a2e50 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 02:25:56 +0000 Subject: [PATCH 21/34] Storage account successfully created --- scripts/experiment_paras.sh | 5 +- scripts/requirements-gateway.txt | 1 + setup.py | 1 + skylark/cli/cli.py | 3 +- skylark/cli/cli_helper.py | 24 +++++----- skylark/compute/aws/aws_auth.py | 2 +- skylark/compute/azure/azure_auth.py | 13 +++-- skylark/compute/gcp/gcp_auth.py | 4 +- skylark/obj_store/azure_interface.py | 71 +++++++++++++--------------- skylark/test/test_azure_interface.py | 2 +- 10 files changed, 63 insertions(+), 63 deletions(-) diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh index 81e34bf41..efc20ccd5 100755 --- a/scripts/experiment_paras.sh +++ b/scripts/experiment_paras.sh @@ -18,13 +18,12 @@ filename=data/plan/${experiment}.json echo $filename # creats buckets + bucket data and sets env variables -# python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} - +python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} # TODO:artificially increase the number of chunks # TODO: try synthetic data -source scripts/pack_docker.sh; +source scripts/pack_docker.sh ## create plan throughput=$(($max_instance*3)) diff --git a/scripts/requirements-gateway.txt b/scripts/requirements-gateway.txt index 647c28580..afa7dbc09 100644 --- a/scripts/requirements-gateway.txt +++ b/scripts/requirements-gateway.txt @@ -3,6 +3,7 @@ azure-identity azure-mgmt-compute azure-mgmt-network azure-mgmt-resource +azure-mgmt-storage azure-mgmt-authorization azure-storage-blob>=12.0.0 boto3 diff --git a/setup.py b/setup.py index ad03a7789..8d88bb9c5 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", + "azure-mgmt-storage", "azure-mgmt-authorization", "azure-storage-blob>=12.0.0", "boto3", diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 5c5087da4..e584947ec 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -91,7 +91,8 @@ def cp(src: str, dst: str): elif provider_src == "gs" and provider_dst == "local": copy_gcs_local(bucket_src, path_src, Path(path_dst)) elif provider_src == "local" and provider_dst == "azure": - copy_local_azure(Path(path_src), bucket_dst, path_dst) + account_name, container_name = bucket_dst + copy_local_azure(Path(path_src), account_name, container_name, path_dst) elif provider_src == "azure" and provider_dst == "local": copy_azure_local(bucket_src, path_src, Path(path_dst)) else: diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 92c2a7aef..6e3f9391f 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -51,7 +51,7 @@ def parse_path(path: str): if match is None: raise ValueError(f"Invalid Azure path: {path}") account, container, blob_path = match.groups() - return "azure", account, container + return "azure", (account, container), blob_path elif path.startswith("azure://"): bucket_name = path[8:] region = path[8:].split("-", 2)[-1] @@ -96,7 +96,7 @@ def copy_local_local(src: Path, dst: Path): copyfile(src, dst) -def copy_local_objstore(object_interface: ObjectStoreInterface, src: Path, dst_bucket: str, dst_key: str): +def copy_local_objstore(object_interface: ObjectStoreInterface, src: Path, dst_key: str): ops: List[concurrent.futures.Future] = [] path_mapping: Dict[concurrent.futures.Future, Path] = {} @@ -120,7 +120,7 @@ def _copy(path: Path, dst_key: str, total_size=0.0): pbar.update(path_mapping[op].stat().st_size) -def copy_objstore_local(object_interface: ObjectStoreInterface, src_bucket: str, src_key: str, dst: Path): +def copy_objstore_local(object_interface: ObjectStoreInterface, src_key: str, dst: Path): ops: List[concurrent.futures.Future] = [] obj_mapping: Dict[concurrent.futures.Future, ObjectStoreObject] = {} @@ -148,24 +148,22 @@ def _copy(src_obj: ObjectStoreObject, dst: Path): def copy_local_gcs(src: Path, dst_bucket: str, dst_key: str): gcs = GCSInterface(None, dst_bucket) - return copy_local_objstore(gcs, src, dst_bucket, dst_key) + return copy_local_objstore(gcs, src, dst_key) def copy_gcs_local(src_bucket: str, src_key: str, dst: Path): gcs = GCSInterface(None, src_bucket) - return copy_objstore_local(gcs, src_bucket, src_key, dst) + return copy_objstore_local(gcs, src_key, dst) -def copy_local_azure(src: Path, dst_bucket: str, dst_key: str): - # Note that dst_key is infact azure region - azure = AzureInterface(dst_key, dst_bucket) - return copy_local_objstore(azure, src, dst_bucket, dst_key) +def copy_local_azure(src: Path, dst_account_name: str, dst_container_name: str, dst_key: str): + azure = AzureInterface(dst_key, dst_account_name, dst_container_name) + return copy_local_objstore(azure, src, dst_key) -def copy_azure_local(src_bucket: str, src_key: str, dst: Path): - # Note that src_key is infact azure region - azure = AzureInterface(src_key, src_bucket) - return copy_objstore_local(azure, src_bucket, src_key, dst) +def copy_azure_local(src_account_name: str, src_container_name: str, src_key: str, dst: Path): + azure = AzureInterface(src_key, src_account_name, src_container_name) + return copy_objstore_local(azure, src_key, dst) def copy_local_s3(src: Path, dst_bucket: str, dst_key: str, use_tls: bool = True): diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index f47f51068..037d3f139 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -30,7 +30,7 @@ def enabled(self): def infer_credentials(self): # todo load temporary credentials from STS cached_credential = getattr(self.__cached_credentials, "boto3_credential", (None, None)) - if cached_credential is None: + if cached_credential == (None, None): session = boto3.Session() credentials = session.get_credentials() if credentials: diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index 7bdb2e71e..05032875c 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -7,7 +7,8 @@ from azure.mgmt.network import NetworkManagementClient from azure.mgmt.resource import ResourceManagementClient from azure.mgmt.authorization import AuthorizationManagementClient -from azure.storage.blob import BlobServiceClient +from azure.mgmt.storage import StorageManagementClient +from azure.storage.blob import BlobServiceClient, ContainerClient from skylark import cloud_config from skylark.compute.utils import query_which_cloud @@ -19,7 +20,7 @@ class AzureAuthentication: def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): self.subscription_id = subscription_id self.credential = self.get_credential(subscription_id) - + def get_credential(self, subscription_id: str): cached_credential = getattr(self.__cached_credentials, f"credential_{subscription_id}", None) if cached_credential is None: @@ -59,5 +60,11 @@ def get_network_client(self): def get_authorization_client(self): return AuthorizationManagementClient(self.credential, self.subscription_id) - def get_storage_client(self, account_url: str): + def get_storage_management_client(self): + return StorageManagementClient(self.credential, self.subscription_id) + + def get_container_client(self, account_url: str, container_name: str): + return ContainerClient(account_url, container_name, credential=self.credential) + + def get_blob_service_client(self, account_url: str): return BlobServiceClient(account_url=account_url, credential=self.credential) diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index a7ad526d2..face08a36 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -8,10 +8,10 @@ class GCPAuthentication: __cached_credentials = threading.local() - + def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): self.credentials, self.project_id = self.make_credential(project_id) - + def make_credential(self, project_id): cached_credential = getattr(self.__cached_credentials, f"credential_{project_id}", (None, None)) if cached_credential == (None, None): diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 270055b1a..16ed6c4b1 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -3,88 +3,81 @@ from typing import Iterator, List from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError, ServiceRequestError from skylark.compute.azure.azure_auth import AzureAuthentication +from skylark.compute.azure.azure_server import AzureServer from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject class AzureObject(ObjectStoreObject): def full_path(self): - return os.path.join(f"https://{self.bucket}.blob.core.windows.net", self.key) + account_name, container_name = self.bucket.split("/") + return os.path.join(f"https://{account_name}.blob.core.windows.net", container_name, self.key) class AzureInterface(ObjectStoreInterface): def __init__(self, azure_region, container_name): self.azure_region = azure_region + self.account_name = f"skylark{azure_region.replace(' ', '').lower()}" self.container_name = container_name - self.bucket_name = self.container_name # For compatibility - self.pending_downloads, self.completed_downloads = 0, 0 - self.pending_uploads, self.completed_uploads = 0, 0 # Create a blob service client self.auth = AzureAuthentication() - self.account_url = "https://{}.blob.core.windows.net".format("skylark" + self.azure_region) - print("===> Account URL:", self.account_url) - self.blob_service_client = self.auth.get_storage_client(self.account_url) + self.account_url = f"https://{self.account_name}.blob.core.windows.net" + self.storage_management_client = self.auth.get_storage_management_client() + self.container_client = self.auth.get_container_client(self.account_url, self.container_name) + self.blob_service_client = self.auth.get_blob_service_client(self.account_url) + # parallel upload/downloads self.pool = ThreadPoolExecutor(max_workers=256) # TODO: This might need some tuning self.max_concurrency = 1 - self.container_client = None - if not self.container_exists(): - self.create_container() - logger.info(f"==> Creating Azure container {self.container_name}") - def _on_done_download(self, **kwargs): - self.completed_downloads += 1 - self.pending_downloads -= 1 - - def _on_done_upload(self, **kwargs): - self.completed_uploads += 1 - self.pending_uploads -= 1 + def storage_account_exists(self): + try: + self.storage_management_client.storage_accounts.get_properties(AzureServer.resource_group_name, self.account_name) + return True + except ResourceNotFoundError: + return False def container_exists(self): - # More like "is container empty?" - # Get a client to interact with a specific container - though it may not yet exist - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) try: - for blob in self.container_client.list_blobs(): + for _ in self.container_client.list_blobs(): return True except ResourceNotFoundError: return False - except ServiceRequestError: - logger.error("==> Unable to access storage account for region specified") - logger.error("==> Aborting. Please check your Azure credentials and region") - exit(-1) + + def create_storage_account(self, tier="Standard_LRS"): + try: + operation = self.storage_management_client.storage_accounts.begin_create( + AzureServer.resource_group_name, + self.account_name, + {"sku": {"name": tier}, "kind": "Storage", "location": self.azure_region}, + ) + operation.result() + except ResourceExistsError: + logger.warning("Unable to create storage account as it already exists") def create_container(self): try: - self.container_client = self.blob_service_client.create_container(self.container_name) - self.properties = self.container_client.get_container_properties() + self.container_client.create_container() except ResourceExistsError: - logger.warning("==> Container might already exist, in which case blobs are re-written") - # logger.warning("==> Alternatively use a diff bucket name with `--bucket-prefix`") - return + logger.warning("Unable to create container as it already exists") def create_bucket(self): - return self.create_container() + return self.create_storage_account() and self.create_container() def delete_container(self): - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) try: self.container_client.delete_container() except ResourceNotFoundError: - logger.warning("Container doesn't exists. Unable to delete") + logger.warning("Unable to delete container as it doesn't exists") def delete_bucket(self): return self.delete_container() def list_objects(self, prefix="") -> Iterator[AzureObject]: - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) blobs = self.container_client.list_blobs() for blob in blobs: - yield AzureObject("azure", blob.container, blob.name, blob.size, blob.last_modified) + yield AzureObject("azure", f"{self.account_name}/{blob.container}", blob.name, blob.size, blob.last_modified) def delete_objects(self, keys: List[str]): for key in keys: diff --git a/skylark/test/test_azure_interface.py b/skylark/test/test_azure_interface.py index 92f3faa2f..42c311e28 100644 --- a/skylark/test/test_azure_interface.py +++ b/skylark/test/test_azure_interface.py @@ -9,7 +9,7 @@ def test_azure_interface(): azure_interface = AzureInterface(f"eastus", f"sky-us-east-1") - assert azure_interface.bucket_name == "sky-us-east-1" + assert azure_interface.container_name == "sky-us-east-1" assert azure_interface.azure_region == "eastus" azure_interface.create_bucket() From 715d24c3f8dd040cdfdb73a7ff85aa0b6951357b Mon Sep 17 00:00:00 2001 From: Shishir Patil Date: Wed, 16 Mar 2022 03:26:59 +0000 Subject: [PATCH 22/34] Changing Azure to premium storage --- scripts/experiment.sh | 3 +-- skylark/obj_store/azure_interface.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/experiment.sh b/scripts/experiment.sh index 3c93094cd..6db6a31cf 100755 --- a/scripts/experiment.sh +++ b/scripts/experiment.sh @@ -36,7 +36,7 @@ echo $filename export GOOGLE_APPLICATION_CREDENTIALS="/home/ubuntu/.skylark-shishir-42be5f375b7a.json" # creats buckets + bucket data and sets env variables -python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --gcp-project skylark-shishir --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} +python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} # TODO:artificially increase the number of chunks @@ -67,7 +67,6 @@ cp ${filename} data/results/${experiment} # run replication (obj store) skylark replicate-json ${filename} \ - --gcp-project skylark-shishir \ --source-bucket $src_bucket \ --dest-bucket $dest_bucket \ --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 16ed6c4b1..35f6be641 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -45,12 +45,12 @@ def container_exists(self): except ResourceNotFoundError: return False - def create_storage_account(self, tier="Standard_LRS"): + def create_storage_account(self, tier="Premium_LRS"): try: operation = self.storage_management_client.storage_accounts.begin_create( AzureServer.resource_group_name, self.account_name, - {"sku": {"name": tier}, "kind": "Storage", "location": self.azure_region}, + {"sku": {"name": tier}, "kind": "BlockBlobStorage", "location": self.azure_region}, ) operation.result() except ResourceExistsError: From 98fe25b9491e0f2074f85cc6434496ac989d90b0 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 04:05:36 +0000 Subject: [PATCH 23/34] AWS and GCP works --- scripts/experiment_paras.sh | 2 +- skylark/compute/aws/aws_auth.py | 36 ++++++++++++++++------------ skylark/compute/gcp/gcp_auth.py | 17 ++++++++++++- skylark/obj_store/azure_interface.py | 5 ++-- skylark/obj_store/gcs_interface.py | 4 ++-- skylark/obj_store/s3_interface.py | 2 +- 6 files changed, 44 insertions(+), 22 deletions(-) diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh index efc20ccd5..71c8fbc82 100755 --- a/scripts/experiment_paras.sh +++ b/scripts/experiment_paras.sh @@ -18,7 +18,7 @@ filename=data/plan/${experiment}.json echo $filename # creats buckets + bucket data and sets env variables -python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} +# python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} # TODO:artificially increase the number of chunks # TODO: try synthetic data diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index 037d3f139..614932a91 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -11,33 +11,39 @@ def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = """Loads AWS authentication details. If no access key is provided, it will try to load credentials using boto3""" if access_key and secret_key: self.config_mode = "manual" - self.access_key = access_key - self.secret_key = secret_key + self._access_key = access_key + self._secret_key = secret_key else: - infer_access_key, infer_secret_key = self.infer_credentials() - if infer_access_key and infer_secret_key: - self.config_mode = "iam_inferred" - self.access_key = infer_access_key - self.secret_key = infer_secret_key - else: - self.config_mode = "disabled" - self.access_key = None - self.secret_key = None + self.config_mode = "iam_inferred" + self._access_key = None + self._secret_key = None + + @property + def access_key(self): + if self._access_key is None: + self._access_key, self._secret_key = self.infer_credentials() + return self._access_key + + @property + def secret_key(self): + if self._secret_key is None: + self._access_key, self._secret_key = self.infer_credentials() + return self._secret_key def enabled(self): return self.config_mode != "disabled" def infer_credentials(self): # todo load temporary credentials from STS - cached_credential = getattr(self.__cached_credentials, "boto3_credential", (None, None)) - if cached_credential == (None, None): + cached_credential = getattr(self.__cached_credentials, "boto3_credential", None) + if cached_credential == None: session = boto3.Session() credentials = session.get_credentials() if credentials: credentials = credentials.get_frozen_credentials() cached_credential = (credentials.access_key, credentials.secret_key) - setattr(self.__cached_credentials, "boto3_credential", cached_credential) - return cached_credential + setattr(self.__cached_credentials, "boto3_credential", cached_credential) + return cached_credential if cached_credential else (None, None) def get_boto3_session(self, aws_region: str): if self.config_mode == "manual": diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index face08a36..a4dd7491a 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -10,7 +10,22 @@ class GCPAuthentication: __cached_credentials = threading.local() def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): - self.credentials, self.project_id = self.make_credential(project_id) + # load credentials lazily and then cache across threads + self.inferred_project_id = project_id + self._credentials = None + self._project_id = None + + @property + def credentials(self): + if self._credentials is None: + self._credentials, self._project_id = self.make_credential(self.inferred_project_id) + return self._credentials + + @property + def project_id(self): + if self._project_id is None: + self._credentials, self._project_id = self.make_credential(self.inferred_project_id) + return self._project_id def make_credential(self, project_id): cached_credential = getattr(self.__cached_credentials, f"credential_{project_id}", (None, None)) diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 35f6be641..9782a71d9 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -62,8 +62,9 @@ def create_container(self): except ResourceExistsError: logger.warning("Unable to create container as it already exists") - def create_bucket(self): - return self.create_storage_account() and self.create_container() + def create_bucket(self, premium_tier=True): + tier = "Premium_LRS" if premium_tier else "Standard_LRS" + return self.create_storage_account(tier=tier) and self.create_container() def delete_container(self): try: diff --git a/skylark/obj_store/gcs_interface.py b/skylark/obj_store/gcs_interface.py index a10e28fe5..b40ce5220 100644 --- a/skylark/obj_store/gcs_interface.py +++ b/skylark/obj_store/gcs_interface.py @@ -47,10 +47,10 @@ def bucket_exists(self): except Exception: return False - def create_bucket(self, storage_class: str = "STANDARD"): + def create_bucket(self, premium_tier=True): if not self.bucket_exists(): bucket = self._gcs_client.bucket(self.bucket_name) - bucket.storage_class = storage_class + bucket.storage_class = "STANDARD" new_bucket = self._gcs_client.create_bucket(bucket, location=self.gcp_region) assert self.bucket_exists() diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index f246e13e8..c5aa0e069 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -56,7 +56,7 @@ def bucket_exists(self): s3_client = self.auth.get_boto3_client("s3", self.aws_region) return self.bucket_name in [b["Name"] for b in s3_client.list_buckets()["Buckets"]] - def create_bucket(self): + def create_bucket(self, premium_tier=True): s3_client = self.auth.get_boto3_client("s3", self.aws_region) if not self.bucket_exists(): if self.aws_region == "us-east-1": From 6eeedcc7cb101d29f39a5b3f466117528924cff1 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 04:09:05 +0000 Subject: [PATCH 24/34] Make container --- scripts/experiment_paras.sh | 2 +- skylark/obj_store/azure_interface.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh index 71c8fbc82..efc20ccd5 100755 --- a/scripts/experiment_paras.sh +++ b/scripts/experiment_paras.sh @@ -18,7 +18,7 @@ filename=data/plan/${experiment}.json echo $filename # creats buckets + bucket data and sets env variables -# python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} +python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} # TODO:artificially increase the number of chunks # TODO: try synthetic data diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 9782a71d9..f839145b2 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -40,8 +40,8 @@ def storage_account_exists(self): def container_exists(self): try: - for _ in self.container_client.list_blobs(): - return True + self.container_client.get_container_properties() + return True except ResourceNotFoundError: return False @@ -64,7 +64,10 @@ def create_container(self): def create_bucket(self, premium_tier=True): tier = "Premium_LRS" if premium_tier else "Standard_LRS" - return self.create_storage_account(tier=tier) and self.create_container() + if not self.storage_account_exists(): + self.create_storage_account(tier=tier) + if not self.container_exists(): + self.create_container() def delete_container(self): try: From 7384d20caf04e127612c3a11013ec986ec32ae56 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:42:16 +0000 Subject: [PATCH 25/34] Autoflake --- skylark/__init__.py | 1 - skylark/cli/cli_azure.py | 3 --- skylark/cli/cli_gcp.py | 1 - skylark/cli/cli_helper.py | 3 +-- skylark/compute/aws/aws_auth.py | 2 +- skylark/compute/aws/aws_server.py | 1 - skylark/compute/azure/azure_cloud_provider.py | 1 - skylark/compute/azure/azure_server.py | 1 - skylark/compute/gcp/gcp_auth.py | 2 +- skylark/compute/gcp/gcp_cloud_provider.py | 1 - skylark/compute/gcp/gcp_server.py | 1 - skylark/compute/server.py | 2 -- skylark/gateway/chunk_store.py | 2 +- skylark/gateway/gateway_daemon_api.py | 3 --- skylark/obj_store/azure_interface.py | 2 +- skylark/replicate/replicator_client.py | 2 -- 16 files changed, 5 insertions(+), 23 deletions(-) diff --git a/skylark/__init__.py b/skylark/__init__.py index 874d670f0..d80cae8b2 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -1,7 +1,6 @@ import os from pathlib import Path -from skylark.compute.utils import query_which_cloud from skylark.config import SkylarkConfig diff --git a/skylark/cli/cli_azure.py b/skylark/cli/cli_azure.py index 0193a6cee..97ea6130f 100644 --- a/skylark/cli/cli_azure.py +++ b/skylark/cli/cli_azure.py @@ -7,10 +7,7 @@ from typing import List import typer -from azure.identity import DefaultAzureCredential -from azure.mgmt.compute import ComputeManagementClient from skylark.compute.azure.azure_auth import AzureAuthentication -from skylark.config import SkylarkConfig from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.utils.utils import do_parallel diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py index 62d98c88c..78c9ec2a0 100644 --- a/skylark/cli/cli_gcp.py +++ b/skylark/cli/cli_gcp.py @@ -5,7 +5,6 @@ import questionary import typer -from skylark.utils import logger from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.gcp.gcp_server import GCPServer diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 6e3f9391f..9c83050da 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -6,8 +6,7 @@ import subprocess from pathlib import Path from shutil import copyfile -import sys -from typing import Dict, List, Optional +from typing import Dict, List import boto3 import typer diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index 614932a91..4010e9daa 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -23,7 +23,7 @@ def access_key(self): if self._access_key is None: self._access_key, self._secret_key = self.infer_credentials() return self._access_key - + @property def secret_key(self): if self._secret_key is None: diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index fc4e38f9c..c80eaaf84 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -1,4 +1,3 @@ -import subprocess import os from pathlib import Path from typing import Dict, Optional diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index 83557b07e..740edb884 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -6,7 +6,6 @@ import paramiko from skylark.compute.azure.azure_auth import AzureAuthentication -from skylark.config import SkylarkConfig from skylark.utils import logger from skylark import key_root from skylark.compute.azure.azure_server import AzureServer diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index c29b8d62e..0bd7384e8 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import Optional import paramiko from skylark import key_root diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index a4dd7491a..641f90a91 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -20,7 +20,7 @@ def credentials(self): if self._credentials is None: self._credentials, self._project_id = self.make_credential(self.inferred_project_id) return self._credentials - + @property def project_id(self): if self._project_id is None: diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index e5d9ca977..4aad3804e 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -7,7 +7,6 @@ import googleapiclient import paramiko from skylark.compute.gcp.gcp_auth import GCPAuthentication -from skylark.config import SkylarkConfig from skylark.utils import logger from oslo_concurrency import lockutils diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index e5ad6d2b9..e90fbb782 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -1,4 +1,3 @@ -import os from functools import lru_cache from pathlib import Path diff --git a/skylark/compute/server.py b/skylark/compute/server.py index e53f2d8c2..decded968 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -1,5 +1,4 @@ import json -import os import subprocess from enum import Enum, auto from pathlib import Path @@ -10,7 +9,6 @@ from skylark.utils.utils import PathLike, Timer, retry_backoff, wait_for from skylark import config_path -import os class ServerState(Enum): diff --git a/skylark/gateway/chunk_store.py b/skylark/gateway/chunk_store.py index ce67d106a..e212b7dad 100644 --- a/skylark/gateway/chunk_store.py +++ b/skylark/gateway/chunk_store.py @@ -1,5 +1,5 @@ from datetime import datetime -from multiprocessing import Lock, Manager, Queue +from multiprocessing import Manager, Queue from os import PathLike from pathlib import Path from queue import Empty diff --git a/skylark/gateway/gateway_daemon_api.py b/skylark/gateway/gateway_daemon_api.py index 7d7bf8571..41d0bb15f 100644 --- a/skylark/gateway/gateway_daemon_api.py +++ b/skylark/gateway/gateway_daemon_api.py @@ -1,13 +1,10 @@ import logging import logging.handlers -from multiprocessing import Process import os -import signal import threading from typing import Dict, List from flask import Flask, jsonify, request -import setproctitle from skylark.utils import logger from skylark.chunk import ChunkRequest, ChunkState from skylark.gateway.chunk_store import ChunkStore diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index f839145b2..4ace201d4 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -1,7 +1,7 @@ import os from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List -from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError, ServiceRequestError +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_server import AzureServer from skylark.utils import logger diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index c082dcf5d..06159377f 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -9,8 +9,6 @@ import uuid import requests -from skylark.compute.aws.aws_auth import AWSAuthentication -from skylark.config import SkylarkConfig from skylark.replicate.profiler import status_df_to_traceevent from skylark.utils import logger from tqdm import tqdm From a226c56180abb36b0b58ac8dbb1bbbbe1a47d591 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:42:48 +0000 Subject: [PATCH 26/34] Autoflake --- skylark/cli/cli_gcp.py | 2 -- skylark/compute/gcp/gcp_server.py | 1 - skylark/compute/server.py | 3 --- skylark/gateway/chunk_store.py | 2 +- skylark/gateway/gateway_daemon_api.py | 3 --- skylark/obj_store/azure_interface.py | 2 +- 6 files changed, 2 insertions(+), 11 deletions(-) diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py index 5c3f3b692..56faeb1cf 100644 --- a/skylark/cli/cli_gcp.py +++ b/skylark/cli/cli_gcp.py @@ -1,4 +1,3 @@ -import os import subprocess from shlex import split from typing import Optional @@ -7,7 +6,6 @@ import typer from skylark.config import load_config -from skylark.utils import logger from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.gcp.gcp_server import GCPServer diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index f1b2678fa..c3d919742 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -1,4 +1,3 @@ -import os from functools import lru_cache from pathlib import Path diff --git a/skylark/compute/server.py b/skylark/compute/server.py index a63b5e715..a330eca91 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -1,5 +1,4 @@ import json -import os import subprocess from enum import Enum, auto from pathlib import Path @@ -10,8 +9,6 @@ from skylark.compute.utils import make_dozzle_command, make_sysctl_tcp_tuning_command from skylark.utils.utils import PathLike, Timer, retry_backoff, wait_for -import configparser -import os from skylark import config_file diff --git a/skylark/gateway/chunk_store.py b/skylark/gateway/chunk_store.py index fda4ced81..677c4f1d5 100644 --- a/skylark/gateway/chunk_store.py +++ b/skylark/gateway/chunk_store.py @@ -1,5 +1,5 @@ from datetime import datetime -from multiprocessing import Lock, Manager, Queue +from multiprocessing import Manager, Queue from os import PathLike from pathlib import Path from queue import Empty diff --git a/skylark/gateway/gateway_daemon_api.py b/skylark/gateway/gateway_daemon_api.py index 7d7bf8571..41d0bb15f 100644 --- a/skylark/gateway/gateway_daemon_api.py +++ b/skylark/gateway/gateway_daemon_api.py @@ -1,13 +1,10 @@ import logging import logging.handlers -from multiprocessing import Process import os -import signal import threading from typing import Dict, List from flask import Flask, jsonify, request -import setproctitle from skylark.utils import logger from skylark.chunk import ChunkRequest, ChunkState from skylark.gateway.chunk_store import ChunkStore diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index f578b2b03..3948762ae 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -2,7 +2,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError -from azure.identity import DefaultAzureCredential, ClientSecretCredential +from azure.identity import ClientSecretCredential from azure.storage.blob import BlobServiceClient from skylark.config import load_config from skylark.utils import logger From d7e2df8c736b7fbf8c63be28045f4b83c59ab55f Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:46:07 +0000 Subject: [PATCH 27/34] Remove unused --- skylark/benchmark/network/latency.py | 92 --------- skylark/benchmark/profile_solver.py | 0 .../replicate/benchmark_triangles.py | 102 ---------- skylark/benchmark/replicate/test_direct.py | 136 -------------- skylark/benchmark/stop_all_instances.py | 43 ----- .../{network => traceroute}/traceroute.py | 0 .../tracerouteparser.py | 0 skylark/control_plane/README.md | 12 -- skylark/control_plane/config.yaml | 175 ------------------ skylark/control_plane/script.py | 31 ---- 10 files changed, 591 deletions(-) delete mode 100644 skylark/benchmark/network/latency.py delete mode 100644 skylark/benchmark/profile_solver.py delete mode 100644 skylark/benchmark/replicate/benchmark_triangles.py delete mode 100644 skylark/benchmark/replicate/test_direct.py delete mode 100644 skylark/benchmark/stop_all_instances.py rename skylark/benchmark/{network => traceroute}/traceroute.py (100%) rename skylark/benchmark/{network => traceroute}/tracerouteparser.py (100%) delete mode 100644 skylark/control_plane/README.md delete mode 100644 skylark/control_plane/config.yaml delete mode 100644 skylark/control_plane/script.py diff --git a/skylark/benchmark/network/latency.py b/skylark/benchmark/network/latency.py deleted file mode 100644 index 21bc8d157..000000000 --- a/skylark/benchmark/network/latency.py +++ /dev/null @@ -1,92 +0,0 @@ -import argparse -import json -import re -from typing import List, Tuple - -from skylark.utils import logger -from tqdm import tqdm - -from skylark import skylark_root -from skylark.benchmark.utils import provision -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.aws.aws_server import AWSServer -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.compute.gcp.gcp_server import GCPServer -from skylark.compute.server import Server -from skylark.utils.utils import do_parallel - - -def parse_args(): - aws_regions = AWSCloudProvider.region_list() - gcp_regions = GCPCloudProvider.region_list() - parser = argparse.ArgumentParser(description="Provision EC2 instances") - parser.add_argument("--aws_instance_class", type=str, default="i3en.large", help="Instance class") - parser.add_argument("--aws_region_list", type=str, nargs="+", default=aws_regions) - parser.add_argument("--gcp_instance_class", type=str, default="n1-highcpu-8", help="Instance class") - parser.add_argument("--use-premium-network", action="store_true", help="Use premium network") - parser.add_argument("--gcp_project", type=str, default="bair-commons-307400", help="GCP project") - parser.add_argument("--gcp_region_list", type=str, nargs="+", default=gcp_regions) - return parser.parse_args() - - -def main(args): - data_dir = skylark_root / "data" - log_dir = data_dir / "logs" - log_dir.mkdir(exist_ok=True, parents=True) - - aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) - aws_instances: dict[str, list[AWSServer]] - gcp_instances: dict[str, list[GCPServer]] - aws_instances, gcp_instances = provision( - aws=aws, - gcp=gcp, - aws_regions_to_provision=args.aws_region_list, - gcp_regions_to_provision=args.gcp_region_list, - aws_instance_class=args.aws_instance_class, - gcp_instance_class=args.gcp_instance_class, - ) - instance_list: List[Server] = [i for ilist in aws_instances.values() for i in ilist] - instance_list.extend([i for ilist in gcp_instances.values() for i in ilist]) - - # compute pairwise latency by running ping - def compute_latency(arg_pair: Tuple[Server, Server]) -> str: - instance_src, instance_dst = arg_pair - stdout, stderr = instance_src.run_command(f"ping -c 10 {instance_dst.public_ip()}") - latency_result = stdout.strip().split("\n")[-1] - tqdm.write(f"Latency from {instance_src.region_tag} to {instance_dst.region_tag} is {latency_result}") - return latency_result - - instance_pairs = [(i1, i2) for i1 in instance_list for i2 in instance_list if i1 != i2] - latency_results = do_parallel( - compute_latency, - instance_pairs, - progress_bar=True, - n=24, - desc="Latency", - arg_fmt=lambda x: f"{x[0].region_tag} to {x[1].region_tag}", - ) - - def parse_ping_result(string): - """make regex with named groups""" - try: - regex = r"rtt min/avg/max/mdev = (?P\d+\.\d+)/(?P\d+\.\d+)/(?P\d+\.\d+)/(?P\d+\.\d+) ms" - m = re.search(regex, string) - return dict(min=float(m.group("min")), avg=float(m.group("avg")), max=float(m.group("max")), mdev=float(m.group("mdev"))) - except Exception as e: - logger.exception(e) - return {} - - # save results - latency_results_out = [] - for (i1, i2), r in latency_results: - row = dict(src=i1.region_tag, dst=i2.region_tag, ping_str=r, **parse_ping_result(r)) - logger.info(row) - latency_results_out.append(row) - - with open(str(data_dir / "latency.json"), "w") as f: - json.dump(latency_results_out, f) - - -if __name__ == "__main__": - main(parse_args()) diff --git a/skylark/benchmark/profile_solver.py b/skylark/benchmark/profile_solver.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/skylark/benchmark/replicate/benchmark_triangles.py b/skylark/benchmark/replicate/benchmark_triangles.py deleted file mode 100644 index 53c33c63c..000000000 --- a/skylark/benchmark/replicate/benchmark_triangles.py +++ /dev/null @@ -1,102 +0,0 @@ -import atexit -from datetime import datetime -import pickle -from pathlib import Path - -import typer -from skylark.utils import logger -from skylark import GB, MB, skylark_root -from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology -from skylark.replicate.replicator_client import ReplicatorClient - - -def bench_triangle( - src_region: str, - dst_region: str, - inter_region: str = None, - log_dir: Path = None, - num_gateways: int = 1, - num_outgoing_connections: int = 16, - chunk_size_mb: int = 8, - n_chunks: int = 2048, - gcp_project: str = "skylark-333700", - gateway_docker_image: str = "ghcr.io/parasj/skylark:main", - aws_instance_class: str = "m5.8xlarge", - gcp_instance_class: str = None, - gcp_use_premium_network: bool = False, - key_prefix: str = "/test/benchmark_triangles", -): - if log_dir is None: - log_dir = skylark_root / "data" / "experiments" / "benchmark_triangles" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - log_dir.mkdir(exist_ok=True, parents=True) - result_dir = log_dir / "results" - result_dir.mkdir(exist_ok=True, parents=True) - - try: - if inter_region: - topo = ReplicationTopology() - for i in range(num_gateways): - topo.add_edge(src_region, i, inter_region, i, num_outgoing_connections) - topo.add_edge(inter_region, i, dst_region, i, num_outgoing_connections) - else: - topo = ReplicationTopology() - for i in range(num_gateways): - topo.add_edge(src_region, i, dst_region, i, num_outgoing_connections) - rc = ReplicatorClient( - topo, - gcp_project=gcp_project, - gateway_docker_image=gateway_docker_image, - aws_instance_class=aws_instance_class, - gcp_instance_class=gcp_instance_class, - gcp_use_premium_network=gcp_use_premium_network, - ) - - rc.provision_gateways(reuse_instances=False) - atexit.register(rc.deprovision_gateways) - for node, gw in rc.bound_nodes.items(): - logger.info(f"Provisioned {node}: {gw.gateway_log_viewer_url}") - - job = ReplicationJob( - source_region=src_region, - source_bucket=None, - dest_region=dst_region, - dest_bucket=None, - objs=[f"{key_prefix}/{i}" for i in range(n_chunks)], - random_chunk_size_mb=chunk_size_mb, - ) - - total_bytes = n_chunks * chunk_size_mb * MB - job = rc.run_replication_plan(job) - logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer(job, show_pbar=False, time_limit_seconds=600) - stats["success"] = True - stats["log"] = rc.get_chunk_status_log_df() - rc.deprovision_gateways() - except Exception as e: - logger.error(f"Failed to benchmark triangle {src_region} -> {dst_region}") - logger.exception(e) - - stats = {} - stats["error"] = str(e) - stats["success"] = False - - stats["src_region"] = src_region - stats["dst_region"] = dst_region - stats["inter_region"] = inter_region - stats["num_gateways"] = num_gateways - stats["num_outgoing_connections"] = num_outgoing_connections - stats["chunk_size_mb"] = chunk_size_mb - stats["n_chunks"] = n_chunks - - logger.info(f"Stats:") - for k, v in stats.items(): - if k not in ["log", "completed_chunk_ids"]: - logger.info(f"\t{k}: {v}") - - arg_hash = hash((src_region, dst_region, inter_region, num_gateways, num_outgoing_connections, chunk_size_mb, n_chunks)) - with open(result_dir / f"{arg_hash}.pkl", "wb") as f: - pickle.dump(stats, f) - - -if __name__ == "__main__": - typer.run(bench_triangle) diff --git a/skylark/benchmark/replicate/test_direct.py b/skylark/benchmark/replicate/test_direct.py deleted file mode 100644 index e623a255c..000000000 --- a/skylark/benchmark/replicate/test_direct.py +++ /dev/null @@ -1,136 +0,0 @@ -import argparse -import json -import time -from datetime import datetime - -from skylark.utils import logger - -from skylark import skylark_root -from skylark.benchmark.utils import provision -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.utils.utils import do_parallel - - -def parse_args(): - full_region_list = [] - full_region_list += [f"aws:{r}" for r in AWSCloudProvider.region_list()] - full_region_list += [f"gcp:{r}" for r in GCPCloudProvider.region_list()] - parser = argparse.ArgumentParser(description="Test throughput with Skylark Gateway") - parser.add_argument("--aws_instance_class", type=str, default="c5.4xlarge", help="Instance class") - parser.add_argument("--gcp_instance_class", type=str, default="n1-highcpu-8", help="Instance class") - parser.add_argument("--gcp_project", type=str, default="skylark-333700", help="GCP project") - parser.add_argument( - "--gcp_test_standard_network", action="store_true", help="Test GCP standard network in addition to premium (default)" - ) - parser.add_argument("--src_region", default="aws:us-east-1", choices=full_region_list, help="Source region") - parser.add_argument("--dst_region", default="aws:us-east-2", choices=full_region_list, help="Destination region") - parser.add_argument("--gateway_docker_image", type=str, default="ghcr.io/parasj/skylark:latest", help="Gateway docker image") - return parser.parse_args() - - -def setup(tup): - server, docker_image = tup - server.run_command("sudo apt-get update && sudo apt-get install -y iperf3") - docker_installed = "Docker version" in server.run_command(f"sudo docker --version")[0] - if not docker_installed: - logger.debug(f"[{server.region_tag}] Installing docker") - server.run_command("curl -fsSL https://get.docker.com -o get-docker.sh && sudo sh get-docker.sh") - out, err = server.run_command("sudo docker run --rm hello-world") - assert "Hello from Docker!" in out - server.run_command("sudo docker pull {}".format(docker_image)) - - -def parse_output(output): - stdout, stderr = output - last_line = stdout.strip().split("\n") - if len(last_line) > 0: - try: - return json.loads(last_line[-1]) - except json.decoder.JSONDecodeError: - logger.error(f"JSON parse error, stdout = '{stdout}', stderr = '{stderr}'") - else: - logger.error(f"No output from server, stderr = {stderr}") - return None - - -def main(args): - data_dir = skylark_root / "data" - log_dir = data_dir / "logs" / "gateway_test" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - log_dir.mkdir(exist_ok=True, parents=True) - - aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) - - # provision and setup servers - aws_regions = [r.split(":")[1] for r in [args.src_region, args.dst_region] if r.startswith("aws:")] - gcp_regions = [r.split(":")[1] for r in [args.src_region, args.dst_region] if r.startswith("gcp:")] - aws_instances, gcp_instances = provision( - aws=aws, - gcp=gcp, - aws_regions_to_provision=aws_regions, - gcp_regions_to_provision=gcp_regions, - aws_instance_class=args.aws_instance_class, - gcp_instance_class=args.gcp_instance_class, - gcp_use_premium_network=not args.gcp_test_standard_network, - log_dir=str(log_dir), - ) - - # select servers - src_cloud_region = args.src_region.split(":")[1] - if args.src_region.startswith("aws:"): - src_server = aws_instances[src_cloud_region][0] - elif args.src_region.startswith("gcp:"): - src_server = gcp_instances[src_cloud_region][0] - else: - raise ValueError(f"Unknown region {args.src_region}") - dst_cloud_region = args.dst_region.split(":")[1] - if args.dst_region.startswith("aws:"): - dst_server = aws_instances[dst_cloud_region][0] - elif args.dst_region.startswith("gcp:"): - dst_server = gcp_instances[dst_cloud_region][0] - else: - raise ValueError(f"Unknown region {args.dst_region}") - do_parallel( - setup, - [(src_server, args.gateway_docker_image), (dst_server, args.gateway_docker_image)], - progress_bar=True, - arg_fmt=lambda tup: tup[0].region_tag, - ) - - # generate random 1GB file on src server in /dev/shm/skylark/chunks_in - src_server.run_command("mkdir -p /dev/shm/skylark/chunks_in") - src_server.run_command("sudo dd if=/dev/urandom of=/dev/shm/skylark/chunks_in/1 bs=100M count=10 iflag=fullblock") - assert src_server.run_command("ls /dev/shm/skylark/chunks_in/1 | wc -l")[0].strip() == "1" - - # stop existing gateway containers - src_server.run_command("sudo docker kill gateway_server") - dst_server.run_command("sudo docker kill gateway_server") - - # start gateway on dst server - dst_server.run_command("dig +short myip.opendns.com @resolver1.opendns.com")[0].strip() - server_cmd = f"sudo docker run -d --rm --ipc=host --network=host --name=gateway_server {args.gateway_docker_image} /env/bin/python /pkg/skylark/replicate/gateway_server.py --port 3333 --num_connections 1" - dst_server.run_command(server_cmd) - - # wait for port to appear on dst server - while True: - if dst_server.run_command("sudo netstat -tulpn | grep 3333")[0].strip() != "": - break - time.sleep(1) - - # benchmark src to dst copy - client_cmd = f"sudo docker run --rm --ipc=host --network=host --name=gateway_client {args.gateway_docker_image} /env/bin/python /pkg/skylark/replicate/gateway_client.py --dst_host {dst_server.public_ip} --dst_port 3333 --chunk_id 1" - dst_data = parse_output(src_server.run_command(client_cmd)) - src_server.run_command("sudo docker kill gateway_server") - logger.info(f"Src to dst copy: {dst_data}") - - # run iperf server on dst - out, err = dst_server.run_command("iperf3 -s -D") - out, err = src_server.run_command("iperf3 -c {} -t 10".format(dst_server.public_ip)) - dst_server.run_command("sudo pkill iperf3") - print(out) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/skylark/benchmark/stop_all_instances.py b/skylark/benchmark/stop_all_instances.py deleted file mode 100644 index 0836e3ffd..000000000 --- a/skylark/benchmark/stop_all_instances.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse - -from skylark.utils import logger -from tqdm import tqdm - -from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider -from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider -from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider -from skylark.compute.server import Server -from skylark.utils.utils import do_parallel - - -def stop_instance(instance: Server): - instance.terminate_instance() - tqdm.write(f"Terminated instance {instance}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Stop all instances") - parser.add_argument("--disable-aws", action="store_true", help="Disables AWS operations if present") - parser.add_argument("--gcp-project", type=str, help="GCP project", default=None) - parser.add_argument("--azure-subscription", type=str, help="Microsoft Azure Subscription", default=None) - args = parser.parse_args() - - instances = [] - - if not args.disable_aws: - logger.info("Getting matching AWS instances") - aws = AWSCloudProvider() - for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True): - instances += instance_list - - if args.gcp_project: - logger.info("Getting matching GCP instances") - gcp = GCPCloudProvider(gcp_project=args.gcp_project) - instances += gcp.get_matching_instances() - - if args.azure_subscription: - logger.info("Getting matching Azure instances") - azure = AzureCloudProvider(azure_subscription=args.azure_subscription) - instances += azure.get_matching_instances() - - do_parallel(stop_instance, instances, progress_bar=True) diff --git a/skylark/benchmark/network/traceroute.py b/skylark/benchmark/traceroute/traceroute.py similarity index 100% rename from skylark/benchmark/network/traceroute.py rename to skylark/benchmark/traceroute/traceroute.py diff --git a/skylark/benchmark/network/tracerouteparser.py b/skylark/benchmark/traceroute/tracerouteparser.py similarity index 100% rename from skylark/benchmark/network/tracerouteparser.py rename to skylark/benchmark/traceroute/tracerouteparser.py diff --git a/skylark/control_plane/README.md b/skylark/control_plane/README.md deleted file mode 100644 index 346147cc8..000000000 --- a/skylark/control_plane/README.md +++ /dev/null @@ -1,12 +0,0 @@ -## Ray Set-up Instructions - -- Make sure you update the latest docker image in the `config.yaml` -- Make sure you have the AWS key-pair as defined in the `config.yaml` - -To launch the Ray Autoscaler cluster: `time ray up -y config.yaml` - -To attach to the head: `ray attach config.yaml` - -To run a python script `script.py`: `ray submit config.yaml script.py` - -To teardown the system: `ray down -y config.yaml` diff --git a/skylark/control_plane/config.yaml b/skylark/control_plane/config.yaml deleted file mode 100644 index 891d3ed69..000000000 --- a/skylark/control_plane/config.yaml +++ /dev/null @@ -1,175 +0,0 @@ -# An unique identifier for the head node and workers of this cluster. -cluster_name: skylark-control-plane - -# The maximum number of workers nodes to launch in addition to the head -# node. -max_workers: 2 - -# The autoscaler will scale up the cluster faster with higher upscaling speed. -# E.g., if the task requires adding more nodes then autoscaler will gradually -# scale up the cluster in chunks of upscaling_speed*currently_running_nodes. -# This number should be > 0. -upscaling_speed: 1.0 - -# This executes all commands on all nodes in the docker container, -# and opens all the necessary ports to support the Ray cluster. -# Empty string means disabled. -docker: - image: "ghcr.io/parasj/skylark:local-acecc236e40ed040d94d5463a337fd70" - container_name: "skylark_container" - pull_before_run: True - run_options: # Extra options to pass into "docker run" - - --ulimit nofile=65536:65536 - - - worker_image: "ghcr.io/parasj/skylark:local-acecc236e40ed040d94d5463a337fd70" - # worker_run_options: [] - -# If a node is idle for this many minutes, it will be removed. -idle_timeout_minutes: 5 - -# Cloud-provider specific configuration. -provider: - type: aws - region: us-west-2 - # Availability zone(s), comma-separated, that nodes may be launched in. - # Nodes are currently spread between zones by a round-robin approach, - # however this implementation detail should not be relied upon. - availability_zone: us-west-2a,us-west-2b - # Whether to allow node reuse. If set to False, nodes will be terminated - # instead of stopped. - cache_stopped_nodes: False # If not present, the default is True. - -# How Ray will authenticate with newly launched nodes. -auth: - ssh_user: ubuntu - ssh_private_key: /home/ubuntu/.skylark-us-west-2.pem -# By default Ray creates a new private keypair, but you can also use your own. -# If you do so, make sure to also set "KeyName" in the head and worker node -# configurations below. - -# Tell the autoscaler the allowed node types and the resources they provide. -# The key is the name of the node type, which is just for debugging purposes. -# The node config specifies the launch config and physical instance type. -available_node_types: - ray.head.default: - # The node type's CPU and GPU resources are auto-detected based on AWS instance type. - # If desired, you can override the autodetected CPU and GPU resources advertised to the autoscaler. - # You can also set custom resources. - # For example, to mark a node type as having 1 CPU, 1 GPU, and 5 units of a resource called "custom", set - # resources: {"CPU": 1, "GPU": 1, "custom": 5} - resources: {} - # Provider-specific config for this node type, e.g. instance type. By default - # Ray will auto-configure unspecified fields such as SubnetId and KeyName. - # For more documentation on available fields, see: - # http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances - node_config: - InstanceType: c5.24xlarge - ImageId: ami-074251216af698218 # Ubuntu 18.04 no Deep Learning AMI since no GPU - KeyName: skylark-us-west-2 - # You can provision additional disk space with a conf as follows - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 300 - # Additional options in the boto docs. - ray.worker.default: - # The minimum number of worker nodes of this type to launch. - # This number should be >= 0. - min_workers: 0 - # The maximum number of worker nodes of this type to launch. - # This takes precedence over min_workers. - max_workers: 2 - # The node type's CPU and GPU resources are auto-detected based on AWS instance type. - # If desired, you can override the autodetected CPU and GPU resources advertised to the autoscaler. - # You can also set custom resources. - # For example, to mark a node type as having 1 CPU, 1 GPU, and 5 units of a resource called "custom", set - # resources: {"CPU": 1, "GPU": 1, "custom": 5} - resources: {} - # Provider-specific config for this node type, e.g. instance type. By default - # Ray will auto-configure unspecified fields such as SubnetId and KeyName. - # For more documentation on available fields, see: - # http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances - node_config: - InstanceType: c5.24xlarge - ImageId: ami-074251216af698218 # ubuntu 18.04. No Deep Learning AMI since no GPU - KeyName: skylark-us-west-2 - BlockDeviceMappings: - - DeviceName: /dev/sda1 - Ebs: - VolumeSize: 300 - # Run workers on spot by default. Comment this out to use on-demand. - # NOTE: If relying on spot instances, it is best to specify multiple different instance - # types to avoid interruption when one instance type is experiencing heightened demand. - # Demand information can be found at https://aws.amazon.com/ec2/spot/instance-advisor/ - # InstanceMarketOptions: - # MarketType: spot - # Additional options can be found in the boto docs, e.g. - # SpotOptions: - # MaxPrice: MAX_HOURLY_PRICE - # Additional options in the boto docs. - -# Specify the node type of the head node (as configured above). -head_node_type: ray.head.default - -# Files or directories to copy to the head and worker nodes. The format is a -# dictionary from REMOTE_PATH: LOCAL_PATH, e.g. -file_mounts: { -# "/path1/on/remote/machine": "/path1/on/local/machine", -# "/path2/on/remote/machine": "/path2/on/local/machine", -} - -# Files or directories to copy from the head node to the worker nodes. The format is a -# list of paths. The same path on the head node will be copied to the worker node. -# This behavior is a subset of the file_mounts behavior. In the vast majority of cases -# you should just use file_mounts. Only use this if you know what you're doing! -cluster_synced_files: [] - -# Whether changes to directories in file_mounts or cluster_synced_files in the head node -# should sync to the worker node continuously -file_mounts_sync_continuously: False - -# Patterns for files to exclude when running rsync up or rsync down -rsync_exclude: - - "**/.git" - - "**/.git/**" - -# Pattern files to use for filtering out files when running rsync up or rsync down. The file is searched for -# in the source directory and recursively through all subdirectories. For example, if .gitignore is provided -# as a value, the behavior will match git's behavior for finding and using .gitignore files. -rsync_filter: - - ".gitignore" - -# List of commands that will be run before `setup_commands`. If docker is -# enabled, these commands will run outside the container and before docker -# is setup. -initialization_commands: [] - -# List of shell commands to run to set up nodes. -setup_commands: [] - # Note: if you're developing Ray, you probably want to create a Docker image that - # has your Ray repo pre-cloned. Then, you can replace the pip installs - # below with a git checkout (and possibly a recompile). - # To run the nightly version of ray (as opposed to the latest), either use a rayproject docker image - # that has the "nightly" (e.g. "rayproject/ray-ml:nightly-gpu") or uncomment the following line: - # - pip install -U "ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl" - -# Custom commands that will be run on the head node after common setup. -head_setup_commands: [] - -# Custom commands that will be run on worker nodes after common setup. -worker_setup_commands: [] - -# Command to start ray on the head node. You don't need to change this. -head_start_ray_commands: - - ray stop - - ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml - -# Command to start ray on worker nodes. You don't need to change this. -worker_start_ray_commands: - - ray stop - - ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 - -head_node: {} -worker_nodes: {} - diff --git a/skylark/control_plane/script.py b/skylark/control_plane/script.py deleted file mode 100644 index e7802228b..000000000 --- a/skylark/control_plane/script.py +++ /dev/null @@ -1,31 +0,0 @@ -from collections import Counter -import socket -import time - -import ray - -ray.init(address="auto") - -print( - """This cluster consists of - {} nodes in total - {} CPU resources in total -""".format( - len(ray.nodes()), ray.cluster_resources()["CPU"] - ) -) - - -@ray.remote -def f(): - time.sleep(0.005) - # Return IP address. - return socket.gethostbyname(socket.gethostname()) - - -object_ids = [f.remote() for _ in range(50000)] -ip_addresses = ray.get(object_ids) - -print("Tasks executed") -for ip_address, num_tasks in Counter(ip_addresses).items(): - print(" {} tasks on {}".format(num_tasks, ip_address)) From 90b61c399545eee98f50e5b24281bfdb9b7efbf9 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:54:19 +0000 Subject: [PATCH 28/34] Dead code elimination --- skylark/chunk.py | 11 -------- skylark/compute/azure/azure_cloud_provider.py | 4 --- skylark/compute/azure/azure_server.py | 12 --------- skylark/gateway/gateway_sender.py | 1 - skylark/obj_store/azure_interface.py | 10 ------- skylark/obj_store/gcs_interface.py | 16 +---------- skylark/obj_store/s3_interface.py | 27 ++++--------------- 7 files changed, 6 insertions(+), 75 deletions(-) diff --git a/skylark/chunk.py b/skylark/chunk.py index 0ed899f0e..bb0d51bf0 100644 --- a/skylark/chunk.py +++ b/skylark/chunk.py @@ -73,17 +73,6 @@ class ChunkState(Enum): def from_str(s: str): return ChunkState[s.lower()] - def to_short_str(self): - return { - ChunkState.registered: "REG", - ChunkState.download_in_progress: "DL", - ChunkState.downloaded: "DL_DONE", - ChunkState.upload_queued: "UL_QUE", - ChunkState.upload_in_progress: "UL", - ChunkState.upload_complete: "UL_DONE", - ChunkState.failed: "FAILED", - } - def __lt__(self, other): return self.value < other.value diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index f650fec4c..73988ae96 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -202,10 +202,6 @@ def lookup_valid_instance(region: str, instance_name: str) -> Optional[str]: logger.error(f"Cannot confirm availability of {instance_name} in {region}") return instance_name - @staticmethod - def get_resource_group_name(name): - return name - @staticmethod def get_transfer_cost(src_key, dst_key, premium_tier=True): """Assumes <10TB transfer tier.""" diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 716dd91fc..4b7ae6f37 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -104,18 +104,6 @@ def ip_name(name): def nic_name(name): return AzureServer.vm_name(name) + "-nic" - def get_resource_group(self): - credential = self.credential - resource_client = ResourceManagementClient(credential, self.subscription_id) - rg = resource_client.resource_groups.get(AzureServer.resource_group_name) - - # Sanity checks - assert rg.name == AzureServer.resource_group_name - assert rg.location == AzureServer.resource_group_location - assert rg.tags.get("skylark", None) == "true" - - return rg - def get_virtual_machine(self): credential = self.credential compute_client = ComputeManagementClient(credential, self.subscription_id) diff --git a/skylark/gateway/gateway_sender.py b/skylark/gateway/gateway_sender.py index a0a0feb82..fcdf8bb19 100644 --- a/skylark/gateway/gateway_sender.py +++ b/skylark/gateway/gateway_sender.py @@ -21,7 +21,6 @@ def __init__(self, chunk_store: ChunkStore, outgoing_ports: Dict[str, int]): # shared state self.manager = Manager() - self.next_worker_id = Value("i", 0) self.worker_queue: queue.Queue[int] = self.manager.Queue() self.exit_flags = [Event() for _ in range(self.n_processes)] diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 3948762ae..157bb6f7a 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -20,8 +20,6 @@ def __init__(self, azure_region, container_name): self.azure_region = azure_region self.container_name = container_name self.bucket_name = self.container_name # For compatibility - self.pending_downloads, self.completed_downloads = 0, 0 - self.pending_uploads, self.completed_uploads = 0, 0 # Authenticate config = load_config() self.subscription_id = config["azure_subscription_id"] @@ -41,14 +39,6 @@ def __init__(self, azure_region, container_name): self.create_container() logger.info(f"==> Creating Azure container {self.container_name}") - def _on_done_download(self, **kwargs): - self.completed_downloads += 1 - self.pending_downloads -= 1 - - def _on_done_upload(self, **kwargs): - self.completed_uploads += 1 - self.pending_uploads -= 1 - def container_exists(self): # More like "is container empty?" # Get a client to interact with a specific container - though it may not yet exist if self.container_client is None: diff --git a/skylark/obj_store/gcs_interface.py b/skylark/obj_store/gcs_interface.py index 7316e3176..d95c5ee83 100644 --- a/skylark/obj_store/gcs_interface.py +++ b/skylark/obj_store/gcs_interface.py @@ -17,10 +17,7 @@ class GCSInterface(ObjectStoreInterface): def __init__(self, gcp_region, bucket_name): # TODO: infer region? self.gcp_region = gcp_region - self.bucket_name = bucket_name - self.pending_downloads, self.completed_downloads = 0, 0 - self.pending_uploads, self.completed_uploads = 0, 0 # TODO - figure out how paralllelism handled self._gcs_client = storage.Client() @@ -28,17 +25,6 @@ def __init__(self, gcp_region, bucket_name): # TODO: set number of threads self.pool = ThreadPoolExecutor(max_workers=4) - def _on_done_download(self, **kwargs): - self.completed_downloads += 1 - self.pending_downloads -= 1 - - def _on_done_upload(self, **kwargs): - self.completed_uploads += 1 - self.pending_uploads -= 1 - - def infer_gcs_region(self, bucket_name: str): - raise NotImplementedError() - def bucket_exists(self): try: self._gcs_client.get_bucket(self.bucket_name) @@ -50,7 +36,7 @@ def create_bucket(self, storage_class: str = "STANDARD"): if not self.bucket_exists(): bucket = self._gcs_client.bucket(self.bucket_name) bucket.storage_class = storage_class - new_bucket = self._gcs_client.create_bucket(bucket, location=self.gcp_region) + self._gcs_client.create_bucket(bucket, location=self.gcp_region) assert self.bucket_exists() def list_objects(self, prefix="") -> Iterator[GCSObject]: diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index 45a2c8a1f..beee995c2 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -20,17 +20,10 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=None): - + def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=100): self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region self.bucket_name = bucket_name - self.pending_downloads, self.completed_downloads = 0, 0 - self.pending_uploads, self.completed_uploads = 0, 0 - self.s3_part_size = part_size - self.s3_throughput_target_gbps = throughput_target_gbps - # num_threads=os.cpu_count() - # num_threads=256 - num_threads = 4 # 256 + num_threads = 4 event_loop_group = EventLoopGroup(num_threads=num_threads, cpu_group=None) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) @@ -47,19 +40,11 @@ def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throug bootstrap=bootstrap, region=self.aws_region, credential_provider=credential_provider, - throughput_target_gbps=100, - part_size=None, + throughput_target_gbps=throughput_target_gbps, + part_size=part_size, tls_mode=S3RequestTlsMode.ENABLED if use_tls else S3RequestTlsMode.DISABLED, ) - def _on_done_download(self, **kwargs): - self.completed_downloads += 1 - self.pending_downloads -= 1 - - def _on_done_upload(self, **kwargs): - self.completed_uploads += 1 - self.pending_uploads -= 1 - @staticmethod def infer_s3_region(bucket_name: str): s3_client = AWSServer.get_boto3_client("s3") @@ -129,8 +114,6 @@ def _on_body_download(offset, chunk, **kwargs): recv_filepath=dst_file_path, request=request, type=S3RequestType.GET_OBJECT, - on_done=self._on_done_download, - on_body=_on_body_download, ).finished_future def upload_object(self, src_file_path, dst_object_name, content_type="infer") -> Future: @@ -145,5 +128,5 @@ def upload_object(self, src_file_path, dst_object_name, content_type="infer") -> upload_headers.add("Content-Length", str(content_len)) request = HttpRequest("PUT", dst_object_name, upload_headers) return self._s3_client.make_request( - send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT, on_done=self._on_done_upload + send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT ).finished_future From 688b21b26b13ce2b1d8e4caf81d09397f18b21b5 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:55:29 +0000 Subject: [PATCH 29/34] Cast keyfile to string --- skylark/compute/aws/aws_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 1b3e0677b..276bd5fe7 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -116,7 +116,7 @@ def get_ssh_client_impl(self): client.connect( self.public_ip(), username="ec2-user", - pkey=paramiko.RSAKey.from_private_key_file(self.local_keyfile), + pkey=paramiko.RSAKey.from_private_key_file(str(self.local_keyfile)), look_for_keys=False, allow_agent=False, banner_timeout=200, From 370188aa721ea5de1b5b11e7f9368bf20fa9ede5 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 21:57:06 +0000 Subject: [PATCH 30/34] Unused tests --- skylark/test/test_gateway_daemon.py | 39 ------ skylark/test/test_replicator_client.py | 165 ------------------------- 2 files changed, 204 deletions(-) delete mode 100644 skylark/test/test_gateway_daemon.py delete mode 100644 skylark/test/test_replicator_client.py diff --git a/skylark/test/test_gateway_daemon.py b/skylark/test/test_gateway_daemon.py deleted file mode 100644 index 7c8545120..000000000 --- a/skylark/test/test_gateway_daemon.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -from skylark.utils import logger - -from skylark.chunk import Chunk, ChunkRequest, ChunkRequestHop, ChunkState -from skylark.gateway.gateway_daemon import GatewayDaemon -from skylark.obj_store.s3_interface import S3Interface - -if __name__ == "__main__": - daemon = GatewayDaemon("/dev/shm/skylark/chunks", debug=True) - - # make obj store interfaces - src_obj_interface = S3Interface("us-east-1", "skylark-us-east-1") - dst_obj_interface = S3Interface("us-west-1", "skylark-us-west-1") - obj = "/test.txt" - - # make random test.txt file and upload it if it doesn't exist - if not src_obj_interface.exists(obj): - logger.info(f"Uploading {obj} to {src_obj_interface.bucket_name}") - test_file = Path("/tmp/test.txt") - test_file.write_text("test") - src_obj_interface.upload_object(test_file, obj).result() - - # make chunk request - file_size_bytes = src_obj_interface.get_obj_size(obj) - chunk = Chunk(key=obj, chunk_id=0, file_offset_bytes=0, chunk_length_bytes=file_size_bytes) - src_path = ChunkRequestHop( - hop_cloud_region="aws:us-east-1", - hop_ip_address="localhost", - chunk_location_type="src_object_store", - # src_object_store_region="us-east-1", - # src_object_store_bucket="skylark-us-east-1", - ) - req = ChunkRequest(chunk=chunk, path=[src_path]) - logger.debug(f"Chunk request: {req}") - daemon.chunk_store.add_chunk_request(req, ChunkState.registered) - - # run gateway daemon - daemon.run() diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py deleted file mode 100644 index 407e63dce..000000000 --- a/skylark/test/test_replicator_client.py +++ /dev/null @@ -1,165 +0,0 @@ -import argparse -from skylark.obj_store.object_store_interface import ObjectStoreInterface - -from skylark.utils import logger -from skylark import GB, MB, print_header - -import tempfile -import concurrent -import os - -# from skylark.obj_store.azure_interface import AzureInterface - -import tempfile -import concurrent -import os -from shutil import copyfile - -from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology -from skylark.replicate.replicator_client import ReplicatorClient - -from skylark.config import load_config - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run a replication job") - - # gateway path parameters - parser.add_argument("--src-region", default="aws:us-east-1", help="AWS region of source bucket") - parser.add_argument("--inter-region", default=None, help="AWS region of intermediate bucket") - parser.add_argument("--dest-region", default="aws:us-west-1", help="AWS region of destination bucket") - parser.add_argument("--num-gateways", default=1, type=int, help="Number of gateways to use") - - # object information - parser.add_argument("--key-prefix", default="/test/direct_replication", help="S3 key prefix for all objects") - parser.add_argument("--chunk-size-mb", default=128, type=int, help="Chunk size in MB") - parser.add_argument("--n-chunks", default=512, type=int, help="Number of chunks in bucket") - parser.add_argument("--skip-upload", action="store_true", help="Skip uploading objects to S3") - - # bucket namespace - parser.add_argument("--bucket-prefix", default="sarah", help="Prefix for bucket to avoid naming collision") - - # gateway provisioning - parser.add_argument("--gcp-project", default=None, help="GCP project ID") - parser.add_argument("--azure-subscription", default=None, help="Azure subscription") - parser.add_argument("--gateway-docker-image", default="ghcr.io/parasj/skylark:main", help="Docker image for gateway instances") - parser.add_argument("--aws-instance-class", default="m5.4xlarge", help="AWS instance class") - parser.add_argument("--azure-instance-class", default="Standard_D2_v5", help="Azure instance class") - parser.add_argument("--gcp-instance-class", default="n2-standard-16", help="GCP instance class") - parser.add_argument("--copy-ssh-key", default=None, help="SSH public key to add to gateways") - parser.add_argument("--log-dir", default=None, help="Directory to write instance SSH logs to") - parser.add_argument("--gcp-use-premium-network", action="store_true", help="Use GCP premium network") - args = parser.parse_args() - - # add support for None arguments - if args.aws_instance_class == "None": - args.aws_instance_class = None - if args.azure_instance_class == "None": - args.azure_instance_class = None - if args.gcp_instance_class == "None": - args.gcp_instance_class = None - - return args - - -def main(args): - config = load_config() - gcp_project = args.gcp_project or config.get("gcp_project_id") - azure_subscription = args.azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") - - src_bucket = f"{args.bucket_prefix}-skylark-{args.src_region.split(':')[1]}" - dst_bucket = f"{args.bucket_prefix}-skylark-{args.dest_region.split(':')[1]}" - obj_store_interface_src = ObjectStoreInterface.create(args.src_region, src_bucket) - obj_store_interface_src.create_bucket() - obj_store_interface_dst = ObjectStoreInterface.create(args.dest_region, dst_bucket) - obj_store_interface_dst.create_bucket() - - # TODO: fix this to get the key instead of S3Object - if not args.skip_upload: - logger.info(f"Not skipping upload, source bucket is {src_bucket}, destination bucket is {dst_bucket}") - - # TODO: fix this to get the key instead of S3Object - matching_src_keys = list([obj.key for obj in obj_store_interface_src.list_objects(prefix=args.key_prefix)]) - if matching_src_keys and not args.skip_upload: - logger.warning(f"Deleting {len(matching_src_keys)} objects from source bucket") - obj_store_interface_src.delete_objects(matching_src_keys) - - # create test objects w/ random data - logger.info("Creating test objects") - obj_keys = [] - futures = [] - tmp_files = [] - - # TODO: for n_chunks > 880, get syscall error - with tempfile.NamedTemporaryFile() as f: - f.write(os.urandom(int(MB * args.chunk_size_mb))) - f.seek(0) - for i in range(args.n_chunks): - k = f"{args.key_prefix}/{i}" - tmp_file = f"{f.name}-{i}" - # need to copy, since GCP API will open file and cause to delete - copyfile(f.name, f"{f.name}-{i}") - futures.append(obj_store_interface_src.upload_object(tmp_file, k)) - obj_keys.append(k) - tmp_files.append(tmp_file) - - logger.info(f"Uploading {len(obj_keys)} to bucket {src_bucket}") - concurrent.futures.wait(futures) - - matching_dst_keys = list([obj.key for obj in obj_store_interface_dst.list_objects(prefix=args.key_prefix)]) - if matching_dst_keys: - logger.warning(f"Deleting {len(matching_dst_keys)} objects from destination bucket") - obj_store_interface_dst.delete_objects(matching_dst_keys) - - # cleanup temp files once done - for f in tmp_files: - os.remove(f) - else: - obj_keys = [f"{args.key_prefix}/{i}" for i in range(args.n_chunks)] - - # define the replication job and topology - if args.inter_region: - topo = ReplicationTopology() - for i in range(args.num_gateways): - topo.add_edge(args.src_region, i, args.inter_region, i, args.num_outgoing_connections) - topo.add_edge(args.inter_region, i, args.dest_region, i, args.num_outgoing_connections) - else: - topo = ReplicationTopology() - for i in range(args.num_gateways): - topo.add_edge(args.src_region, i, args.dest_region, i, args.num_outgoing_connections) - logger.info("Creating replication client") - - # Getting configs - rc = ReplicatorClient( - topo, - gcp_project=gcp_project, - azure_subscription=azure_subscription, - gateway_docker_image=args.gateway_docker_image, - aws_instance_class=args.aws_instance_class, - azure_instance_class=args.azure_instance_class, - gcp_instance_class=args.gcp_instance_class, - gcp_use_premium_network=args.gcp_use_premium_network, - ) - - # provision the gateway instances - logger.info("Provisioning gateway instances") - rc.provision_gateways(reuse_instances=True, log_dir=args.log_dir, authorize_ssh_pub_key=args.copy_ssh_key) - for node, gw in rc.bound_nodes.items(): - logger.info(f"Provisioned {node}: {gw.gateway_log_viewer_url}") - - # run replication, monitor progress - job = ReplicationJob( - source_region=args.src_region, source_bucket=src_bucket, dest_region=args.dest_region, dest_bucket=dst_bucket, objs=obj_keys - ) - - total_bytes = args.n_chunks * args.chunk_size_mb * MB - job = rc.run_replication_plan(job) - logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer(job, show_pbar=True, cancel_pending=False) - logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}Gbit/s)") - - -if __name__ == "__main__": - print_header() - main(parse_args()) From 5aa04930d36b8403393072d2304c558aa6275ea2 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 22:11:59 +0000 Subject: [PATCH 31/34] Partial cleanup from patch --- scripts/experiment.sh | 3 +- scripts/experiment_paras.sh | 54 +++++++++++++++++++++ scripts/setup_bucket.py | 34 ------------- skylark/compute/utils.py | 35 +++++++++++++ skylark/gateway/chunk_store.py | 10 ++-- skylark/gateway/gateway_obj_store.py | 21 ++------ skylark/obj_store/object_store_interface.py | 16 +++--- skylark/utils/logger.py | 7 +-- skylark/utils/utils.py | 12 +++-- 9 files changed, 119 insertions(+), 73 deletions(-) create mode 100755 scripts/experiment_paras.sh diff --git a/scripts/experiment.sh b/scripts/experiment.sh index 3c93094cd..6db6a31cf 100755 --- a/scripts/experiment.sh +++ b/scripts/experiment.sh @@ -36,7 +36,7 @@ echo $filename export GOOGLE_APPLICATION_CREDENTIALS="/home/ubuntu/.skylark-shishir-42be5f375b7a.json" # creats buckets + bucket data and sets env variables -python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --gcp-project skylark-shishir --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} +python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} # TODO:artificially increase the number of chunks @@ -67,7 +67,6 @@ cp ${filename} data/results/${experiment} # run replication (obj store) skylark replicate-json ${filename} \ - --gcp-project skylark-shishir \ --source-bucket $src_bucket \ --dest-bucket $dest_bucket \ --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt diff --git a/scripts/experiment_paras.sh b/scripts/experiment_paras.sh new file mode 100755 index 000000000..f72078d33 --- /dev/null +++ b/scripts/experiment_paras.sh @@ -0,0 +1,54 @@ +#!/bin/bash +set -xe + +src=$1 +dest=$2 + +key_prefix="fake_imagenet" +bucket_prefix="exps-paras" +src_bucket=(${src//:/ }) +src_bucket=${bucket_prefix}-skylark-${src_bucket[1]} +dest_bucket=(${dest//:/ }) +dest_bucket=${bucket_prefix}-skylark-${dest_bucket[1]} +echo $src_bucket +echo $dest_bucket +max_instance=1 +experiment=${src//[:]/-}_${dest//[:]/-}_${max_instance}_${key_prefix//[\/]/-} +filename=data/plan/${experiment}.json +echo $filename + +# creats buckets + bucket data and sets env variables +python scripts/setup_bucket.py --key-prefix ${key_prefix} --bucket-prefix ${bucket_prefix} --src-data-path ../${key_prefix}/ --src-region ${src} --dest-region ${dest} + +# TODO:artificially increase the number of chunks +# TODO: try synthetic data + +source scripts/pack_docker.sh + +## create plan +throughput=$(($max_instance*3)) +# throughput=25 +skylark solver solve-throughput ${src} ${dest} ${throughput} -o ${filename} --max-instances ${max_instance}; +echo ${filename} + +# make exp directory +mkdir -p data/results +mkdir -p data/results/${experiment} + +# save copy of plan +cp ${filename} data/results/${experiment} + +## run replication (random) +#skylark replicate-json ${filename} \ +# --use-random-data \ +# --size-total-mb 73728 \ +# --n-chunks 1152 &> data/results/${experiment}/random-logs.txt +#tail -1 data/results/${experiment}/random-logs.txt; + +# run replication (obj store) +skylark replicate-json ${filename} \ + --source-bucket $src_bucket \ + --dest-bucket $dest_bucket \ + --key-prefix ${key_prefix} > data/results/${experiment}/obj-store-logs.txt +tail -1 data/results/${experiment}/obj-store-logs.txt; +echo ${experiment} diff --git a/scripts/setup_bucket.py b/scripts/setup_bucket.py index 7f4be455a..3d67819b4 100644 --- a/scripts/setup_bucket.py +++ b/scripts/setup_bucket.py @@ -7,30 +7,15 @@ from multiprocessing import Pool from concurrent.futures import wait -import ctypes - -libgcc_s = ctypes.CDLL("libgcc_s.so.1") - def parse_args(): parser = argparse.ArgumentParser(description="Setup replication experiment") - parser.add_argument("--src-data-path", default="../fake_imagenet", help="Data to upload to src bucket") - - # gateway path parameters parser.add_argument("--src-region", default="aws:us-east-1", help="AWS region of source bucket") parser.add_argument("--dest-region", default="aws:us-west-1", help="AWS region of destination bucket") - - # bucket namespace parser.add_argument("--bucket-prefix", default="sarah", help="Prefix for bucket to avoid naming collision") parser.add_argument("--key-prefix", default="", help="Prefix keys") - - # gateway provisioning - parser.add_argument("--gcp-project", default=None, help="GCP project ID") - parser.add_argument("--azure-subscription", default=None, help="Azure subscription") - parser.add_argument("--gateway-docker-image", default="ghcr.io/parasj/skylark:main", help="Docker image for gateway instances") args = parser.parse_args() - return args @@ -56,13 +41,6 @@ def main(args): obj_store_interface_dst.create_bucket() print("running upload... (note: may need to chunk)") - - ## TODO: chunkify - # p = Pool(16) - # uploaded = p.starmap(upload, [(args.src_region, src_bucket, os.path.join(args.src_data_path, f), f"{args.key_prefix}/{f}") for f in os.listdir(args.src_data_path)]) - # p.close() - # print(f"uploaded {sum(uploaded)} files to {src_bucket}") - futures = [] for f in tqdm(os.listdir(args.src_data_path)): futures.append(obj_store_interface_src.upload_object(os.path.join(args.src_data_path, f), f"{args.key_prefix}/{f}")) @@ -71,18 +49,6 @@ def main(args): wait(futures) futures = [] - ### check files - # for f in tqdm(os.listdir(args.src_data_path)): - # assert obj_store_interface_src.exists(f"{args.key_prefix}/{f}") - - def done_uploading(): - bucket_size = len(list(obj_store_interface_src.list_objects(prefix=args.key_prefix))) - # f"Length mismatch {len(os.listdir(args.src_data_path))}, {bucket_size}" - print("bucket", bucket_size, len(os.listdir(args.src_data_path))) - return len(os.listdir(args.src_data_path)) == bucket_size - - ##wait_for(done_uploading, timeout=60, interval=0.1, desc=f"Waiting for files to upload") - if __name__ == "__main__": main(parse_args()) diff --git a/skylark/compute/utils.py b/skylark/compute/utils.py index b5e80f041..6a21f1a81 100644 --- a/skylark/compute/utils.py +++ b/skylark/compute/utils.py @@ -1,6 +1,41 @@ +from functools import lru_cache +import subprocess from skylark.utils import logger +@lru_cache +def query_which_cloud() -> str: + if ( + subprocess.call( + 'curl -f --noproxy "*" http://169.254.169.254/1.0/meta-data/instance-id'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "aws" + elif ( + subprocess.call( + 'curl -f -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance?api-version=2021-02-01"'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "azure" + elif ( + subprocess.call( + 'curl -f --noproxy "*" http://metadata.google.internal/computeMetadata/v1/instance/hostname'.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + == 0 + ): + return "gcp" + else: + return "unknown" + + def make_dozzle_command(port): cmd = """sudo docker run -d --rm --name dozzle \ -p {log_viewer_port}:8080 \ diff --git a/skylark/gateway/chunk_store.py b/skylark/gateway/chunk_store.py index 677c4f1d5..e212b7dad 100644 --- a/skylark/gateway/chunk_store.py +++ b/skylark/gateway/chunk_store.py @@ -68,21 +68,21 @@ def state_finish_download(self, chunk_id: int, receiver_id: Optional[str] = None if state in [ChunkState.download_in_progress, ChunkState.downloaded]: self.set_chunk_state(chunk_id, ChunkState.downloaded, {"receiver_id": receiver_id}) else: - raise ValueError(f"Invalid transition finish_download from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition finish_download from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_queue_upload(self, chunk_id: int): state = self.get_chunk_state(chunk_id) if state in [ChunkState.downloaded, ChunkState.upload_queued]: self.set_chunk_state(chunk_id, ChunkState.upload_queued) else: - raise ValueError(f"Invalid transition upload_queued from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition upload_queued from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_start_upload(self, chunk_id: int, sender_id: Optional[str] = None): state = self.get_chunk_state(chunk_id) if state in [ChunkState.upload_queued, ChunkState.upload_in_progress]: self.set_chunk_state(chunk_id, ChunkState.upload_in_progress, {"sender_id": sender_id}) else: - raise ValueError(f"Invalid transition start_upload from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition start_upload from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_finish_upload(self, chunk_id: int, sender_id: Optional[str] = None): # todo log runtime to statistics store @@ -90,13 +90,13 @@ def state_finish_upload(self, chunk_id: int, sender_id: Optional[str] = None): if state in [ChunkState.upload_in_progress, ChunkState.upload_complete]: self.set_chunk_state(chunk_id, ChunkState.upload_complete, {"sender_id": sender_id}) else: - raise ValueError(f"Invalid transition finish_upload from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition finish_upload from {self.get_chunk_state(chunk_id)} (id={chunk_id})") def state_fail(self, chunk_id: int): if self.get_chunk_state(chunk_id) != ChunkState.upload_complete: self.set_chunk_state(chunk_id, ChunkState.failed) else: - raise ValueError(f"Invalid transition fail from {self.get_chunk_state(chunk_id)}") + raise ValueError(f"Invalid transition fail from {self.get_chunk_state(chunk_id)} (id={chunk_id})") ### # Chunk management diff --git a/skylark/gateway/gateway_obj_store.py b/skylark/gateway/gateway_obj_store.py index dd6116085..37c1f3b27 100644 --- a/skylark/gateway/gateway_obj_store.py +++ b/skylark/gateway/gateway_obj_store.py @@ -12,6 +12,8 @@ from dataclasses import dataclass +from skylark.utils.utils import retry_backoff + @dataclass class ObjStoreRequest: @@ -66,10 +68,8 @@ def worker_loop(self, worker_id: int): request = self.worker_queue.get_nowait() chunk_req = request.chunk_req req_type = request.req_type - except queue.Empty: continue - fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute()) logger.debug(f"[obj_store:{self.worker_id}] Received chunk ID {chunk_req.chunk.chunk_id}") @@ -77,47 +77,34 @@ def worker_loop(self, worker_id: int): assert chunk_req.dst_type == "object_store" region = chunk_req.dst_region bucket = chunk_req.dst_object_store_bucket - self.chunk_store.state_start_upload(chunk_req.chunk.chunk_id, f"obj_store:{self.worker_id}") - logger.debug(f"[obj_store:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {bucket}") def upload(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) - obj_store_interface.upload_object(fpath, key).result() + retry_backoff(lambda: obj_store_interface.upload_object(fpath, key).result(), max_retries=4) chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) - - # update chunk state self.chunk_store.state_finish_upload(chunk_id, f"obj_store:{self.worker_id}") - - # delete chunk chunk_file_path.unlink() - logger.debug(f"[obj_store:{self.worker_id}] Uploaded {chunk_id} to {bucket}") # wait for upload in seperate thread threading.Thread(target=upload, args=(region, bucket, fpath, chunk_req.chunk.key, chunk_req.chunk.chunk_id)).start() - elif req_type == "download": assert chunk_req.src_type == "object_store" region = chunk_req.src_region bucket = chunk_req.src_object_store_bucket - self.chunk_store.state_start_download(chunk_req.chunk.chunk_id, f"obj_store:{self.worker_id}") - logger.debug(f"[obj_store:{self.worker_id}] Starting download {chunk_req.chunk.chunk_id} from {bucket}") def download(region, bucket, fpath, key, chunk_id): obj_store_interface = self.get_obj_store_interface(region, bucket) - obj_store_interface.download_object(key, fpath).result() - - # update chunk state + retry_backoff(lambda: obj_store_interface.download_object(key, fpath).result(), max_retries=4) self.chunk_store.state_finish_download(chunk_id, f"obj_store:{self.worker_id}") logger.debug(f"[obj_store:{self.worker_id}] Downloaded {chunk_id} from {bucket}") # wait for request to return in sepearte thread, so we can update chunk state threading.Thread(target=download, args=(region, bucket, fpath, chunk_req.chunk.key, chunk_req.chunk.chunk_id)).start() - else: raise ValueError(f"Invalid location for chunk req, {req_type}: {chunk_req.src_type}->{chunk_req.dst_type}") diff --git a/skylark/obj_store/object_store_interface.py b/skylark/obj_store/object_store_interface.py index 2efb0b6bb..a769d4e3d 100644 --- a/skylark/obj_store/object_store_interface.py +++ b/skylark/obj_store/object_store_interface.py @@ -12,30 +12,30 @@ class ObjectStoreObject: last_modified: str def full_path(self): - raise NotImplementedError + raise NotImplementedError() class ObjectStoreInterface: def bucket_exists(self): - raise NotImplementedError + raise NotImplementedError() def create_bucket(self): - raise NotImplementedError + raise NotImplementedError() def delete_bucket(self): - raise NotImplementedError + raise NotImplementedError() def list_objects(self, prefix=""): - raise NotImplementedError + raise NotImplementedError() def get_obj_size(self, obj_name): - raise NotImplementedError + raise NotImplementedError() def download_object(self, src_object_name, dst_file_path): - raise NotImplementedError + raise NotImplementedError() def upload_object(self, src_file_path, dst_object_name, content_type="infer"): - raise NotImplementedError + raise NotImplementedError() @staticmethod def create(region_tag: str, bucket: str): diff --git a/skylark/utils/logger.py b/skylark/utils/logger.py index b4bcbebf0..38f0dd8b1 100644 --- a/skylark/utils/logger.py +++ b/skylark/utils/logger.py @@ -20,8 +20,9 @@ def log(msg, LEVEL="INFO", color="white", *args, **kwargs): error = partial(log, LEVEL="ERROR", color="red") -def exception(msg, *args, **kwargs): +def exception(msg, print_traceback=True, *args, **kwargs): error(f"Exception: {msg}", *args, **kwargs) - import traceback + if print_traceback: + import traceback - traceback.print_exc() + traceback.print_exc() diff --git a/skylark/utils/utils.py b/skylark/utils/utils.py index 1fb823932..ed5271c8c 100644 --- a/skylark/utils/utils.py +++ b/skylark/utils/utils.py @@ -48,7 +48,7 @@ def wait_for(fn: Callable[[], bool], timeout=60, interval=0.25, progress_bar=Fal def do_parallel( - func: Callable[[T], R], args_list: Iterable[T], n=-1, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None + func: Callable[[T], R], args_list: Iterable[T], n=-1, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None, hide_args=False ) -> List[Tuple[T, R]]: """Run list of jobs in parallel with tqdm progress bar""" args_list = list(args_list) @@ -71,15 +71,19 @@ def wrapped_fn(args): for future in as_completed(future_list): args, result = future.result() results.append((args, result)) - pbar.set_description(f"{desc} ({str(arg_fmt(args))})" if desc else str(arg_fmt(args))) + if not hide_args: + pbar.set_description(f"{desc} ({str(arg_fmt(args))})" if desc else str(arg_fmt(args))) + else: + pbar.set_description(desc) pbar.update() return results def retry_backoff( fn: Callable[[], R], - max_retries=4, + max_retries=8, initial_backoff=0.1, + max_backoff=8, exception_class=Exception, ) -> R: """Retry fn until it does not raise an exception. @@ -97,4 +101,4 @@ def retry_backoff( else: logger.warning(f"Retrying {fn.__name__} due to {e} (attempt {i + 1}/{max_retries})") time.sleep(backoff) - backoff *= 2 + backoff = min(backoff * 2, max_backoff) From 77600f4969a30061c6e4c4f3078209a2763fad38 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 22:20:37 +0000 Subject: [PATCH 32/34] Minor cleanup --- skylark/cli/cli_helper.py | 2 +- skylark/gateway/gateway_sender.py | 2 +- skylark/obj_store/azure_interface.py | 3 --- skylark/obj_store/s3_interface.py | 4 +--- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index d35d73950..ccd15d88a 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -8,7 +8,7 @@ from shutil import copyfile from typing import Dict, List from sys import platform -from typing import Dict, List, Optional +from typing import Dict, List import boto3 import typer diff --git a/skylark/gateway/gateway_sender.py b/skylark/gateway/gateway_sender.py index fcdf8bb19..9c9720ad2 100644 --- a/skylark/gateway/gateway_sender.py +++ b/skylark/gateway/gateway_sender.py @@ -1,6 +1,6 @@ import queue import socket -from multiprocessing import Event, Manager, Process, Value +from multiprocessing import Event, Manager, Process from typing import Dict, List, Optional import requests diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index b02cfda50..3118feb5f 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -4,9 +4,6 @@ from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_server import AzureServer -from azure.identity import ClientSecretCredential -from azure.identity import ClientSecretCredential -from azure.storage.blob import BlobServiceClient from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index b6b05bd90..8bb578894 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -116,6 +116,4 @@ def upload_object(self, src_file_path, dst_object_name, content_type="infer") -> upload_headers.add("Content-Type", content_type) upload_headers.add("Content-Length", str(content_len)) request = HttpRequest("PUT", dst_object_name, upload_headers) - return self._s3_client.make_request( - send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT - ).finished_future + return self._s3_client.make_request(send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT).finished_future From 51857f73c241c619a76d8fe26f51d69c80b6d032 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 22:31:45 +0000 Subject: [PATCH 33/34] Fix AzureInterface --- skylark/cli/cli_helper.py | 4 ++-- skylark/obj_store/azure_interface.py | 15 +++++++++++++-- skylark/obj_store/object_store_interface.py | 3 ++- skylark/test/test_azure_interface.py | 3 ++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index ccd15d88a..62d68f4c2 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -158,12 +158,12 @@ def copy_gcs_local(src_bucket: str, src_key: str, dst: Path): def copy_local_azure(src: Path, dst_account_name: str, dst_container_name: str, dst_key: str): - azure = AzureInterface(dst_key, dst_account_name, dst_container_name) + azure = AzureInterface(None, dst_account_name, dst_container_name) return copy_local_objstore(azure, src, dst_key) def copy_azure_local(src_account_name: str, src_container_name: str, src_key: str, dst: Path): - azure = AzureInterface(src_key, src_account_name, src_container_name) + azure = AzureInterface(None, src_account_name, src_container_name) return copy_objstore_local(azure, src_key, dst) diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 3118feb5f..affd269f6 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -15,10 +15,11 @@ def full_path(self): class AzureInterface(ObjectStoreInterface): - def __init__(self, azure_region, container_name): - self.azure_region = azure_region + def __init__(self, azure_region, account_name, container_name): + # TODO (#210): should be configured via argument self.account_name = f"skylark{azure_region.replace(' ', '').lower()}" self.container_name = container_name + # Create a blob service client self.auth = AzureAuthentication() self.account_url = f"https://{self.account_name}.blob.core.windows.net" @@ -26,10 +27,20 @@ def __init__(self, azure_region, container_name): self.container_client = self.auth.get_container_client(self.account_url, self.container_name) self.blob_service_client = self.auth.get_blob_service_client(self.account_url) + # infer azure region from storage account + if azure_region is None: + self.azure_region = self.get_region_from_storage_account(self.account_name) + else: + self.azure_region = azure_region + # parallel upload/downloads self.pool = ThreadPoolExecutor(max_workers=256) # TODO: This might need some tuning self.max_concurrency = 1 + def get_region_from_storage_account(self, storage_account_name): + storage_account = self.storage_management_client.storage_accounts.get_properties(AzureServer.resource_group_name, storage_account_name) + return storage_account.location + def storage_account_exists(self): try: self.storage_management_client.storage_accounts.get_properties(AzureServer.resource_group_name, self.account_name) diff --git a/skylark/obj_store/object_store_interface.py b/skylark/obj_store/object_store_interface.py index a769d4e3d..4edba8b65 100644 --- a/skylark/obj_store/object_store_interface.py +++ b/skylark/obj_store/object_store_interface.py @@ -50,7 +50,8 @@ def create(region_tag: str, bucket: str): elif region_tag.startswith("azure"): from skylark.obj_store.azure_interface import AzureInterface - return AzureInterface(region_tag.split(":")[1], bucket) + # TODO (#210): should be configured via argument + return AzureInterface(region_tag.split(":")[1], None, bucket) else: raise ValueError(f"Invalid region_tag {region_tag} - could not create interface") diff --git a/skylark/test/test_azure_interface.py b/skylark/test/test_azure_interface.py index 42c311e28..68d8f6928 100644 --- a/skylark/test/test_azure_interface.py +++ b/skylark/test/test_azure_interface.py @@ -8,9 +8,10 @@ def test_azure_interface(): - azure_interface = AzureInterface(f"eastus", f"sky-us-east-1") + azure_interface = AzureInterface("eastus", "skyeastus", "sky-us-east-1") assert azure_interface.container_name == "sky-us-east-1" assert azure_interface.azure_region == "eastus" + assert azure_interface.account_name == "skyeastus" azure_interface.create_bucket() # generate file and upload From d861212c7e1e15cb559d7cac700e0877693da14c Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Wed, 16 Mar 2022 22:33:57 +0000 Subject: [PATCH 34/34] Fix another Azure misconfiguration --- skylark/cli/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index e584947ec..77c2d4737 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -94,7 +94,8 @@ def cp(src: str, dst: str): account_name, container_name = bucket_dst copy_local_azure(Path(path_src), account_name, container_name, path_dst) elif provider_src == "azure" and provider_dst == "local": - copy_azure_local(bucket_src, path_src, Path(path_dst)) + account_name, container_name = bucket_dst + copy_azure_local(account_name, container_name, path_src, Path(path_dst)) else: raise NotImplementedError(f"{provider_src} to {provider_dst} not supported yet")