diff --git a/scripts/requirements-gateway.txt b/scripts/requirements-gateway.txt index 32793883e..afa7dbc09 100644 --- a/scripts/requirements-gateway.txt +++ b/scripts/requirements-gateway.txt @@ -3,11 +3,14 @@ azure-identity azure-mgmt-compute azure-mgmt-network azure-mgmt-resource +azure-mgmt-storage +azure-mgmt-authorization azure-storage-blob>=12.0.0 boto3 click>=7.1.2 flask google-api-python-client +google-auth google-cloud-compute google-cloud-storage grpcio-status>=1.33.2 diff --git a/setup.py b/setup.py index df8575268..8d88bb9c5 100644 --- a/setup.py +++ b/setup.py @@ -12,11 +12,14 @@ "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", + "azure-mgmt-storage", + "azure-mgmt-authorization", "azure-storage-blob>=12.0.0", "boto3", "click>=7.1.2", "flask", "google-api-python-client", + "google-auth", "google-cloud-compute", "google-cloud-storage", "grpcio-status>=1.33.2", diff --git a/skylark/__init__.py b/skylark/__init__.py index 401632eed..d80cae8b2 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -1,10 +1,22 @@ +import os + from pathlib import Path +from skylark.config import SkylarkConfig + # paths skylark_root = Path(__file__).parent.parent -key_root = skylark_root / "data" / "keys" -config_file = skylark_root / "data" / "config.json" +config_root = Path("~/.skylark").expanduser() +config_root.mkdir(exist_ok=True) + +if "SKYLARK_CONFIG" in os.environ: + config_path = Path(os.environ["SKYLARK_CONFIG"]).expanduser() +else: + config_path = config_root / "config" + +key_root = config_root / "keys" tmp_log_dir = Path("/tmp/skylark") +tmp_log_dir.mkdir(exist_ok=True) # header def print_header(): @@ -26,3 +38,7 @@ def print_header(): KB = 1024 MB = 1024 * 1024 GB = 1024 * 1024 * 1024 +if config_path.exists(): + cloud_config = SkylarkConfig.load_config(config_path) +else: + cloud_config = SkylarkConfig() diff --git a/skylark/benchmark/traceroute/traceroute.py b/skylark/benchmark/traceroute/traceroute.py index 7462aa128..7a64e0778 100644 --- a/skylark/benchmark/traceroute/traceroute.py +++ b/skylark/benchmark/traceroute/traceroute.py @@ -46,7 +46,7 @@ def main(args): log_dir.mkdir(exist_ok=True, parents=True) aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project) + gcp = GCPCloudProvider() aws_instances, gcp_instances = provision( aws=aws, gcp=gcp, diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 98a5e9013..77c2d4737 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -25,8 +25,9 @@ import skylark.cli.cli_solver import skylark.cli.experiments import typer +from skylark.config import SkylarkConfig from skylark.utils import logger -from skylark import GB, MB, config_file, print_header +from skylark import config_path, GB, MB, print_header from skylark.cli.cli_helper import ( check_ulimit, copy_azure_local, @@ -39,11 +40,13 @@ copy_gcs_local, copy_local_gcs, deprovision_skylark_instances, + load_aws_config, + load_azure_config, + load_gcp_config, ls_local, ls_s3, parse_path, ) -from skylark.config import load_config from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology from skylark.replicate.replicator_client import ReplicatorClient from skylark.obj_store.object_store_interface import ObjectStoreInterface @@ -59,11 +62,6 @@ @app.command() def ls(directory: str): """List objects in the object store.""" - config = load_config() - gcp_project = config.get("gcp_project_id") - azure_subscription = config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") - check_ulimit() provider, bucket, key = parse_path(directory) if provider == "local": for path in ls_local(Path(directory)): @@ -77,10 +75,6 @@ def ls(directory: str): def cp(src: str, dst: str): """Copy objects from the object store to the local filesystem.""" print_header() - config = load_config() - gcp_project = config.get("gcp_project_id") - azure_subscription = config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() provider_src, bucket_src, path_src = parse_path(src) @@ -97,9 +91,11 @@ def cp(src: str, dst: str): elif provider_src == "gs" and provider_dst == "local": copy_gcs_local(bucket_src, path_src, Path(path_dst)) elif provider_src == "local" and provider_dst == "azure": - copy_local_azure(Path(path_src), bucket_dst, path_dst) + account_name, container_name = bucket_dst + copy_local_azure(Path(path_src), account_name, container_name, path_dst) elif provider_src == "azure" and provider_dst == "local": - copy_azure_local(bucket_src, path_src, Path(path_dst)) + account_name, container_name = bucket_dst + copy_azure_local(account_name, container_name, path_src, Path(path_dst)) else: raise NotImplementedError(f"{provider_src} to {provider_dst} not supported yet") @@ -116,11 +112,9 @@ def replicate_random( total_transfer_size_mb: int = typer.Option(2048, "--size-total-mb", "-s", help="Total transfer size in MB."), chunk_size_mb: int = typer.Option(8, "--chunk-size-mb", help="Chunk size in MB."), reuse_gateways: bool = False, - azure_subscription: Optional[str] = None, - gcp_project: Optional[str] = None, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), aws_instance_class: str = "m5.8xlarge", - azure_instance_class: str = "Standard_D32_v5", + azure_instance_class: str = "Standard_D32_v4", gcp_instance_class: Optional[str] = "n2-standard-32", gcp_use_premium_network: bool = True, key_prefix: str = "/test/replicate_random", @@ -129,10 +123,6 @@ def replicate_random( ): """Replicate objects from remote object store to another remote object store.""" print_header() - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() if inter_region: @@ -149,8 +139,6 @@ def replicate_random( rc = ReplicatorClient( topo, - azure_subscription=azure_subscription, - gcp_project=gcp_project, gateway_docker_image=gateway_docker_image, aws_instance_class=aws_instance_class, azure_instance_class=azure_instance_class, @@ -186,6 +174,11 @@ def replicate_random( stats = rc.monitor_transfer(job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} + + if not reuse_gateways: + atexit.unregister(rc.deprovision_gateways) + rc.deprovision_gateways() + typer.echo(f"\n{json.dumps(out_json)}") return 0 if stats["success"] else 1 @@ -204,10 +197,8 @@ def replicate_json( reuse_gateways: bool = False, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), # cloud provider specific options - azure_subscription: Optional[str] = None, - gcp_project: Optional[str] = None, aws_instance_class: str = "m5.8xlarge", - azure_instance_class: str = "Standard_D32_v5", + azure_instance_class: str = "Standard_D32_v4", gcp_instance_class: Optional[str] = "n2-standard-32", gcp_use_premium_network: bool = True, # logging options @@ -216,10 +207,6 @@ def replicate_json( ): """Replicate objects from remote object store to another remote object store.""" print_header() - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded gcp_project: {gcp_project}, azure_subscription: {azure_subscription}") check_ulimit() with path.open("r") as f: @@ -227,8 +214,6 @@ def replicate_json( rc = ReplicatorClient( topo, - azure_subscription=azure_subscription, - gcp_project=gcp_project, gateway_docker_image=gateway_docker_image, aws_instance_class=aws_instance_class, azure_instance_class=azure_instance_class, @@ -250,8 +235,6 @@ def replicate_json( logger.warning(f"total_transfer_size_mb ({size_total_mb}) is not a multiple of n_chunks ({n_chunks})") chunk_size_mb = size_total_mb // n_chunks - print("REGION", topo.source_region()) - if use_random_data: job = ReplicationJob( source_region=topo.source_region(), @@ -283,104 +266,47 @@ def replicate_json( total_bytes = sum([chunk_req.chunk.chunk_length_bytes for chunk_req in job.chunk_requests]) logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") - stats = rc.monitor_transfer( - job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds, cancel_pending=False - ) + stats = rc.monitor_transfer(job, show_pbar=True, log_interval_s=log_interval_s, time_limit_seconds=time_limit_seconds) stats["success"] = stats["monitor_status"] == "completed" out_json = {k: v for k, v in stats.items() if k not in ["log", "completed_chunk_ids"]} + + # deprovision + if not reuse_gateways: + atexit.unregister(rc.deprovision_gateways) + rc.deprovision_gateways() + typer.echo(f"\n{json.dumps(out_json)}") return 0 if stats["success"] else 1 @app.command() -def deprovision(azure_subscription: Optional[str] = None, gcp_project: Optional[str] = None): +def deprovision(): """Deprovision gateways.""" - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded from config file: gcp_project={gcp_project}, azure_subscription={azure_subscription}") - deprovision_skylark_instances(azure_subscription=azure_subscription, gcp_project_id=gcp_project) + deprovision_skylark_instances() @app.command() -def init( - azure_tenant_id: str = typer.Option(None, envvar="AZURE_TENANT_ID", prompt="Azure tenant ID"), - azure_client_id: str = typer.Option(None, envvar="AZURE_CLIENT_ID", prompt="Azure client ID"), - azure_client_secret: str = typer.Option(None, envvar="AZURE_CLIENT_SECRET", prompt="Azure client secret"), - azure_subscription_id: str = typer.Option(None, envvar="AZURE_SUBSCRIPTION_ID", prompt="Azure subscription ID"), - gcp_application_credentials_file: Path = typer.Option( - None, - envvar="GOOGLE_APPLICATION_CREDENTIALS", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - help="Path to GCP application credentials file (usually a JSON file)", - ), - gcp_project: str = typer.Option(None, envvar="GCP_PROJECT_ID", prompt="GCP project ID"), -): - out_config = {} - - # AWS config - def load_aws_credentials(): - if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: - return os.environ["AWS_ACCESS_KEY_ID"], os.environ["AWS_SECRET_ACCESS_KEY"] - if (Path.home() / ".aws" / "credentials").exists(): - with open(Path.home() / ".aws" / "credentials") as f: - access_key, secret_key = None, None - lines = f.readlines() - for line in lines: - if line.startswith("aws_access_key_id"): - access_key = line.split("=")[1].strip() - if line.startswith("aws_secret_access_key"): - secret_key = line.split("=")[1].strip() - if access_key and secret_key: - return access_key, secret_key - return None, None - - aws_access_key, aws_secret_key = load_aws_credentials() - if aws_access_key is None: - aws_access_key = typer.prompt("AWS access key") - assert aws_access_key is not None and aws_access_key != "" - if aws_secret_key is None: - aws_secret_key = typer.prompt("AWS secret key") - assert aws_secret_key is not None and aws_secret_key != "" - - if config_file.exists(): - typer.confirm("Config file already exists. Overwrite?", abort=True) - - out_config["aws_access_key_id"] = aws_access_key - out_config["aws_secret_access_key"] = aws_secret_key - - # Azure config - typer.secho("Azure config can be generated using: az ad sp create-for-rbac -n api://skylark --sdk-auth", fg=typer.colors.GREEN) - if azure_tenant_id is not None or len(azure_tenant_id) > 0: - logger.info(f"Setting Azure tenant ID to {azure_tenant_id}") - out_config["azure_tenant_id"] = azure_tenant_id - if azure_client_id is not None or len(azure_client_id) > 0: - logger.info(f"Setting Azure client ID to {azure_client_id}") - out_config["azure_client_id"] = azure_client_id - if azure_client_secret is not None or len(azure_client_secret) > 0: - logger.info(f"Setting Azure client secret to {azure_client_secret}") - out_config["azure_client_secret"] = azure_client_secret - if azure_subscription_id is not None or len(azure_subscription_id) > 0: - logger.info(f"Setting Azure subscription ID to {azure_subscription_id}") - out_config["azure_subscription_id"] = azure_subscription_id - - # GCP config - if gcp_application_credentials_file is not None and gcp_application_credentials_file.exists(): - logger.info(f"Setting GCP application credentials file to {gcp_application_credentials_file}") - out_config["gcp_application_credentials_file"] = str(gcp_application_credentials_file) - if gcp_project is not None or len(gcp_project) > 0: - logger.info(f"Setting GCP project ID to {gcp_project}") - out_config["gcp_project_id"] = gcp_project - - # write to config file - config_file.parent.mkdir(parents=True, exist_ok=True) - with config_file.open("w") as f: - json.dump(out_config, f) - typer.secho(f"Config: {out_config}", fg=typer.colors.GREEN) - typer.secho(f"Wrote config to {config_file}", fg=typer.colors.GREEN) +def init(reinit_azure: bool = False, reinit_gcp: bool = False): + print_header() + if config_path.exists(): + cloud_config = SkylarkConfig.load_config(config_path) + else: + cloud_config = SkylarkConfig() + + # load AWS config + typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) + cloud_config = load_aws_config(cloud_config) + + # load Azure config + typer.secho("\n(2) Configuring Azure:", fg="yellow", bold=True) + cloud_config = load_azure_config(cloud_config, force_init=reinit_azure) + + # load GCP config + typer.secho("\n(3) Configuring GCP:", fg="yellow", bold=True) + cloud_config = load_gcp_config(cloud_config, force_init=reinit_gcp) + + cloud_config.to_config_file(config_path) + typer.secho(f"\nConfig file saved to {config_path}", fg="green") return 0 diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 155bbc871..671e14a82 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -12,10 +12,12 @@ import questionary import typer from skylark import GB +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.aws.aws_server import AWSServer from skylark.obj_store.s3_interface import S3Interface from skylark.utils.utils import Timer, do_parallel +from skylark.utils import logger app = typer.Typer(name="skylark-aws") @@ -23,10 +25,16 @@ @app.command() def vcpu_limits(quota_code="L-1216C47A"): """List the vCPU limits for each region.""" + aws_auth = AWSAuthentication() def get_service_quota(region): - service_quotas = AWSServer.get_boto3_client("service-quotas", region) - response = service_quotas.get_service_quota(ServiceCode="ec2", QuotaCode=quota_code) + service_quotas = aws_auth.get_boto3_client("service-quotas", region) + try: + response = service_quotas.get_service_quota(ServiceCode="ec2", QuotaCode=quota_code) + except Exception as e: + logger.exception(e, print_traceback=False) + logger.error(f"Failed to get service quota for {quota_code} in {region}") + return -1 return response["Quota"]["Value"] quotas = do_parallel(get_service_quota, AWSCloudProvider.region_list()) @@ -41,7 +49,7 @@ def ssh(region: Optional[str] = None): instances = aws.get_matching_instances(region=region) if len(instances) == 0: typer.secho(f"No instances found", fg="red") - typer.Abort() + raise typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} choices = list(sorted(instance_map.keys())) @@ -57,10 +65,12 @@ def ssh(region: Optional[str] = None): @app.command() def cp_datasync(src_bucket: str, dst_bucket: str, path: str): - src_region = S3Interface.infer_s3_region(src_bucket) - dst_region = S3Interface.infer_s3_region(dst_bucket) + aws_auth = AWSAuthentication() + s3_interface = S3Interface("us-east-1", None) + src_region = s3_interface.infer_s3_region(src_bucket) + dst_region = s3_interface.infer_s3_region(dst_bucket) - iam_client = AWSServer.get_boto3_client("iam", "us-east-1") + iam_client = aws_auth.get_boto3_client("iam", "us-east-1") try: response = iam_client.get_role(RoleName="datasync-role") typer.secho("IAM role exists datasync-role", fg="green") @@ -81,14 +91,14 @@ def cp_datasync(src_bucket: str, dst_bucket: str, path: str): iam_arn = response["Role"]["Arn"] typer.secho(f"IAM role ARN: {iam_arn}", fg="green") - ds_client_src = AWSServer.get_boto3_client("datasync", src_region) + ds_client_src = aws_auth.get_boto3_client("datasync", src_region) src_response = ds_client_src.create_location_s3( S3BucketArn=f"arn:aws:s3:::{src_bucket}", Subdirectory=path, S3Config={"BucketAccessRoleArn": iam_arn}, ) src_s3_arn = src_response["LocationArn"] - ds_client_dst = AWSServer.get_boto3_client("datasync", dst_region) + ds_client_dst = aws_auth.get_boto3_client("datasync", dst_region) dst_response = ds_client_dst.create_location_s3( S3BucketArn=f"arn:aws:s3:::{dst_bucket}", Subdirectory=path, diff --git a/skylark/cli/cli_azure.py b/skylark/cli/cli_azure.py index 8f21c109f..97ea6130f 100644 --- a/skylark/cli/cli_azure.py +++ b/skylark/cli/cli_azure.py @@ -7,9 +7,7 @@ from typing import List import typer -from azure.identity import DefaultAzureCredential -from azure.mgmt.compute import ComputeManagementClient -from skylark.config import load_config +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.utils.utils import do_parallel @@ -18,20 +16,14 @@ @app.command() def get_valid_skus( - azure_subscription: str = typer.Option("", "--azure-subscription", help="Azure subscription ID"), regions: List[str] = typer.Option(AzureCloudProvider.region_list(), "--regions", "-r"), prefix: str = typer.Option("", "--prefix", help="Filter by prefix"), top_k: int = typer.Option(-1, "--top-k", help="Print top k entries"), ): - config = load_config() - azure_subscription = azure_subscription or config.get("azure_subscription_id") - typer.secho(f"Loaded from config file: azure_subscription={azure_subscription}", fg="blue") + auth = AzureAuthentication() + client = auth.get_compute_client() - credential = DefaultAzureCredential() - - # query azure API for each region to get available SKUs for each resource type def get_skus(region): - client = ComputeManagementClient(credential, azure_subscription) valid_skus = [] for sku in client.resource_skus.list(filter="location eq '{}'".format(region)): if len(sku.restrictions) == 0 and (not prefix or sku.name.startswith(prefix)): diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py index 56faeb1cf..78c9ec2a0 100644 --- a/skylark/cli/cli_gcp.py +++ b/skylark/cli/cli_gcp.py @@ -4,7 +4,6 @@ import questionary import typer -from skylark.config import load_config from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.gcp.gcp_server import GCPServer @@ -13,19 +12,13 @@ @app.command() -def ssh( - region: Optional[str] = None, - gcp_project: str = typer.Option("", "--gcp-project", help="GCP project ID"), -): - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - typer.secho(f"Loaded from config file: gcp_project={gcp_project}", fg="blue") - gcp = GCPCloudProvider(gcp_project) +def ssh(region: Optional[str] = None): + gcp = GCPCloudProvider() typer.secho("Querying GCP for instances", fg="green") instances = gcp.get_matching_instances(region=region) if len(instances) == 0: typer.secho(f"No instances found", fg="red") - typer.Abort() + raise typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} choices = list(sorted(instance_map.keys())) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 802143b1a..62d68f4c2 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -1,14 +1,21 @@ import concurrent.futures +from functools import partial import os import re import resource import subprocess from pathlib import Path from shutil import copyfile +from typing import Dict, List from sys import platform -from typing import Dict, List, Optional +from typing import Dict, List +import boto3 import typer +from skylark.compute.aws.aws_auth import AWSAuthentication +from skylark.compute.azure.azure_auth import AzureAuthentication +from skylark.compute.gcp.gcp_auth import GCPAuthentication +from skylark.config import SkylarkConfig from skylark.utils import logger from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider @@ -45,7 +52,7 @@ def parse_path(path: str): if match is None: raise ValueError(f"Invalid Azure path: {path}") account, container, blob_path = match.groups() - return "azure", account, container + return "azure", (account, container), blob_path elif path.startswith("azure://"): bucket_name = path[8:] region = path[8:].split("-", 2)[-1] @@ -90,7 +97,7 @@ def copy_local_local(src: Path, dst: Path): copyfile(src, dst) -def copy_local_objstore(object_interface: ObjectStoreInterface, src: Path, dst_bucket: str, dst_key: str): +def copy_local_objstore(object_interface: ObjectStoreInterface, src: Path, dst_key: str): ops: List[concurrent.futures.Future] = [] path_mapping: Dict[concurrent.futures.Future, Path] = {} @@ -114,7 +121,7 @@ def _copy(path: Path, dst_key: str, total_size=0.0): pbar.update(path_mapping[op].stat().st_size) -def copy_objstore_local(object_interface: ObjectStoreInterface, src_bucket: str, src_key: str, dst: Path): +def copy_objstore_local(object_interface: ObjectStoreInterface, src_key: str, dst: Path): ops: List[concurrent.futures.Future] = [] obj_mapping: Dict[concurrent.futures.Future, ObjectStoreObject] = {} @@ -142,24 +149,22 @@ def _copy(src_obj: ObjectStoreObject, dst: Path): def copy_local_gcs(src: Path, dst_bucket: str, dst_key: str): gcs = GCSInterface(None, dst_bucket) - return copy_local_objstore(gcs, src, dst_bucket, dst_key) + return copy_local_objstore(gcs, src, dst_key) def copy_gcs_local(src_bucket: str, src_key: str, dst: Path): gcs = GCSInterface(None, src_bucket) - return copy_objstore_local(gcs, src_bucket, src_key, dst) + return copy_objstore_local(gcs, src_key, dst) -def copy_local_azure(src: Path, dst_bucket: str, dst_key: str): - # Note that dst_key is infact azure region - azure = AzureInterface(dst_key, dst_bucket) - return copy_local_objstore(azure, src, dst_bucket, dst_key) +def copy_local_azure(src: Path, dst_account_name: str, dst_container_name: str, dst_key: str): + azure = AzureInterface(None, dst_account_name, dst_container_name) + return copy_local_objstore(azure, src, dst_key) -def copy_azure_local(src_bucket: str, src_key: str, dst: Path): - # Note that src_key is infact azure region - azure = AzureInterface(src_key, src_bucket) - return copy_objstore_local(azure, src_bucket, src_key, dst) +def copy_azure_local(src_account_name: str, src_container_name: str, src_key: str, dst: Path): + azure = AzureInterface(None, src_account_name, src_container_name) + return copy_objstore_local(azure, src_key, dst) def copy_local_s3(src: Path, dst_bucket: str, dst_key: str, use_tls: bool = True): @@ -231,7 +236,7 @@ def check_ulimit(hard_limit=1024 * 1024, soft_limit=1024 * 1024): f"Failed to increase ulimit to {soft_limit}, please set manually with 'ulimit -n {soft_limit}'. Current limit is {new_limit}", fg="red", ) - typer.Abort() + raise typer.Abort() else: typer.secho(f"Successfully increased ulimit to {new_limit}", fg="green") if current_limit_soft < soft_limit and (platform == "linux" or platform == "linux2"): @@ -242,31 +247,104 @@ def check_ulimit(hard_limit=1024 * 1024, soft_limit=1024 * 1024): subprocess.check_output(increase_soft_limit) -def deprovision_skylark_instances(azure_subscription: Optional[str] = None, gcp_project_id: Optional[str] = None): +def deprovision_skylark_instances(): instances = [] + query_jobs = [] + if AWSAuthentication().enabled(): + logger.debug("AWS authentication enabled, querying for instances") + aws = AWSCloudProvider() + for region in aws.region_list(): + query_jobs.append(partial(aws.get_matching_instances, region)) + if AzureAuthentication().enabled(): + logger.debug("Azure authentication enabled, querying for instances") + query_jobs.append(lambda: AzureCloudProvider().get_matching_instances()) + if GCPAuthentication().enabled(): + logger.debug("GCP authentication enabled, querying for instances") + query_jobs.append(lambda: GCPCloudProvider().get_matching_instances()) + + # query in parallel + for _, instance_list in do_parallel(lambda f: f(), query_jobs, progress_bar=True, desc="Query instances", hide_args=True): + instances.extend(instance_list) - aws = AWSCloudProvider() - for _, instance_list in do_parallel( - aws.get_matching_instances, aws.region_list(), progress_bar=True, leave_pbar=False, desc="Retrieve AWS instances" - ): - instances += instance_list - - if not azure_subscription: - typer.secho( - "No Microsoft Azure subscription given, so Azure instances will not be terminated", color=typer.colors.YELLOW, bold=True - ) + if instances: + typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True) + do_parallel(lambda instance: instance.terminate_instance(), instances, progress_bar=True, desc="Deprovisioning") else: - azure = AzureCloudProvider(azure_subscription=azure_subscription) - instances += azure.get_matching_instances() - - if not gcp_project_id: - typer.secho("No GCP project ID given, so GCP instances will not be deprovisioned", color=typer.colors.YELLOW, bold=True) + typer.secho("No instances to deprovision, exiting...", fg="yellow", bold=True) + + +def load_aws_config(config: SkylarkConfig) -> SkylarkConfig: + # get AWS credentials from boto3 + session = boto3.Session() + credentials = session.get_credentials() + credentials = credentials.get_frozen_credentials() + if credentials.access_key is None or credentials.secret_key is None: + typer.secho(" AWS credentials not found in boto3 session, please use the AWS CLI to set them via `aws configure`", fg="red") + typer.secho(" https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html", fg="red") + typer.secho(" Disabling AWS support", fg="blue") + return config + + typer.secho(f" Loaded AWS credentials from the AWS CLI [IAM access key ID: ...{credentials.access_key[-6:]}]", fg="blue") + return config + + +def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: + if force_init: + typer.secho(" Azure credentials will be re-initialized", fg="red") + config.azure_subscription_id = None + + if config.azure_subscription_id: + typer.secho(" Azure credentials already configured! To reconfigure Azure, run `skylark init --reinit-azure`.", fg="blue") + return config + + # check if Azure is enabled + auth = AzureAuthentication() + try: + auth.credential.get_token("https://management.azure.com/") + azure_enabled = True + except: + azure_enabled = False + if not azure_enabled: + typer.secho(" No local Azure credentials! Run `az login` to set them up.", fg="red") + typer.secho(" https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate", fg="red") + typer.secho(" Disabling Azure support", fg="blue") + return config + typer.secho(" Azure credentials found in Azure CLI", fg="blue") + inferred_subscription_id = AzureAuthentication.infer_subscription_id() + if typer.confirm(" Azure credentials found, do you want to enable Azure support in Skylark?", default=True): + config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:", default=inferred_subscription_id) else: - gcp = GCPCloudProvider(gcp_project=gcp_project_id) - instances += gcp.get_matching_instances() + config.azure_subscription_id = None + typer.secho(" Disabling Azure support", fg="blue") + return config - if instances: - typer.secho(f"Deprovisioning {len(instances)} instances", color=typer.colors.YELLOW, bold=True) - do_parallel(lambda instance: instance.terminate_instance(), instances, progress_bar=True, desc="Deprovisioning") + +def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: + if force_init: + typer.secho(" GCP credentials will be re-initialized", fg="red") + config.gcp_project_id = None + + if config.gcp_project_id is not None: + typer.secho(" GCP already configured! To reconfigure GCP, run `skylark init --reinit-gcp`.", fg="blue") + return config + + # check if GCP is enabled + auth = GCPAuthentication() + if not auth.credentials: + typer.secho( + " Default GCP credentials are not set up yet. Run `gcloud auth application-default login`.", + fg="red", + ) + typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") + typer.secho(" Disabling GCP support", fg="blue") + return config else: - typer.secho("No instances to deprovision, exiting...", color=typer.colors.YELLOW, bold=True) + typer.secho(" GCP credentials found in GCP CLI", fg="blue") + if typer.confirm(" GCP credentials found, do you want to enable GCP support in Skylark?", default=True): + config.gcp_project_id = typer.prompt(" Enter the GCP project ID:", default=auth.project_id) + assert config.gcp_project_id is not None, "GCP project ID must not be None" + return config + else: + config.gcp_project_id = None + typer.secho(" Disabling GCP support", fg="blue") + return config diff --git a/skylark/cli/experiments/throughput.py b/skylark/cli/experiments/throughput.py index ff2d65636..04c7fd955 100644 --- a/skylark/cli/experiments/throughput.py +++ b/skylark/cli/experiments/throughput.py @@ -10,7 +10,6 @@ import typer from skylark import GB, skylark_root from skylark.benchmark.utils import provision, split_list -from skylark.config import load_config from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider @@ -81,9 +80,6 @@ def throughput_grid( aws_instance_class: str = typer.Option("m5.8xlarge", help="AWS instance class to use"), azure_instance_class: str = typer.Option("Standard_D32_v5", help="Azure instance class to use"), gcp_instance_class: str = typer.Option("n2-standard-32", help="GCP instance class to use"), - # cloud options - gcp_project: Optional[str] = None, - azure_subscription: Optional[str] = None, # iperf3 options iperf3_runtime: int = typer.Option(5, help="Runtime for iperf3 in seconds"), iperf3_connections: int = typer.Option(64, help="Number of connections to test"), @@ -91,11 +87,6 @@ def throughput_grid( def check_stderr(tup): assert tup[1].strip() == "", f"Command failed, err: {tup[1]}" - config = load_config() - gcp_project = gcp_project or config.get("gcp_project_id") - azure_subscription = azure_subscription or config.get("azure_subscription_id") - logger.debug(f"Loaded from config file: gcp_project={gcp_project}, azure_subscription={azure_subscription}") - if resume: index_key = [ "iperf3_connections", @@ -118,21 +109,21 @@ def check_stderr(tup): gcp_region_list = gcp_region_list if enable_gcp else [] if not enable_aws and not enable_azure and not enable_gcp: logger.error("At least one of -aws, -azure, -gcp must be enabled.") - typer.Abort() + raise typer.Abort() # validate AWS regions if not enable_aws: aws_region_list = [] elif not all(r in all_aws_regions for r in aws_region_list): logger.error(f"Invalid AWS region list: {aws_region_list}") - typer.Abort() + raise typer.Abort() # validate Azure regions if not enable_azure: azure_region_list = [] elif not all(r in all_azure_regions for r in azure_region_list): logger.error(f"Invalid Azure region list: {azure_region_list}") - typer.Abort() + raise typer.Abort() # validate GCP regions assert not enable_gcp_standard or enable_gcp, f"GCP is disabled but GCP standard is enabled" @@ -140,19 +131,19 @@ def check_stderr(tup): gcp_region_list = [] elif not all(r in all_gcp_regions for r in gcp_region_list): logger.error(f"Invalid GCP region list: {gcp_region_list}") - typer.Abort() + raise typer.Abort() # validate GCP standard instances if not enable_gcp_standard: gcp_standard_region_list = [] if not all(r in all_gcp_regions_standard for r in gcp_standard_region_list): logger.error(f"Invalid GCP standard region list: {gcp_standard_region_list}") - typer.Abort() + raise typer.Abort() # provision servers aws = AWSCloudProvider() - azure = AzureCloudProvider(azure_subscription) - gcp = GCPCloudProvider(gcp_project) + azure = AzureCloudProvider() + gcp = GCPCloudProvider() aws_instances, azure_instances, gcp_instances = provision( aws=aws, azure=azure, diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py new file mode 100644 index 000000000..4010e9daa --- /dev/null +++ b/skylark/compute/aws/aws_auth.py @@ -0,0 +1,65 @@ +import threading +from typing import Optional + +import boto3 + + +class AWSAuthentication: + __cached_credentials = threading.local() + + def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = None): + """Loads AWS authentication details. If no access key is provided, it will try to load credentials using boto3""" + if access_key and secret_key: + self.config_mode = "manual" + self._access_key = access_key + self._secret_key = secret_key + else: + self.config_mode = "iam_inferred" + self._access_key = None + self._secret_key = None + + @property + def access_key(self): + if self._access_key is None: + self._access_key, self._secret_key = self.infer_credentials() + return self._access_key + + @property + def secret_key(self): + if self._secret_key is None: + self._access_key, self._secret_key = self.infer_credentials() + return self._secret_key + + def enabled(self): + return self.config_mode != "disabled" + + def infer_credentials(self): + # todo load temporary credentials from STS + cached_credential = getattr(self.__cached_credentials, "boto3_credential", None) + if cached_credential == None: + session = boto3.Session() + credentials = session.get_credentials() + if credentials: + credentials = credentials.get_frozen_credentials() + cached_credential = (credentials.access_key, credentials.secret_key) + setattr(self.__cached_credentials, "boto3_credential", cached_credential) + return cached_credential if cached_credential else (None, None) + + def get_boto3_session(self, aws_region: str): + if self.config_mode == "manual": + return boto3.Session( + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + region_name=aws_region, + ) + else: + return boto3.Session(region_name=aws_region) + + def get_boto3_resource(self, service_name, aws_region=None): + return self.get_boto3_session(aws_region).resource(service_name, region_name=aws_region) + + def get_boto3_client(self, service_name, aws_region=None): + if aws_region is None: + return self.get_boto3_session(aws_region).client(service_name) + else: + return self.get_boto3_session(aws_region).client(service_name, region_name=aws_region) diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 16f820653..885bd4aa9 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -1,20 +1,23 @@ +import json import uuid from typing import List, Optional import botocore import pandas as pd +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.utils import logger from oslo_concurrency import lockutils from skylark import skylark_root from skylark.compute.aws.aws_server import AWSServer from skylark.compute.cloud_providers import CloudProvider -from skylark.utils.utils import retry_backoff +from skylark.utils.utils import retry_backoff, wait_for class AWSCloudProvider(CloudProvider): def __init__(self): super().__init__() + self.auth = AWSAuthentication() @property def name(self): @@ -66,13 +69,11 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): src_rows = transfer_df.loc[src] src_rows = src_rows[src_rows.index != "internet"] return src_rows.max()["cost"] - elif dst_provider == "gcp" or dst_provider == "azure": - return transfer_df.loc[src, "internet"]["cost"] else: - raise NotImplementedError + return transfer_df.loc[src, "internet"]["cost"] def get_instance_list(self, region: str) -> List[AWSServer]: - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) valid_states = ["pending", "running", "stopped", "stopping"] instances = ec2.instances.filter(Filters=[{"Name": "instance-state-name", "Values": valid_states}]) try: @@ -84,7 +85,7 @@ def get_instance_list(self, region: str) -> List[AWSServer]: return [AWSServer(f"aws:{region}", i) for i in instance_ids] def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) assert len(vpcs) == 1, f"Found {len(vpcs)} vpcs with name {vpc_name}" sgs = [sg for sg in vpcs[0].security_groups.all() if sg.group_name == sg_name] @@ -92,7 +93,7 @@ def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark") return sgs[0] def get_vpc(self, region: str, vpc_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) if len(vpcs) == 0: return None @@ -100,7 +101,7 @@ def get_vpc(self, region: str, vpc_name="skylark"): return vpcs[0] def make_vpc(self, region: str, vpc_name="skylark"): - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) ec2client = ec2.meta.client vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) @@ -154,7 +155,7 @@ def make_vpc(self, region: str, vpc_name="skylark"): def delete_vpc(self, region: str, vpcid: str): """Delete VPC, from https://gist.github.com/vernhart/c6a0fc94c0aeaebe84e5cd6f3dede4ce""" logger.warning(f"[{region}] Deleting VPC {vpcid}") - ec2 = AWSServer.get_boto3_resource("ec2", region) + ec2 = self.auth.get_boto3_resource("ec2", region) ec2client = ec2.meta.client vpc = ec2.Vpc(vpcid) # detach and delete all gateways associated with the vpc @@ -194,6 +195,27 @@ def delete_vpc(self, region: str, vpcid: str): # finally, delete the vpc ec2client.delete_vpc(VpcId=vpcid) + def create_iam(self, iam_name: str = "skylark_gateway", attach_policy_arn: Optional[str] = None): + """Create IAM role if it doesn't exist and grant managed role if given.""" + + @lockutils.synchronized(f"aws_create_iam_{iam_name}", external=True, lock_path="/tmp/skylark_locks") + def fn(): + iam = self.auth.get_boto3_client("iam") + + # create IAM role + try: + iam.get_role(RoleName=iam_name) + except iam.exceptions.NoSuchEntityException: + doc = { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Principal": {"Service": "ec2.amazonaws.com"}, "Action": "sts:AssumeRole"}], + } + iam.create_role(RoleName=iam_name, AssumeRolePolicyDocument=json.dumps(doc), Tags=[{"Key": "skylark", "Value": "true"}]) + if attach_policy_arn: + iam.attach_role_policy(RoleName=iam_name, PolicyArn=attach_policy_arn) + + return fn() + def add_ip_to_security_group(self, aws_region: str): """Add IP to security group. If security group ID is None, use group named skylark (create if not exists).""" @@ -219,20 +241,40 @@ def provision_instance( # ami_id: Optional[str] = None, tags={"skylark": "true"}, ebs_volume_size: int = 128, + iam_name: str = "skylark_gateway", ) -> AWSServer: assert not region.startswith("aws:"), "Region should be AWS region" if name is None: name = f"skylark-aws-{str(uuid.uuid4()).replace('-', '')}" - ec2 = AWSServer.get_boto3_resource("ec2", region) - AWSServer.ensure_keyfile_exists(region) - + iam_instance_profile_name = f"{name}_profile" + iam = self.auth.get_boto3_client("iam") + ec2 = self.auth.get_boto3_resource("ec2", region) vpc = self.get_vpc(region) assert vpc is not None, "No VPC found" subnets = list(vpc.subnets.all()) assert len(subnets) > 0, "No subnets found" + def check_iam_role(): + try: + iam.get_role(RoleName=iam_name) + return True + except iam.exceptions.NoSuchEntityException: + return False + + def check_instance_profile(): + try: + iam.get_instance_profile(InstanceProfileName=iam_instance_profile_name) + return True + except iam.exceptions.NoSuchEntityException: + return False + + # wait for iam_role to be created and create instance profile + wait_for(check_iam_role, timeout=60, interval=0.5) + iam.create_instance_profile(InstanceProfileName=iam_instance_profile_name, Tags=[{"Key": "skylark", "Value": "true"}]) + iam.add_role_to_instance_profile(InstanceProfileName=iam_instance_profile_name, RoleName=iam_name) + wait_for(check_instance_profile, timeout=60, interval=0.5) + def start_instance(): - # todo instance-initiated-shutdown-behavior terminate return ec2.create_instances( ImageId="resolve:ssm:/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id", InstanceType=instance_class, @@ -260,6 +302,8 @@ def start_instance(): "DeleteOnTermination": True, } ], + IamInstanceProfile={"Name": iam_instance_profile_name}, + InstanceInitiatedShutdownBehavior="terminate", ) instance = retry_backoff(start_instance, initial_backoff=1) diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 276bd5fe7..c80eaaf84 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -4,6 +4,7 @@ import boto3 import paramiko +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.utils import logger from oslo_concurrency import lockutils @@ -18,6 +19,7 @@ class AWSServer(Server): def __init__(self, region_tag, instance_id, log_dir=None): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "aws" + self.auth = AWSAuthentication() self.aws_region = self.region_tag.split(":")[1] self.instance_id = instance_id self.boto3_session = boto3.Session(region_name=self.aws_region) @@ -26,40 +28,22 @@ def __init__(self, region_tag, instance_id, log_dir=None): def uuid(self): return f"{self.region_tag}:{self.instance_id}" - @classmethod - def get_boto3_session(cls, aws_region) -> boto3.Session: - return boto3.Session(region_name=aws_region) - - @classmethod - def get_boto3_resource(cls, service_name, aws_region=None): - return cls.get_boto3_session(aws_region).resource(service_name, region_name=aws_region) - - @classmethod - def get_boto3_client(cls, service_name, aws_region=None): - if aws_region is None: - return cls.get_boto3_session(aws_region).client(service_name) - else: - return cls.get_boto3_session(aws_region).client(service_name, region_name=aws_region) - def get_boto3_instance_resource(self): - ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) + ec2 = self.auth.get_boto3_resource("ec2", self.aws_region) return ec2.Instance(self.instance_id) - @staticmethod - def ensure_keyfile_exists(aws_region, prefix=key_root / "aws"): + def ensure_keyfile_exists(self, aws_region, prefix=key_root / "aws"): + ec2 = self.auth.get_boto3_resource("ec2", aws_region) + ec2_client = self.auth.get_boto3_client("ec2", aws_region) prefix = Path(prefix) key_name = f"skylark-{aws_region}" local_key_file = prefix / f"{key_name}.pem" @lockutils.synchronized(f"aws_keyfile_lock_{aws_region}", external=True, lock_path="/tmp/skylark_locks") def create_keyfile(): - if not local_key_file.exists(): # we have to check again since another process may have created it - ec2 = AWSServer.get_boto3_resource("ec2", aws_region) - ec2_client = AWSServer.get_boto3_client("ec2", aws_region) + if not local_key_file.exists(): local_key_file.parent.mkdir(parents=True, exist_ok=True) - # delete key pair from ec2 if it exists - keys_in_region = set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]) - if key_name in keys_in_region: + if key_name in set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]): logger.warning(f"Deleting key {key_name} in region {aws_region}") ec2_client.delete_key_pair(KeyName=key_name) key_pair = ec2.create_key_pair(KeyName=f"skylark-{aws_region}", KeyType="rsa") @@ -68,7 +52,6 @@ def create_keyfile(): if not key_str.endswith("\n"): key_str += "\n" f.write(key_str) - f.flush() # sometimes generates keys with zero bytes, so we flush to ensure it's written os.chmod(local_key_file, 0o600) logger.info(f"Created key file {local_key_file}") @@ -107,8 +90,7 @@ def __repr__(self): return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})" def terminate_instance_impl(self): - ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) - ec2.instances.filter(InstanceIds=[self.instance_id]).terminate() + self.auth.get_boto3_resource("ec2", self.aws_region).instances.filter(InstanceIds=[self.instance_id]).terminate() def get_ssh_client_impl(self): client = paramiko.SSHClient() @@ -116,6 +98,7 @@ def get_ssh_client_impl(self): client.connect( self.public_ip(), username="ec2-user", + # todo generate keys with password "skylark" pkey=paramiko.RSAKey.from_private_key_file(str(self.local_keyfile)), look_for_keys=False, allow_agent=False, diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py new file mode 100644 index 000000000..05032875c --- /dev/null +++ b/skylark/compute/azure/azure_auth.py @@ -0,0 +1,70 @@ +import os +import subprocess +import threading +from typing import Optional +from azure.identity import DefaultAzureCredential +from azure.mgmt.compute import ComputeManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.authorization import AuthorizationManagementClient +from azure.mgmt.storage import StorageManagementClient +from azure.storage.blob import BlobServiceClient, ContainerClient + +from skylark import cloud_config +from skylark.compute.utils import query_which_cloud + + +class AzureAuthentication: + __cached_credentials = threading.local() + + def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): + self.subscription_id = subscription_id + self.credential = self.get_credential(subscription_id) + + def get_credential(self, subscription_id: str): + cached_credential = getattr(self.__cached_credentials, f"credential_{subscription_id}", None) + if cached_credential is None: + cached_credential = DefaultAzureCredential( + exclude_managed_identity_credential=query_which_cloud() != "azure", # exclude MSI if not Azure + exclude_powershell_credential=True, + exclude_visual_studio_code_credential=True, + ) + setattr(self.__cached_credentials, f"credential_{subscription_id}", cached_credential) + return cached_credential + + def enabled(self) -> bool: + return self.subscription_id is not None + + @staticmethod + def infer_subscription_id() -> Optional[str]: + if "AZURE_SUBSCRIPTION_ID" in os.environ: + return os.environ["AZURE_SUBSCRIPTION_ID"] + else: + try: + return subprocess.check_output(["az", "account", "show", "--query", "id"]).decode("utf-8").replace('"', "").strip() + except subprocess.CalledProcessError: + return None + + def get_token(self, resource: str): + return self.credential.get_token(resource) + + def get_compute_client(self): + return ComputeManagementClient(self.credential, self.subscription_id) + + def get_resource_client(self): + return ResourceManagementClient(self.credential, self.subscription_id) + + def get_network_client(self): + return NetworkManagementClient(self.credential, self.subscription_id) + + def get_authorization_client(self): + return AuthorizationManagementClient(self.credential, self.subscription_id) + + def get_storage_management_client(self): + return StorageManagementClient(self.credential, self.subscription_id) + + def get_container_client(self, account_url: str, container_name: str): + return ContainerClient(account_url, container_name, credential=self.credential) + + def get_blob_service_client(self, account_url: str): + return BlobServiceClient(account_url=account_url, credential=self.credential) diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index 73988ae96..179d13354 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -5,34 +5,23 @@ from typing import List, Optional import paramiko -from skylark.config import load_config +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.utils import logger from skylark import key_root from skylark.compute.azure.azure_server import AzureServer from skylark.compute.cloud_providers import CloudProvider +from azure.mgmt.authorization.models import RoleAssignmentCreateParameters -from azure.identity import DefaultAzureCredential, ClientSecretCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.compute.models import ResourceIdentityType -from skylark.utils.utils import do_parallel +from skylark.utils.utils import Timer, do_parallel class AzureCloudProvider(CloudProvider): - def __init__(self, azure_subscription, key_root=key_root / "azure", read_credential=True): + def __init__(self, key_root=key_root / "azure"): super().__init__() - if read_credential: - config = load_config() - self.subscription_id = azure_subscription if azure_subscription is not None else config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) - else: - self.credential = DefaultAzureCredential() - self.subscription_id = azure_subscription + self.auth = AzureAuthentication() + key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "azure_key" self.public_key_path = key_root / "azure_key.pub" @@ -209,6 +198,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): dst_provider, dst_region = dst_key.split(":") assert src_provider == "azure" if not premium_tier: + # TODO: tracked in https://github.com/parasj/skylark/issues/59 return NotImplementedError() src_continent = AzureCloudProvider.lookup_continent(src_region) @@ -260,14 +250,12 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): raise ValueError(f"Unknown transfer cost for {src_key} -> {dst_key}") def get_instance_list(self, region: str) -> List[AzureServer]: - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - + compute_client = self.auth.get_compute_client() server_list = [] for vm in compute_client.virtual_machines.list(AzureServer.resource_group_name): if vm.tags.get("skylark", None) == "true" and AzureServer.is_valid_vm_name(vm.name) and vm.location == region: name = AzureServer.base_name_from_vm_name(vm.name) - s = AzureServer(self.subscription_id, name) + s = AzureServer(name) if s.is_valid(): server_list.append(s) else: @@ -287,18 +275,17 @@ def create_ssh_key(self): f.write(f"{key.get_name()} {key.get_base64()}\n") def set_up_resource_group(self, clean_up_orphans=True): - credential = self.credential - resource_client = ResourceManagementClient(credential, self.subscription_id) + resource_client = self.auth.get_resource_client() if resource_client.resource_groups.check_existence(AzureServer.resource_group_name): # Resource group already exists. # Take this moment to search for orphaned resources and clean them up... - network_client = NetworkManagementClient(credential, self.subscription_id) + network_client = self.auth.get_network_client() if clean_up_orphans: instances_to_terminate = [] for vnet in network_client.virtual_networks.list(AzureServer.resource_group_name): if vnet.tags.get("skylark", None) == "true" and AzureServer.is_valid_vnet_name(vnet.name): name = AzureServer.base_name_from_vnet_name(vnet.name) - s = AzureServer(self.subscription_id, name, assume_exists=False) + s = AzureServer(name, assume_exists=False) if not s.is_valid(): logger.warning(f"Cleaning up orphaned Azure resources for {name}...") instances_to_terminate.append(s) @@ -327,9 +314,8 @@ def provision_instance( pub_key = f.read() # Prepare for making Microsoft Azure API calls - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - network_client = NetworkManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() + network_client = self.auth.get_network_client() # Use the common resource group for this instance resource_group = AzureServer.resource_group_name @@ -337,108 +323,137 @@ def provision_instance( # TODO: On future requests to create resources, check if a resource # with that name already exists, and fail the operation if so - # Create a Virtual Network for this instance - poller = network_client.virtual_networks.begin_create_or_update( - resource_group, - AzureServer.vnet_name(name), - {"location": location, "tags": {"skylark": "true"}, "address_space": {"address_prefixes": ["10.0.0.0/24"]}}, - ) - poller.result() - - # Create a Network Security Group for this instance - # NOTE: This is insecure. We should fix this soon. - poller = network_client.network_security_groups.begin_create_or_update( - resource_group, - AzureServer.nsg_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "security_rules": [ - { - "name": name + "-allow-all", - "protocol": "Tcp", - "source_port_range": "*", - "source_address_prefix": "*", - "destination_port_range": "*", - "destination_address_prefix": "*", - "access": "Allow", - "priority": 300, - "direction": "Inbound", - } - # Azure appears to add default rules for outbound connections - ], - }, - ) - nsg_result = poller.result() - - # Create a subnet for this instance with the above Network Security Group - subnet_poller = network_client.subnets.begin_create_or_update( - resource_group, - AzureServer.vnet_name(name), - AzureServer.subnet_name(name), - {"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}}, - ) + with Timer("Creating Azure network"): + # Create a Virtual Network for this instance + poller = network_client.virtual_networks.begin_create_or_update( + resource_group, + AzureServer.vnet_name(name), + {"location": location, "tags": {"skylark": "true"}, "address_space": {"address_prefixes": ["10.0.0.0/24"]}}, + ) + poller.result() - # Create a public IPv4 address for this instance - ip_poller = network_client.public_ip_addresses.begin_create_or_update( - resource_group, - AzureServer.ip_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "sku": {"name": "Standard"}, - "public_ip_allocation_method": "Static", - "public_ip_address_version": "IPV4", - }, - ) + # Create a Network Security Group for this instance + # NOTE: This is insecure. We should fix this soon. + poller = network_client.network_security_groups.begin_create_or_update( + resource_group, + AzureServer.nsg_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "security_rules": [ + { + "name": name + "-allow-all", + "protocol": "Tcp", + "source_port_range": "*", + "source_address_prefix": "*", + "destination_port_range": "*", + "destination_address_prefix": "*", + "access": "Allow", + "priority": 300, + "direction": "Inbound", + } + # Azure appears to add default rules for outbound connections + ], + }, + ) + nsg_result = poller.result() - subnet_result = subnet_poller.result() - public_ip_result = ip_poller.result() - - # Create a NIC for this instance, with accelerated networking enabled - poller = network_client.network_interfaces.begin_create_or_update( - resource_group, - AzureServer.nic_name(name), - { - "location": location, - "tags": {"skylark": "true"}, - "ip_configurations": [ - {"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}} - ], - "enable_accelerated_networking": True, - }, - ) - nic_result = poller.result() + # Create a subnet for this instance with the above Network Security Group + subnet_poller = network_client.subnets.begin_create_or_update( + resource_group, + AzureServer.vnet_name(name), + AzureServer.subnet_name(name), + {"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}}, + ) - # Create the VM - with self.provisioning_semaphore: - poller = compute_client.virtual_machines.begin_create_or_update( + # Create a public IPv4 address for this instance + ip_poller = network_client.public_ip_addresses.begin_create_or_update( + resource_group, + AzureServer.ip_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "sku": {"name": "Standard"}, + "public_ip_allocation_method": "Static", + "public_ip_address_version": "IPV4", + }, + ) + + subnet_result = subnet_poller.result() + public_ip_result = ip_poller.result() + + with Timer("Creating Azure NIC"): + # Create a NIC for this instance, with accelerated networking enabled + poller = network_client.network_interfaces.begin_create_or_update( resource_group, - AzureServer.vm_name(name), + AzureServer.nic_name(name), { "location": location, "tags": {"skylark": "true"}, - "hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)}, - "storage_profile": { - "image_reference": { - "publisher": "canonical", - "offer": "0001-com-ubuntu-server-focal", - "sku": "20_04-lts", - "version": "latest", + "ip_configurations": [ + {"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}} + ], + "enable_accelerated_networking": True, + }, + ) + nic_result = poller.result() + + # Create the VM + with Timer("Creating Azure VM"): + with self.provisioning_semaphore: + poller = compute_client.virtual_machines.begin_create_or_update( + resource_group, + AzureServer.vm_name(name), + { + "location": location, + "tags": {"skylark": "true"}, + "hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)}, + "storage_profile": { + # "image_reference": { + # "publisher": "canonical", + # "offer": "0001-com-ubuntu-server-focal", + # "sku": "20_04-lts", + # "version": "latest", + # }, + "image_reference": { + "publisher": "microsoft-aks", + "offer": "aks", + "sku": "aks-engine-ubuntu-1804-202112", + "version": "latest", + }, + "os_disk": {"create_option": "FromImage", "delete_option": "Delete"}, }, - "os_disk": {"create_option": "FromImage", "delete_option": "Delete"}, - }, - "os_profile": { - "computer_name": AzureServer.vm_name(name), - "admin_username": uname, - "linux_configuration": { - "disable_password_authentication": True, - "ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]}, + "os_profile": { + "computer_name": AzureServer.vm_name(name), + "admin_username": uname, + "linux_configuration": { + "disable_password_authentication": True, + "ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]}, + }, }, + "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, + # give VM managed identity w/ system assigned identity + "identity": {"type": ResourceIdentityType.system_assigned}, }, - "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, - }, + ) + vm_result = poller.result() + + with Timer("Role assignment"): + # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment + # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets + auth_client = self.auth.get_authorization_client() + scope = f"/subscriptions/{self.auth.subscription_id}" + role_name = "Contributor" + roles = list(auth_client.role_definitions.list(scope, filter="roleName eq '{}'".format(role_name))) + assert len(roles) == 1 + + # Add RG scope to the MSI identities: + role_assignment = auth_client.role_assignments.create( + scope, + uuid.uuid4(), # Role assignment random name + RoleAssignmentCreateParameters( + properties=dict(role_definition_id=roles[0].id, principal_id=vm_result.identity.principal_id) + ), ) - poller.result() - return AzureServer(self.subscription_id, name) + return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 4b7ae6f37..99e249690 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -1,19 +1,14 @@ import os from pathlib import Path -from typing import Optional import paramiko from skylark import key_root -from skylark.config import load_config +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike import azure.core.exceptions -from azure.identity import DefaultAzureCredential, ClientSecretCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient class AzureServer(Server): @@ -22,25 +17,13 @@ class AzureServer(Server): def __init__( self, - subscription_id: Optional[str], name: str, key_root: PathLike = key_root / "azure", log_dir=None, ssh_private_key=None, - read_credential=True, assume_exists=True, ): - if read_credential: - config = load_config() - self.subscription_id = subscription_id if subscription_id is not None else config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) - else: - self.credential = DefaultAzureCredential() - self.subscription_id = subscription_id + self.auth = AzureAuthentication() self.name = name self.location = None @@ -105,8 +88,7 @@ def nic_name(name): return AzureServer.vm_name(name) + "-nic" def get_virtual_machine(self): - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() vm = compute_client.virtual_machines.get(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) # Sanity checks @@ -123,11 +105,10 @@ def is_valid(self): return False def uuid(self): - return f"{self.subscription_id}:{self.region_tag}:{self.name}" + return f"{self.region_tag}:{self.name}" def instance_state(self) -> ServerState: - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) + compute_client = self.auth.get_compute_client() vm_instance_view = compute_client.virtual_machines.instance_view(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) statuses = vm_instance_view.statuses for status in statuses: @@ -137,8 +118,7 @@ def instance_state(self) -> ServerState: @ignore_lru_cache() def public_ip(self): - credential = self.credential - network_client = NetworkManagementClient(credential, self.subscription_id) + network_client = self.auth.get_network_client() public_ip = network_client.public_ip_addresses.get(AzureServer.resource_group_name, AzureServer.ip_name(self.name)) # Sanity checks @@ -166,28 +146,22 @@ def network_tier(self): return "PREMIUM" def terminate_instance_impl(self): - credential = self.credential - compute_client = ComputeManagementClient(credential, self.subscription_id) - network_client = NetworkManagementClient(credential, self.subscription_id) - + compute_client = self.auth.get_compute_client() + network_client = self.auth.get_network_client() vm_poller = compute_client.virtual_machines.begin_delete(AzureServer.resource_group_name, self.vm_name(self.name)) - _ = vm_poller.result() - + vm_poller.result() nic_poller = network_client.network_interfaces.begin_delete(AzureServer.resource_group_name, self.nic_name(self.name)) - _ = nic_poller.result() - ip_poller = network_client.public_ip_addresses.begin_delete(AzureServer.resource_group_name, self.ip_name(self.name)) subnet_poller = network_client.subnets.begin_delete( AzureServer.resource_group_name, self.vnet_name(self.name), self.subnet_name(self.name) ) - _ = ip_poller.result() - _ = subnet_poller.result() - nsg_poller = network_client.network_security_groups.begin_delete(AzureServer.resource_group_name, self.nsg_name(self.name)) - _ = nsg_poller.result() - vnet_poller = network_client.virtual_networks.begin_delete(AzureServer.resource_group_name, self.vnet_name(self.name)) - _ = vnet_poller.result() + nsg_poller.result() + ip_poller.result() + subnet_poller.result() + nic_poller.result() + vnet_poller.result() def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" diff --git a/skylark/compute/cloud_providers.py b/skylark/compute/cloud_providers.py index b37d04cc2..69e7d8402 100644 --- a/skylark/compute/cloud_providers.py +++ b/skylark/compute/cloud_providers.py @@ -44,13 +44,12 @@ def get_matching_instances( tags={"skylark": "true"}, ) -> List[Server]: if isinstance(region, str): - region = [region] + results = [(region, self.get_instance_list(region))] elif region is None: - region = self.region_list() + results = do_parallel(self.get_instance_list, self.region_list(), n=-1) - results = do_parallel(self.get_instance_list, region, n=-1) matching_instances = [] - for r, instances in results: + for _, instances in results: for instance in instances: if not (instance_type is None or instance_type == instance.instance_class()): continue diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py new file mode 100644 index 000000000..641f90a91 --- /dev/null +++ b/skylark/compute/gcp/gcp_auth.py @@ -0,0 +1,44 @@ +import threading +from typing import Optional +import googleapiclient.discovery +import google.auth + +from skylark import cloud_config + + +class GCPAuthentication: + __cached_credentials = threading.local() + + def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): + # load credentials lazily and then cache across threads + self.inferred_project_id = project_id + self._credentials = None + self._project_id = None + + @property + def credentials(self): + if self._credentials is None: + self._credentials, self._project_id = self.make_credential(self.inferred_project_id) + return self._credentials + + @property + def project_id(self): + if self._project_id is None: + self._credentials, self._project_id = self.make_credential(self.inferred_project_id) + return self._project_id + + def make_credential(self, project_id): + cached_credential = getattr(self.__cached_credentials, f"credential_{project_id}", (None, None)) + if cached_credential == (None, None): + cached_credential = google.auth.default(quota_project_id=project_id) + setattr(self.__cached_credentials, f"credential_{project_id}", cached_credential) + return cached_credential + + def enabled(self): + return self.credentials is not None and self.project_id is not None + + def get_gcp_client(self, service_name="compute", version="v1"): + return googleapiclient.discovery.build(service_name, version) + + def get_gcp_instances(self, gcp_region: str): + return self.get_gcp_client().instances().list(project=self.project_id, zone=gcp_region).execute() diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 872d2513f..4aad3804e 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -6,6 +6,7 @@ import googleapiclient import paramiko +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.utils import logger from oslo_concurrency import lockutils @@ -17,9 +18,9 @@ class GCPCloudProvider(CloudProvider): - def __init__(self, gcp_project, key_root=key_root / "gcp"): + def __init__(self, key_root=key_root / "gcp"): super().__init__() - self.gcp_project = gcp_project + self.auth = GCPAuthentication() key_root.mkdir(parents=True, exist_ok=True) self.private_key_path = key_root / "gcp-cert.pem" self.public_key_path = key_root / "gcp-cert.pub" @@ -151,11 +152,11 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): raise ValueError("Unknown src_continent: {}".format(src_continent)) def get_instance_list(self, region) -> List[GCPServer]: - gcp_instance_result = GCPServer.gcp_instances(self.gcp_project, region) + gcp_instance_result = self.auth.get_gcp_instances(region) if "items" in gcp_instance_result: instance_list = [] for i in gcp_instance_result["items"]: - instance_list.append(GCPServer(f"gcp:{region}", self.gcp_project, i["name"], ssh_private_key=self.private_key_path)) + instance_list.append(GCPServer(f"gcp:{region}", i["name"], ssh_private_key=self.private_key_path)) return instance_list else: return [] @@ -178,14 +179,14 @@ def create_ssh_key(self): f.write(f"{key.get_name()} {key.get_base64()}\n") def configure_default_network(self): - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() try: - compute.networks().get(project=self.gcp_project, network="default").execute() + compute.networks().get(project=self.auth.project_id, network="default").execute() except googleapiclient.errors.HttpError as e: if e.resp.status == 404: # create network op = ( compute.networks() - .insert(project=self.gcp_project, body={"name": "default", "subnetMode": "auto", "autoCreateSubnetworks": True}) + .insert(project=self.auth.project_id, body={"name": "default", "subnetMode": "auto", "autoCreateSubnetworks": True}) .execute() ) self.wait_for_operation_to_complete("global", op["name"]) @@ -194,18 +195,18 @@ def configure_default_network(self): def configure_default_firewall(self, ip="0.0.0.0/0"): """Configure default firewall to allow access from all ports from all IPs (if not exists).""" - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() @lockutils.synchronized(f"gcp_configure_default_firewall", external=True, lock_path="/tmp/skylark_locks") def create_firewall(body, update_firewall=False): if update_firewall: - op = compute.firewalls().update(project=self.gcp_project, firewall="default", body=fw_body).execute() + op = compute.firewalls().update(project=self.auth.project_id, firewall="default", body=fw_body).execute() else: - op = compute.firewalls().insert(project=self.gcp_project, body=fw_body).execute() + op = compute.firewalls().insert(project=self.auth.project_id, body=fw_body).execute() self.wait_for_operation_to_complete("global", op["name"]) try: - current_firewall = compute.firewalls().get(project=self.gcp_project, firewall="default").execute() + current_firewall = compute.firewalls().get(project=self.auth.project_id, firewall="default").execute() except googleapiclient.errors.HttpError as e: if e.resp.status == 404: current_firewall = None @@ -228,11 +229,11 @@ def create_firewall(body, update_firewall=False): logger.debug(f"[GCP] Updated firewall") def get_operation_state(self, zone, operation_name): - compute = GCPServer.get_gcp_client() + compute = self.auth.get_gcp_client() if zone == "global": - return compute.globalOperations().get(project=self.gcp_project, operation=operation_name).execute() + return compute.globalOperations().get(project=self.auth.project_id, operation=operation_name).execute() else: - return compute.zoneOperations().get(project=self.gcp_project, zone=zone, operation=operation_name).execute() + return compute.zoneOperations().get(project=self.auth.project_id, zone=zone, operation=operation_name).execute() def wait_for_operation_to_complete(self, zone, operation_name, timeout=120): time_intervals = [0.1] * 10 + [0.2] * 10 + [1.0] * int(timeout) # backoff @@ -252,7 +253,7 @@ def provision_instance( assert not region.startswith("gcp:"), "Region should be GCP region" if name is None: name = f"skylark-gcp-{str(uuid.uuid4()).replace('-', '')}" - compute = GCPServer.get_gcp_client("compute", "v1") + compute = self.auth.get_gcp_client() with open(os.path.expanduser(self.public_key_path)) as f: pub_key = f.read() @@ -286,9 +287,9 @@ def provision_instance( ] }, } - result = compute.instances().insert(project=self.gcp_project, zone=region, body=req_body).execute() + result = compute.instances().insert(project=self.auth.project_id, zone=region, body=req_body).execute() self.wait_for_operation_to_complete(region, result["name"]) - server = GCPServer(f"gcp:{region}", self.gcp_project, name) + server = GCPServer(f"gcp:{region}", name) server.wait_for_ready() server.run_command("sudo /sbin/iptables -A INPUT -j ACCEPT") return server diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index c3d919742..e90fbb782 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -1,9 +1,9 @@ from functools import lru_cache from pathlib import Path -import googleapiclient.discovery import paramiko from skylark import key_root +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike @@ -13,7 +13,6 @@ class GCPServer(Server): def __init__( self, region_tag: str, - gcp_project: str, instance_name: str, key_root: PathLike = key_root / "gcp", log_dir=None, @@ -22,7 +21,7 @@ def __init__( super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "gcp", f"Region name doesn't match pattern gcp: {self.region_tag}" self.gcp_region = self.region_tag.split(":")[1] - self.gcp_project = gcp_project + self.auth = GCPAuthentication() self.gcp_instance_name = instance_name key_root = Path(key_root) key_root.mkdir(parents=True, exist_ok=True) @@ -32,20 +31,11 @@ def __init__( self.ssh_private_key = ssh_private_key def uuid(self): - return f"{self.gcp_project}:{self.region_tag}:{self.gcp_instance_name}" - - @classmethod - def get_gcp_client(cls, service_name="compute", version="v1"): - return googleapiclient.discovery.build(service_name, version) - - @staticmethod - def gcp_instances(gcp_project, gcp_region): - compute = GCPServer.get_gcp_client() - return compute.instances().list(project=gcp_project, zone=gcp_region).execute() + return f"{self.region_tag}:{self.gcp_instance_name}" @lru_cache def get_gcp_instance(self): - instances = self.gcp_instances(self.gcp_project, self.gcp_region) + instances = self.auth.get_gcp_instances(self.gcp_region) if "items" in instances: for i in instances["items"]: if i["name"] == self.gcp_instance_name: @@ -89,11 +79,12 @@ def network_tier(self): return interface["accessConfigs"][0]["networkTier"] def __repr__(self): - return f"GCPServer(region_tag={self.region_tag}, gcp_project={self.gcp_project}, instance_name={self.gcp_instance_name})" + return f"GCPServer(region_tag={self.region_tag}, instance_name={self.gcp_instance_name})" def terminate_instance_impl(self): - compute = self.get_gcp_client() - compute.instances().delete(project=self.gcp_project, zone=self.gcp_region, instance=self.instance_name()).execute() + self.auth.get_gcp_client().instances().delete( + project=self.auth.project_id, zone=self.gcp_region, instance=self.instance_name() + ).execute() def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" diff --git a/skylark/compute/server.py b/skylark/compute/server.py index a330eca91..b6e86f9ee 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -4,13 +4,10 @@ from pathlib import Path from typing import Dict import requests -from skylark import config_file from skylark.utils import logger from skylark.compute.utils import make_dozzle_command, make_sysctl_tcp_tuning_command from skylark.utils.utils import PathLike, Timer, retry_backoff, wait_for - - -from skylark import config_file +from skylark import config_path class ServerState(Enum): @@ -178,6 +175,12 @@ def download_file(self, remote_path, local_path): with client.open_sftp() as sftp: sftp.get(remote_path, local_path) + def upload_file(self, local_path, remote_path): + """Upload a file to the server""" + client = self.ssh_client + with client.open_sftp() as sftp: + sftp.put(local_path, remote_path) + def copy_public_key(self, pub_key_path: PathLike): """Append public key to authorized_keys file on server.""" pub_key_path = Path(pub_key_path) @@ -201,36 +204,35 @@ def start_gateway( log_viewer_port=8888, use_bbr=False, ): - self.wait_for_ready() - def check_stderr(tup): assert tup[1].strip() == "", f"Command failed, err: {tup[1]}" desc_prefix = f"Starting gateway {self.uuid()}, host: {self.public_ip()}" + self.wait_for_ready() # increase TCP connections, enable BBR optionally and raise file limits check_stderr(self.run_command(make_sysctl_tcp_tuning_command(cc="bbr" if use_bbr else "cubic"))) - retry_backoff(self.install_docker, exception_class=RuntimeError) - self.run_command(make_dozzle_command(log_viewer_port)) + with Timer("Install docker"): + retry_backoff(self.install_docker, exception_class=RuntimeError) - # read AWS config file to get credentials - # TODO: Integrate this with updated skylark config file - # copy config file - config = config_file.read_text()[:-2] + "}" - config = json.dumps(config) # Convert to JSON string and remove trailing comma/new-line - self.run_command(f'mkdir -p /tmp; echo "{config}" | sudo tee /tmp/{config_file.name} > /dev/null') + # start log viewer + self.run_command(make_dozzle_command(log_viewer_port)) - docker_envs = "" # If needed, add environment variables to docker command + # copy cloud configuration + docker_envs = {} + if config_path.exists(): + self.upload_file(config_path, f"/tmp/{config_path.name}") + docker_envs["SKYLARK_CONFIG"] = f"/pkg/data/{config_path.name}" + # pull docker image and start container with Timer(f"{desc_prefix}: Docker pull"): docker_out, docker_err = self.run_command(f"sudo docker pull {gateway_docker_image}") assert "Status: Downloaded newer image" in docker_out or "Status: Image is up to date" in docker_out, (docker_out, docker_err) logger.debug(f"{desc_prefix}: Starting gateway container") - docker_run_flags = ( - f"-d --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024} {docker_envs}" - ) + docker_run_flags = f"-d --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024}" docker_run_flags += " --mount type=tmpfs,dst=/skylark,tmpfs-size=$(($(free -b | head -n2 | tail -n1 | awk '{print $2}')/2))" - docker_run_flags += f" -v /tmp/{config_file.name}:/pkg/data/{config_file.name}" + docker_run_flags += f" -v /tmp/{config_path.name}:/pkg/data/{config_path.name}" + docker_run_flags += " " + " ".join(f"--env {k}={v}" for k, v in docker_envs.items()) gateway_daemon_cmd = f"python -u /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /skylark/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" start_out, start_err = self.run_command(docker_launch_cmd) diff --git a/skylark/config.py b/skylark/config.py index 1254d0c04..d43ed0f34 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -1,30 +1,55 @@ -import json +from dataclasses import dataclass import os -from skylark import config_file +from pathlib import Path +from typing import Optional from skylark.utils import logger +import configparser -def load_config(): - if config_file.exists(): - try: - with config_file.open("r") as f: - config = json.load(f) - if "aws_access_key_id" in config: - os.environ["AWS_ACCESS_KEY_ID"] = config["aws_access_key_id"] - if "aws_secret_access_key" in config: - os.environ["AWS_SECRET_ACCESS_KEY"] = config["aws_secret_access_key"] - if "azure_tenant_id" in config: - os.environ["AZURE_TENANT_ID"] = config["azure_tenant_id"] - if "azure_client_id" in config: - os.environ["AZURE_CLIENT_ID"] = config["azure_client_id"] - if "azure_client_secret" in config: - os.environ["AZURE_CLIENT_SECRET"] = config["azure_client_secret"] - if "gcp_application_credentials_file" in config: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config["gcp_application_credentials_file"] - - return config - except json.JSONDecodeError as e: - logger.error(f"Error decoding config file: {e}") - raise e - return {} +@dataclass +class SkylarkConfig: + azure_subscription_id: Optional[str] = None + gcp_project_id: Optional[str] = None + + @staticmethod + def load_config(path) -> "SkylarkConfig": + """Load from a config file.""" + path = Path(path) + config = configparser.ConfigParser() + if not path.exists(): + logger.error(f"Config file not found: {path}") + raise FileNotFoundError(f"Config file not found: {path}") + config.read(path) + + azure_subscription_id = None + if "azure" in config and "subscription_id" in config["azure"]: + azure_subscription_id = config.get("azure", "subscription_id") + + gcp_project_id = None + if "gcp" in config and "project_id" in config["gcp"]: + gcp_project_id = config.get("gcp", "project_id") + + return SkylarkConfig( + azure_subscription_id=azure_subscription_id, + gcp_project_id=gcp_project_id, + ) + + def to_config_file(self, path): + path = Path(path) + config = configparser.ConfigParser() + if path.exists(): + config.read(os.path.expanduser(path)) + + if self.azure_subscription_id: + if "azure" not in config: + config.add_section("azure") + config.set("azure", "subscription_id", self.azure_subscription_id) + + if self.gcp_project_id: + if "gcp" not in config: + config.add_section("gcp") + config.set("gcp", "project_id", self.gcp_project_id) + + with path.open("w") as f: + config.write(f) diff --git a/skylark/gateway/gateway_sender.py b/skylark/gateway/gateway_sender.py index fcdf8bb19..9c9720ad2 100644 --- a/skylark/gateway/gateway_sender.py +++ b/skylark/gateway/gateway_sender.py @@ -1,6 +1,6 @@ import queue import socket -from multiprocessing import Event, Manager, Process, Value +from multiprocessing import Event, Manager, Process from typing import Dict, List, Optional import requests diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 157bb6f7a..affd269f6 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -2,82 +2,96 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError -from azure.identity import ClientSecretCredential -from azure.storage.blob import BlobServiceClient -from skylark.config import load_config +from skylark.compute.azure.azure_auth import AzureAuthentication +from skylark.compute.azure.azure_server import AzureServer from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject class AzureObject(ObjectStoreObject): def full_path(self): - raise NotImplementedError() + account_name, container_name = self.bucket.split("/") + return os.path.join(f"https://{account_name}.blob.core.windows.net", container_name, self.key) class AzureInterface(ObjectStoreInterface): - def __init__(self, azure_region, container_name): - # TODO: the azure region should get corresponding os.getenv() - self.azure_region = azure_region + def __init__(self, azure_region, account_name, container_name): + # TODO (#210): should be configured via argument + self.account_name = f"skylark{azure_region.replace(' ', '').lower()}" self.container_name = container_name - self.bucket_name = self.container_name # For compatibility - # Authenticate - config = load_config() - self.subscription_id = config["azure_subscription_id"] - self.credential = ClientSecretCredential( - tenant_id=config["azure_tenant_id"], - client_id=config["azure_client_id"], - client_secret=config["azure_client_secret"], - ) - # Create a blob service client - self.account_url = "https://{}.blob.core.windows.net".format("skylark" + self.azure_region) - self.blob_service_client = BlobServiceClient(account_url=self.account_url, credential=self.credential) + # Create a blob service client + self.auth = AzureAuthentication() + self.account_url = f"https://{self.account_name}.blob.core.windows.net" + self.storage_management_client = self.auth.get_storage_management_client() + self.container_client = self.auth.get_container_client(self.account_url, self.container_name) + self.blob_service_client = self.auth.get_blob_service_client(self.account_url) + + # infer azure region from storage account + if azure_region is None: + self.azure_region = self.get_region_from_storage_account(self.account_name) + else: + self.azure_region = azure_region + + # parallel upload/downloads self.pool = ThreadPoolExecutor(max_workers=256) # TODO: This might need some tuning self.max_concurrency = 1 - self.container_client = None - if not self.container_exists(): - self.create_container() - logger.info(f"==> Creating Azure container {self.container_name}") - def container_exists(self): # More like "is container empty?" - # Get a client to interact with a specific container - though it may not yet exist - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) + def get_region_from_storage_account(self, storage_account_name): + storage_account = self.storage_management_client.storage_accounts.get_properties(AzureServer.resource_group_name, storage_account_name) + return storage_account.location + + def storage_account_exists(self): + try: + self.storage_management_client.storage_accounts.get_properties(AzureServer.resource_group_name, self.account_name) + return True + except ResourceNotFoundError: + return False + + def container_exists(self): try: - for blob in self.container_client.list_blobs(): - return True + self.container_client.get_container_properties() + return True except ResourceNotFoundError: return False + def create_storage_account(self, tier="Premium_LRS"): + try: + operation = self.storage_management_client.storage_accounts.begin_create( + AzureServer.resource_group_name, + self.account_name, + {"sku": {"name": tier}, "kind": "BlockBlobStorage", "location": self.azure_region}, + ) + operation.result() + except ResourceExistsError: + logger.warning("Unable to create storage account as it already exists") + def create_container(self): try: - self.container_client = self.blob_service_client.create_container(self.container_name) - self.properties = self.container_client.get_container_properties() + self.container_client.create_container() except ResourceExistsError: - logger.warning("==> Container might already exist, in which case blobs are re-written") - # logger.warning("==> Alternatively use a diff bucket name with `--bucket-prefix`") - return + logger.warning("Unable to create container as it already exists") - def create_bucket(self): - return self.create_container() + def create_bucket(self, premium_tier=True): + tier = "Premium_LRS" if premium_tier else "Standard_LRS" + if not self.storage_account_exists(): + self.create_storage_account(tier=tier) + if not self.container_exists(): + self.create_container() def delete_container(self): - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) try: self.container_client.delete_container() except ResourceNotFoundError: - logger.warning("Container doesn't exists. Unable to delete") + logger.warning("Unable to delete container as it doesn't exists") def delete_bucket(self): return self.delete_container() def list_objects(self, prefix="") -> Iterator[AzureObject]: - if self.container_client is None: - self.container_client = self.blob_service_client.get_container_client(self.container_name) blobs = self.container_client.list_blobs() for blob in blobs: - yield AzureObject("azure", blob.container, blob.name, blob.size, blob.last_modified) + yield AzureObject("azure", f"{self.account_name}/{blob.container}", blob.name, blob.size, blob.last_modified) def delete_objects(self, keys: List[str]): for key in keys: diff --git a/skylark/obj_store/gcs_interface.py b/skylark/obj_store/gcs_interface.py index d95c5ee83..0a68167f4 100644 --- a/skylark/obj_store/gcs_interface.py +++ b/skylark/obj_store/gcs_interface.py @@ -10,7 +10,7 @@ class GCSObject(ObjectStoreObject): def full_path(self): - raise NotImplementedError() + return os.path.join(f"gs://{self.bucket}", self.key) class GCSInterface(ObjectStoreInterface): @@ -32,10 +32,10 @@ def bucket_exists(self): except Exception: return False - def create_bucket(self, storage_class: str = "STANDARD"): + def create_bucket(self, premium_tier=True): if not self.bucket_exists(): bucket = self._gcs_client.bucket(self.bucket_name) - bucket.storage_class = storage_class + bucket.storage_class = "STANDARD" self._gcs_client.create_bucket(bucket, location=self.gcp_region) assert self.bucket_exists() diff --git a/skylark/obj_store/object_store_interface.py b/skylark/obj_store/object_store_interface.py index a769d4e3d..4edba8b65 100644 --- a/skylark/obj_store/object_store_interface.py +++ b/skylark/obj_store/object_store_interface.py @@ -50,7 +50,8 @@ def create(region_tag: str, bucket: str): elif region_tag.startswith("azure"): from skylark.obj_store.azure_interface import AzureInterface - return AzureInterface(region_tag.split(":")[1], bucket) + # TODO (#210): should be configured via argument + return AzureInterface(region_tag.split(":")[1], None, bucket) else: raise ValueError(f"Invalid region_tag {region_tag} - could not create interface") diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index beee995c2..8bb578894 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -8,9 +8,8 @@ from awscrt.http import HttpHeaders, HttpRequest from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from skylark.compute.aws.aws_auth import AWSAuthentication -from skylark.compute.aws.aws_server import AWSServer -from skylark.config import load_config from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject @@ -20,43 +19,33 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=100): + def __init__(self, aws_region, bucket_name, use_tls=True, part_size=None, throughput_target_gbps=10, num_threads=4): + self.auth = AWSAuthentication() self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region self.bucket_name = bucket_name - num_threads = 4 event_loop_group = EventLoopGroup(num_threads=num_threads, cpu_group=None) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - # Authenticate - config = load_config() - aws_access_key_id = config["aws_access_key_id"] - aws_secret_access_key = config["aws_secret_access_key"] - if aws_access_key_id and aws_secret_access_key: - credential_provider = AwsCredentialsProvider.new_static(aws_access_key_id, aws_secret_access_key) - else: # use default - credential_provider = AwsCredentialsProvider.new_default_chain(bootstrap) - self._s3_client = S3Client( bootstrap=bootstrap, region=self.aws_region, - credential_provider=credential_provider, + credential_provider=AwsCredentialsProvider.new_default_chain(bootstrap), throughput_target_gbps=throughput_target_gbps, part_size=part_size, tls_mode=S3RequestTlsMode.ENABLED if use_tls else S3RequestTlsMode.DISABLED, ) - @staticmethod - def infer_s3_region(bucket_name: str): - s3_client = AWSServer.get_boto3_client("s3") + def infer_s3_region(self, bucket_name: str): + s3_client = self.auth.get_boto3_client("s3") region = s3_client.get_bucket_location(Bucket=bucket_name).get("LocationConstraint", "us-east-1") return region if region is not None else "us-east-1" def bucket_exists(self): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) return self.bucket_name in [b["Name"] for b in s3_client.list_buckets()["Buckets"]] - def create_bucket(self): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + def create_bucket(self, premium_tier=True): + s3_client = self.auth.get_boto3_client("s3", self.aws_region) if not self.bucket_exists(): if self.aws_region == "us-east-1": s3_client.create_bucket(Bucket=self.bucket_name) @@ -66,7 +55,7 @@ def create_bucket(self): def list_objects(self, prefix="") -> Iterator[S3Object]: prefix = prefix if not prefix.startswith("/") else prefix[1:] - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) paginator = s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) for page in page_iterator: @@ -74,13 +63,13 @@ def list_objects(self, prefix="") -> Iterator[S3Object]: yield S3Object("s3", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"]) def delete_objects(self, keys: List[str]): - s3_client = AWSServer.get_boto3_client("s3", self.aws_region) + s3_client = self.auth.get_boto3_client("s3", self.aws_region) while keys: batch, keys = keys[:1000], keys[1000:] # take up to 1000 keys at a time s3_client.delete_objects(Bucket=self.bucket_name, Delete={"Objects": [{"Key": k} for k in batch]}) def get_obj_metadata(self, obj_name): - s3_resource = AWSServer.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name) + s3_resource = self.auth.get_boto3_resource("s3", self.aws_region).Bucket(self.bucket_name) try: return s3_resource.Object(str(obj_name).lstrip("/")) except botocore.exceptions.ClientError as e: @@ -127,6 +116,4 @@ def upload_object(self, src_file_path, dst_object_name, content_type="infer") -> upload_headers.add("Content-Type", content_type) upload_headers.add("Content-Length", str(content_len)) request = HttpRequest("PUT", dst_object_name, upload_headers) - return self._s3_client.make_request( - send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT - ).finished_future + return self._s3_client.make_request(send_filepath=src_file_path, request=request, type=S3RequestType.PUT_OBJECT).finished_future diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index fe801f997..06159377f 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -29,8 +29,6 @@ class ReplicatorClient: def __init__( self, topology: ReplicationTopology, - azure_subscription: Optional[str], - gcp_project: Optional[str], gateway_docker_image: str = "ghcr.io/parasj/skylark:latest", aws_instance_class: Optional[str] = "m5.4xlarge", # set to None to disable AWS azure_instance_class: Optional[str] = "Standard_D2_v5", # set to None to disable Azure @@ -41,14 +39,13 @@ def __init__( self.gateway_docker_image = gateway_docker_image self.aws_instance_class = aws_instance_class self.azure_instance_class = azure_instance_class - self.azure_subscription = azure_subscription self.gcp_instance_class = gcp_instance_class self.gcp_use_premium_network = gcp_use_premium_network # provisioning - self.aws = AWSCloudProvider() if aws_instance_class != "None" else None - self.azure = AzureCloudProvider(azure_subscription) if azure_instance_class != "None" and azure_subscription is not None else None - self.gcp = GCPCloudProvider(gcp_project) if gcp_instance_class != "None" and gcp_project is not None else None + self.aws = AWSCloudProvider() + self.azure = AzureCloudProvider() + self.gcp = GCPCloudProvider() self.bound_nodes: Dict[ReplicationTopologyGateway, Server] = {} def provision_gateways( @@ -59,12 +56,19 @@ def provision_gateways( azure_regions_to_provision = [r for r in regions_to_provision if r.startswith("azure:")] gcp_regions_to_provision = [r for r in regions_to_provision if r.startswith("gcp:")] - assert len(aws_regions_to_provision) == 0 or self.aws is not None, "AWS not enabled" - assert len(azure_regions_to_provision) == 0 or self.azure is not None, "Azure not enabled" - assert len(gcp_regions_to_provision) == 0 or self.gcp is not None, "GCP not enabled" + assert ( + len(aws_regions_to_provision) == 0 or self.aws.auth.enabled() + ), "AWS credentials not configured but job provisions AWS gateways" + assert ( + len(azure_regions_to_provision) == 0 or self.azure.auth.enabled() + ), "Azure credentials not configured but job provisions Azure gateways" + assert ( + len(gcp_regions_to_provision) == 0 or self.gcp.auth.enabled() + ), "GCP credentials not configured but job provisions GCP gateways" # init clouds jobs = [] + jobs.append(partial(self.aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess")) for r in set(aws_regions_to_provision): jobs.append(partial(self.aws.add_ip_to_security_group, r.split(":")[1])) if azure_regions_to_provision: @@ -79,7 +83,7 @@ def provision_gateways( # reuse existing AWS instances if reuse_instances: - if self.aws is not None: + if self.aws.auth.enabled(): aws_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.aws_instance_class, @@ -95,7 +99,7 @@ def provision_gateways( else: current_aws_instances = {} - if self.azure is not None: + if self.azure.auth.enabled(): azure_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.azure_instance_class, @@ -111,7 +115,7 @@ def provision_gateways( else: current_azure_instances = {} - if self.gcp is not None: + if self.gcp.auth.enabled(): gcp_instance_filter = { "tags": {"skylark": "true"}, "instance_type": self.gcp_instance_class, @@ -132,13 +136,13 @@ def provision_gateways( def provision_gateway_instance(region: str) -> Server: provider, subregion = region.split(":") if provider == "aws": - assert self.aws is not None + assert self.aws.auth.enabled() server = self.aws.provision_instance(subregion, self.aws_instance_class) elif provider == "azure": - assert self.azure is not None + assert self.azure.auth.enabled() server = self.azure.provision_instance(subregion, self.azure_instance_class) elif provider == "gcp": - assert self.gcp is not None + assert self.gcp.auth.enabled() # todo specify network tier in ReplicationTopology server = self.gcp.provision_instance(subregion, self.gcp_instance_class, premium_network=self.gcp_use_premium_network) else: @@ -189,8 +193,10 @@ def setup(server: Server): def deprovision_gateways(self): def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: + logger.warning(f"Deprovisioning {server.uuid()}") server.terminate_instance() + logger.warning("Deprovisioning instances") do_parallel(deprovision_gateway_instance, self.bound_nodes.values(), n=-1) def run_replication_plan(self, job: ReplicationJob) -> ReplicationJob: @@ -308,7 +314,7 @@ def monitor_transfer( show_pbar=False, log_interval_s: Optional[float] = None, time_limit_seconds: Optional[float] = None, - cancel_pending: bool = True, + cleanup_gateway: bool = True, save_log: bool = True, write_profile: bool = True, copy_gateway_logs: bool = True, @@ -328,9 +334,11 @@ def shutdown_handler(): if copy_gateway_logs: for instance in self.bound_nodes.values(): logger.info(f"Copying gateway logs from {instance.uuid()}") - instance.run_command("sudo docker logs -t skylark_gateway &> /tmp/gateway.log") - log_out = transfer_dir / f"gateway_{instance.uuid()}.log" - instance.download_file("/tmp/gateway.log", log_out) + instance.run_command("sudo docker logs -t skylark_gateway 2> /tmp/gateway.stderr > /tmp/gateway.stdout") + log_out = transfer_dir / f"gateway_{instance.uuid()}.stdout" + log_err = transfer_dir / f"gateway_{instance.uuid()}.stderr" + instance.download_file("/tmp/gateway.stdout", log_out) + instance.download_file("/tmp/gateway.stderr", log_err) logger.debug(f"Wrote gateway logs to {transfer_dir}") if write_profile: chunk_status_df = self.get_chunk_status_log_df() @@ -349,7 +357,8 @@ def fn(s: Server): do_parallel(fn, self.bound_nodes.values(), n=-1) - if cancel_pending: + if cleanup_gateway: + logger.debug("Registering shutdown handler") atexit.register(shutdown_handler) with Timer() as t: @@ -381,7 +390,7 @@ def fn(s: Server): or time_limit_seconds is not None and t.elapsed > time_limit_seconds ): - if cancel_pending: + if cleanup_gateway: atexit.unregister(shutdown_handler) shutdown_handler() return dict( diff --git a/skylark/test/test_azure_interface.py b/skylark/test/test_azure_interface.py index 92f3faa2f..68d8f6928 100644 --- a/skylark/test/test_azure_interface.py +++ b/skylark/test/test_azure_interface.py @@ -8,9 +8,10 @@ def test_azure_interface(): - azure_interface = AzureInterface(f"eastus", f"sky-us-east-1") - assert azure_interface.bucket_name == "sky-us-east-1" + azure_interface = AzureInterface("eastus", "skyeastus", "sky-us-east-1") + assert azure_interface.container_name == "sky-us-east-1" assert azure_interface.azure_region == "eastus" + assert azure_interface.account_name == "skyeastus" azure_interface.create_bucket() # generate file and upload