diff --git a/skylark/__init__.py b/skylark/__init__.py index d80cae8b2..f364a3312 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -14,6 +14,10 @@ else: config_path = config_root / "config" +aws_config_path = config_root / "aws_config" +azure_config_path = config_root / "azure_config" +gcp_config_path = config_root / "gcp_config" + key_root = config_root / "keys" tmp_log_dir = Path("/tmp/skylark") tmp_log_dir.mkdir(exist_ok=True) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index a21149e38..f89788a49 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -43,6 +43,7 @@ copy_s3_local, copy_gcs_local, copy_local_gcs, + create_aws_region_config, deprovision_skylark_instances, load_aws_config, load_azure_config, @@ -274,11 +275,12 @@ def init(reinit_azure: bool = False, reinit_gcp: bool = False): if config_path.exists(): cloud_config = SkylarkConfig.load_config(config_path) else: - cloud_config = SkylarkConfig() + cloud_config = SkylarkConfig.default_config() # load AWS config typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) cloud_config = load_aws_config(cloud_config) + create_aws_region_config(cloud_config) # load Azure config typer.secho("\n(2) Configuring Azure:", fg="yellow", bold=True) diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index ce40d2681..15ebbd978 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -345,14 +345,18 @@ def load_aws_config(config: SkylarkConfig) -> SkylarkConfig: credentials = session.get_credentials() credentials = credentials.get_frozen_credentials() if credentials.access_key is None or credentials.secret_key is None: + config.aws_enabled = False typer.secho(" AWS credentials not found in boto3 session, please use the AWS CLI to set them via `aws configure`", fg="red") typer.secho(" https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html", fg="red") typer.secho(" Disabling AWS support", fg="blue") return config typer.secho(f" Loaded AWS credentials from the AWS CLI [IAM access key ID: ...{credentials.access_key[-6:]}]", fg="blue") + config.aws_enabled = True return config +def create_aws_region_config(config): + AWSAuthentication.save_region_config(config) def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkConfig: if force_init: @@ -365,7 +369,7 @@ def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> Skylar # check if Azure is enabled logging.disable(logging.WARNING) # disable Azure logging, we have our own - auth = AzureAuthentication() + auth = AzureAuthentication(config=config) try: auth.credential.get_token("https://management.azure.com/") azure_enabled = True @@ -376,14 +380,17 @@ def load_azure_config(config: SkylarkConfig, force_init: bool = False) -> Skylar typer.secho(" No local Azure credentials! Run `az login` to set them up.", fg="red") typer.secho(" https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate", fg="red") typer.secho(" Disabling Azure support", fg="blue") + config.azure_enabled = False return config typer.secho(" Azure credentials found in Azure CLI", fg="blue") inferred_subscription_id = AzureAuthentication.infer_subscription_id() if typer.confirm(" Azure credentials found, do you want to enable Azure support in Skylark?", default=True): config.azure_subscription_id = typer.prompt(" Enter the Azure subscription ID:", default=inferred_subscription_id) + config.azure_enabled = True else: config.azure_subscription_id = None typer.secho(" Disabling Azure support", fg="blue") + config.azure_enabled = False return config @@ -394,10 +401,11 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC if config.gcp_project_id is not None: typer.secho(" GCP already configured! To reconfigure GCP, run `skylark init --reinit-gcp`.", fg="blue") + config.gcp_enabled = True return config # check if GCP is enabled - auth = GCPAuthentication() + auth = GCPAuthentication(config=config) if not auth.credentials: typer.secho( " Default GCP credentials are not set up yet. Run `gcloud auth application-default login`.", @@ -405,14 +413,17 @@ def load_gcp_config(config: SkylarkConfig, force_init: bool = False) -> SkylarkC ) typer.secho(" https://cloud.google.com/docs/authentication/getting-started", fg="red") typer.secho(" Disabling GCP support", fg="blue") + config.gcp_enabled = False return config else: typer.secho(" GCP credentials found in GCP CLI", fg="blue") if typer.confirm(" GCP credentials found, do you want to enable GCP support in Skylark?", default=True): config.gcp_project_id = typer.prompt(" Enter the GCP project ID:", default=auth.project_id) assert config.gcp_project_id is not None, "GCP project ID must not be None" + config.gcp_enabled = True return config else: config.gcp_project_id = None typer.secho(" Disabling GCP support", fg="blue") + config.gcp_enabled = False return config diff --git a/skylark/compute/aws/aws_auth.py b/skylark/compute/aws/aws_auth.py index 4010e9daa..1b310cb3c 100644 --- a/skylark/compute/aws/aws_auth.py +++ b/skylark/compute/aws/aws_auth.py @@ -1,14 +1,24 @@ import threading from typing import Optional +import typer import boto3 +from skylark.config import SkylarkConfig +from skylark import config_path +from skylark import aws_config_path + class AWSAuthentication: __cached_credentials = threading.local() - def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = None): + def __init__(self, config: Optional[SkylarkConfig] = None, access_key: Optional[str] = None, secret_key: Optional[str] = None): """Loads AWS authentication details. If no access key is provided, it will try to load credentials using boto3""" + if not config == None: + self.config = config + else: + self.config = SkylarkConfig.load_config(config_path) + if access_key and secret_key: self.config_mode = "manual" self._access_key = access_key @@ -17,6 +27,41 @@ def __init__(self, access_key: Optional[str] = None, secret_key: Optional[str] = self.config_mode = "iam_inferred" self._access_key = None self._secret_key = None + + @staticmethod + def save_region_config(config): + f = open(aws_config_path, "w") + if config.aws_enabled == False: + f.write("") + f.close() + return + + region_list = [] + for region in boto3.client('ec2').describe_regions()['Regions']: + if region['OptInStatus'] == 'opt-in-not-required' or region['OptInStatus'] == 'opted-in': + region_text = region['Endpoint'] + region_name = region_text[region_text.find('.') + 1 :region_text.find(".amazon")] + region_list.append(region_name) + + config = "" + for i in range(len(region_list) - 1): + config += region_list[i] + "\n" + config += region_list[len(region_list) - 1] + f.write(config) + typer.secho(f"\nConfig file saved to {aws_config_path}", fg="green") + f.close() + + @staticmethod + def get_region_config(): + try: + f = open(aws_config_path, "r") + except FileNotFoundError: + typer.echo("No AWS config detected! Consquently, the AWS region list is empty. Run 'skylark init' to remedy this.") + return [] + region_list = [] + for region in f.read().split("\n"): + region_list.append(region) + return region_list @property def access_key(self): @@ -31,7 +76,7 @@ def secret_key(self): return self._secret_key def enabled(self): - return self.config_mode != "disabled" + return self.config.aws_enabled def infer_credentials(self): # todo load temporary credentials from STS diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index f0a0a3be7..2106efb4b 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -8,6 +8,7 @@ import pandas as pd from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.utils import logger +import typer from skylark import key_root from oslo_concurrency import lockutils @@ -29,31 +30,8 @@ def name(self): @staticmethod def region_list() -> List[str]: # todo query AWS for list of active regions - all_regions = [ - "af-south-1", - "ap-east-1", - "ap-northeast-1", - "ap-northeast-2", - "ap-northeast-3", - "ap-south-1", - "ap-southeast-1", - "ap-southeast-2", - "ap-southeast-3", - "ca-central-1", - "eu-central-1", - "eu-north-1", - "eu-south-1", - "eu-west-1", - "eu-west-2", - "eu-west-3", - "me-south-1", - "sa-east-1", - "us-east-1", - "us-east-2", - "us-west-1", - "us-west-2", - ] - return all_regions + region_list = AWSAuthentication.get_region_config() + return region_list @staticmethod def get_transfer_cost(src_key, dst_key, premium_tier=True): diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index 05032875c..7252b66bd 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -12,12 +12,18 @@ from skylark import cloud_config from skylark.compute.utils import query_which_cloud +from skylark.config import SkylarkConfig +from skylark import config_path class AzureAuthentication: __cached_credentials = threading.local() - def __init__(self, subscription_id: str = cloud_config.azure_subscription_id): + def __init__(self, config: Optional[SkylarkConfig] = None, subscription_id: str = cloud_config.azure_subscription_id): + if not config == None: + self.config = config + else: + self.config = SkylarkConfig.load_config(config_path) self.subscription_id = subscription_id self.credential = self.get_credential(subscription_id) @@ -33,7 +39,7 @@ def get_credential(self, subscription_id: str): return cached_credential def enabled(self) -> bool: - return self.subscription_id is not None + return self.config.azure_enabled and self.subscription_id is not None @staticmethod def infer_subscription_id() -> Optional[str]: diff --git a/skylark/compute/gcp/gcp_auth.py b/skylark/compute/gcp/gcp_auth.py index 520521284..aa99e8b68 100644 --- a/skylark/compute/gcp/gcp_auth.py +++ b/skylark/compute/gcp/gcp_auth.py @@ -4,13 +4,19 @@ import google.auth from skylark import cloud_config +from skylark.config import SkylarkConfig +from skylark import config_path class GCPAuthentication: __cached_credentials = threading.local() - def __init__(self, project_id: Optional[str] = cloud_config.gcp_project_id): + def __init__(self, config: Optional[SkylarkConfig] = None, project_id: Optional[str] = cloud_config.gcp_project_id): # load credentials lazily and then cache across threads + if not config == None: + self.config = config + else: + self.config = SkylarkConfig.load_config(config_path) self.inferred_project_id = project_id self._credentials = None self._project_id = None @@ -30,15 +36,12 @@ def project_id(self): def make_credential(self, project_id): cached_credential = getattr(self.__cached_credentials, f"credential_{project_id}", (None, None)) if cached_credential == (None, None): - try: - cached_credential = google.auth.default(quota_project_id=project_id) - setattr(self.__cached_credentials, f"credential_{project_id}", cached_credential) - except: - pass + cached_credential = google.auth.default(quota_project_id=project_id) + setattr(self.__cached_credentials, f"credential_{project_id}", cached_credential) return cached_credential def enabled(self): - return self.credentials is not None and self.project_id is not None + return self.config.gcp_enabled and self.credentials is not None and self.project_id is not None def get_gcp_client(self, service_name="compute", version="v1"): return googleapiclient.discovery.build(service_name, version) diff --git a/skylark/config.py b/skylark/config.py index d43ed0f34..1424d2d5d 100644 --- a/skylark/config.py +++ b/skylark/config.py @@ -9,9 +9,22 @@ @dataclass class SkylarkConfig: + aws_enabled: bool + azure_enabled: bool + gcp_enabled: bool azure_subscription_id: Optional[str] = None gcp_project_id: Optional[str] = None + @staticmethod + def default_config() -> "SkylarkConfig": + return SkylarkConfig( + aws_enabled=False, + azure_enabled=False, + gcp_enabled=False, + azure_subscription_id=None, + gcp_project_id=None, + ) + @staticmethod def load_config(path) -> "SkylarkConfig": """Load from a config file.""" @@ -22,15 +35,31 @@ def load_config(path) -> "SkylarkConfig": raise FileNotFoundError(f"Config file not found: {path}") config.read(path) + aws_enabled = False + if "aws" in config: + if "aws_enabled" in config["aws"]: + aws_enabled = config.getboolean("aws", "aws_enabled") + + azure_enabled = False azure_subscription_id = None - if "azure" in config and "subscription_id" in config["azure"]: - azure_subscription_id = config.get("azure", "subscription_id") + if "azure" in config: + if "azure_enabled" in config["azure"]: + azure_enabled = config.getboolean("azure", "azure_enabled") + if "subscription_id" in config["azure"]: + azure_subscription_id = config.get("azure", "subscription_id") + gcp_enabled = False gcp_project_id = None - if "gcp" in config and "project_id" in config["gcp"]: - gcp_project_id = config.get("gcp", "project_id") + if "gcp" in config: + if "gcp_enabled" in config["gcp"]: + gcp_enabled = config.getboolean("gcp", "gcp_enabled") + if "project_id" in config["gcp"]: + gcp_project_id = config.get("gcp", "project_id") return SkylarkConfig( + aws_enabled=aws_enabled, + azure_enabled=azure_enabled, + gcp_enabled=gcp_enabled, azure_subscription_id=azure_subscription_id, gcp_project_id=gcp_project_id, ) @@ -41,14 +70,22 @@ def to_config_file(self, path): if path.exists(): config.read(os.path.expanduser(path)) + if "aws" not in config: + config.add_section("aws") + config.set("aws", "aws_enabled", str(self.aws_enabled)) + + if "azure" not in config: + config.add_section("azure") + config.set("azure", "azure_enabled", str(self.azure_enabled)) + if self.azure_subscription_id: - if "azure" not in config: - config.add_section("azure") config.set("azure", "subscription_id", self.azure_subscription_id) + if "gcp" not in config: + config.add_section("gcp") + config.set("gcp", "gcp_enabled", str(self.gcp_enabled)) + if self.gcp_project_id: - if "gcp" not in config: - config.add_section("gcp") config.set("gcp", "project_id", self.gcp_project_id) with path.open("w") as f: