From cebff9f7c965404d34f760d33752d2b0ca76c9d2 Mon Sep 17 00:00:00 2001 From: Sam Kumar Date: Thu, 13 Jan 2022 16:13:43 -0800 Subject: [PATCH] Add support for Microsoft Azure --- setup.py | 4 + skylark/benchmark/stop_all_instances.py | 17 +- skylark/cli/cli.py | 8 +- skylark/cli/cli_helper.py | 23 +- skylark/compute/aws/aws_cloud_provider.py | 5 +- skylark/compute/azure/azure_cloud_provider.py | 244 ++++++++++++++++++ skylark/compute/azure/azure_server.py | 154 +++++++++++ skylark/compute/server.py | 12 + skylark/replicate/replication_plan.py | 3 + skylark/replicate/replicator_client.py | 42 ++- skylark/test/test_replicator_client.py | 6 + 11 files changed, 497 insertions(+), 21 deletions(-) create mode 100644 skylark/compute/azure/azure_cloud_provider.py create mode 100644 skylark/compute/azure/azure_server.py diff --git a/setup.py b/setup.py index 1d0205198..558a7cfa3 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,10 @@ packages=["skylark"], python_requires=">=3.8", install_requires=[ + "azure-mgmt-resource", + "azure-mgmt-compute", + "azure-mgmt-network", + "azure-identity", "awscrt", "boto3", "flask", diff --git a/skylark/benchmark/stop_all_instances.py b/skylark/benchmark/stop_all_instances.py index 20950eaca..f1de52b5a 100644 --- a/skylark/benchmark/stop_all_instances.py +++ b/skylark/benchmark/stop_all_instances.py @@ -4,6 +4,7 @@ from tqdm import tqdm from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider +from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.server import Server from skylark.utils.utils import do_parallel @@ -16,19 +17,27 @@ def stop_instance(instance: Server): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Stop all instances") + parser.add_argument("--disable-aws", action="store_true", help="Disables AWS operations if present") parser.add_argument("--gcp-project", type=str, help="GCP project", default=None) + parser.add_argument("--azure-subscription", type=str, help="Microsoft Azure Subscription", default=None) args = parser.parse_args() instances = [] - logger.info("Getting matching AWS instances") - aws = AWSCloudProvider() - for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True): - instances += instance_list + if not args.disable_aws: + logger.info("Getting matching AWS instances") + aws = AWSCloudProvider() + for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True): + instances += instance_list if args.gcp_project: logger.info("Getting matching GCP instances") gcp = GCPCloudProvider(gcp_project=args.gcp_project) instances += gcp.get_matching_instances() + if args.azure_subscription: + logger.info("Getting matching Azure instances") + azure = AzureCloudProvider(azure_subscription=args.azure_subscription) + instances += azure.get_matching_instances() + do_parallel(stop_instance, instances, progress_bar=True) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index dab8e435a..6b6282bea 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -84,9 +84,11 @@ def replicate_random( chunk_size_mb: int = 8, n_chunks: int = 2048, reuse_gateways: bool = True, + azure_subscription: str = "skylark-invalid-subscription", gcp_project: str = "skylark-333700", gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"), aws_instance_class: str = "m5.8xlarge", + azure_instance_class: str = "Standard_D2_v5", gcp_instance_class: Optional[str] = "n2-highmem-4", gcp_use_premium_network: bool = False, key_prefix: str = "/test/replicate_random", @@ -105,9 +107,11 @@ def replicate_random( num_conn = num_outgoing_connections rc = ReplicatorClient( topo, + azure_subscription=azure_subscription, gcp_project=gcp_project, gateway_docker_image=gateway_docker_image, aws_instance_class=aws_instance_class, + azure_instance_class=azure_instance_class, gcp_instance_class=gcp_instance_class, gcp_use_premium_network=gcp_use_premium_network, ) @@ -155,9 +159,9 @@ def replicate_random( @app.command() -def deprovision(gcp_project: Optional[str] = None): +def deprovision(azure_subscription: Optional[str] = None, gcp_project: Optional[str] = None): """Deprovision gateways.""" - deprovision_skylark_instances(gcp_project_id=gcp_project) + deprovision_skylark_instances(azure_subscription=azure_subscription, gcp_project_id=gcp_project) if __name__ == "__main__": diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 4e34df7e5..3fdebfbb1 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -8,6 +8,7 @@ from tqdm import tqdm import typer from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider +from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.obj_store.object_store_interface import ObjectStoreObject @@ -135,21 +136,29 @@ def _copy(src_obj: ObjectStoreObject, dst: Path): # utility functions -def deprovision_skylark_instances(gcp_project_id: Optional[str] = None): +def deprovision_skylark_instances(azure_subscription: Optional[str] = None, gcp_project_id: Optional[str] = None): instances = [] - if not gcp_project_id: - typer.secho("No GCP project ID given, so will only deprovision AWS instances", color=typer.colors.YELLOW, bold=True) - else: - gcp = GCPCloudProvider(gcp_project=gcp_project_id) - instances += gcp.get_matching_instances() - aws = AWSCloudProvider() for _, instance_list in do_parallel( aws.get_matching_instances, aws.region_list(), progress_bar=True, leave_pbar=False, desc="Retrieve AWS instances" ): instances += instance_list + if not azure_subscription: + typer.secho( + "No Microsoft Azure subscription given, so Azure instances will not be terminated", color=typer.colors.YELLOW, bold=True + ) + else: + azure = AzureCloudProvider(azure_subscription=azure_subscription) + instances += azure.get_matching_instances() + + if not gcp_project_id: + typer.secho("No GCP project ID given, so GCP instances will not be deprovisioned", color=typer.colors.YELLOW, bold=True) + else: + gcp = GCPCloudProvider(gcp_project=gcp_project_id) + instances += gcp.get_matching_instances() + if instances: typer.secho(f"Deprovisioning {len(instances)} instances", color=typer.colors.YELLOW, bold=True) do_parallel(lambda instance: instance.terminate_instance(), instances, progress_bar=True, desc="Deprovisioning") diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 1eafa74ec..ca36dccd8 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -17,9 +17,6 @@ class AWSCloudProvider(CloudProvider): def __init__(self): super().__init__() - # Ubuntu deep learning AMI - # https://aws.amazon.com/marketplace/pp/prodview-dxk3xpeg6znhm - self.ami_alias = "resolve:ssm:/aws/service/marketplace/prod-oivea5digmbj6/latest" @property def name(self): @@ -264,7 +261,7 @@ def provision_instance( for i in range(4): try: instance = ec2.create_instances( - ImageId=self.ami_alias, + ImageId=self.get_ubuntu_ami_id(region), InstanceType=instance_class, MinCount=1, MaxCount=1, diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py new file mode 100644 index 000000000..e2ca91029 --- /dev/null +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -0,0 +1,244 @@ +import os +import uuid + +from pathlib import Path +from typing import List, Optional + +from azure.identity import DefaultAzureCredential +from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.mgmt.compute import ComputeManagementClient + +import paramiko + +from loguru import logger + +from skylark import key_root +from skylark.compute.cloud_providers import CloudProvider +from skylark.compute.azure.azure_server import AzureServer +from skylark.utils.utils import Timer + + +class AzureCloudProvider(CloudProvider): + def __init__(self, azure_subscription, key_root=key_root / "azure"): + super().__init__() + self.subscription_id = azure_subscription + key_root.mkdir(parents=True, exist_ok=True) + self.private_key_path = key_root / "azure_key" + self.public_key_path = key_root / "azure_key.pub" + + @property + def name(self): + return "azure" + + @staticmethod + def region_list(): + return [ + "eastasia", + "southeastasia", + "centralus", + "eastus", + "eastus2", + "westus", + "northcentralus", + "southcentralus", + "northeurope", + "westeurope", + "japanwest", + "japaneast", + "brazilsouth", + "australiaeast", + "australiasoutheast", + "southindia", + "centralindia", + "westindia", + "jioindiawest", + "jioindiacentral", + "canadacentral", + "canadaeast", + "uksouth", + "ukwest", + "westcentralus", + "westus2", + "koreacentral", + "koreasouth", + "francecentral", + "francesouth", + "australiacentral", + "australiacentral2", + "uaecentral", + "uaenorth", + "southafricanorth", + "southafricawest", + "switzerlandnorth", + "switzerlandwest", + "germanynorth", + "germanywestcentral", + "norwaywest", + "norwayeast", + "brazilsoutheast", + "westus3", + "swedencentral", + ] + + @staticmethod + def get_resource_group_name(name): + return name + + @staticmethod + def get_transfer_cost(src_key, dst_key): + raise NotImplementedError + + def get_instance_list(self, region: str) -> List[AzureServer]: + credential = DefaultAzureCredential() + resource_client = ResourceManagementClient(credential, self.subscription_id) + resource_group_list_iterator = resource_client.resource_groups.list(filter="tagName eq 'skylark' and tagValue eq 'true'") + + server_list = [] + for resource_group in resource_group_list_iterator: + if resource_group.location == region: + s = AzureServer(self.subscription_id, resource_group.name) + if s.is_valid(): + server_list.append(s) + else: + logger.warning( + f"Warning: malformed Azure resource group {resource_group.name} found and ignored. You should go to the Microsoft Azure portal, investigate this manually, and delete any orphaned resources that may be allocated." + ) + return server_list + + # Copied from gcp_cloud_provider.py --- consolidate later? + def create_ssh_key(self): + private_key_path = Path(self.private_key_path) + if not private_key_path.exists(): + private_key_path.parent.mkdir(parents=True, exist_ok=True) + key = paramiko.RSAKey.generate(4096) + key.write_private_key_file(self.private_key_path, password="skylark") + with open(self.public_key_path, "w") as f: + f.write(f"{key.get_name()} {key.get_base64()}\n") + + # This code, along with some code in azure_server.py, is based on + # https://github.com/ucbrise/mage-scripts/blob/main/azure_cloud.py. + def provision_instance( + self, + location: str, + vm_size: str, + name: Optional[str] = None, + uname: str = os.environ.get("USER"), + ) -> AzureServer: + assert ":" not in location, "invalid colon in Azure location" + if name is None: + name = f"skylark-azure-{str(uuid.uuid4()).replace('-', '')}" + + with open(os.path.expanduser(self.public_key_path)) as f: + pub_key = f.read() + + # Prepare for making Microsoft Azure API calls + credential = DefaultAzureCredential() + compute_client = ComputeManagementClient(credential, self.subscription_id) + network_client = NetworkManagementClient(credential, self.subscription_id) + resource_client = ResourceManagementClient(credential, self.subscription_id) + + # Create a resource group for this instance + resource_group = name + if resource_client.resource_groups.check_existence(resource_group): + raise RuntimeError('Cannot spawn instance "{0}": instance already exists'.format(name)) + rg_result = resource_client.resource_groups.create_or_update(resource_group, {"location": location, "tags": {"skylark": "true"}}) + assert rg_result.name == resource_group + + # Create a Virtual Network for this instance + poller = network_client.virtual_networks.begin_create_or_update( + resource_group, AzureServer.vnet_name(name), {"location": location, "address_space": {"address_prefixes": ["10.0.0.0/24"]}} + ) + vnet_result = poller.result() + + # Create a Network Security Group for this instance + # NOTE: This is insecure. We should fix this soon. + poller = network_client.network_security_groups.begin_create_or_update( + resource_group, + AzureServer.nsg_name(name), + { + "location": location, + "security_rules": [ + { + "name": name + "-allow-all", + "protocol": "Tcp", + "source_port_range": "*", + "source_address_prefix": "*", + "destination_port_range": "*", + "destination_address_prefix": "*", + "access": "Allow", + "priority": 300, + "direction": "Inbound", + } + # Azure appears to add default rules for outbound connections + ], + }, + ) + nsg_result = poller.result() + + # Create a subnet for this instance with the above Network Security Group + poller = network_client.subnets.begin_create_or_update( + resource_group, + AzureServer.vnet_name(name), + AzureServer.subnet_name(name), + {"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}}, + ) + subnet_result = poller.result() + + # Create a public IPv4 address for this instance + poller = network_client.public_ip_addresses.begin_create_or_update( + resource_group, + AzureServer.ip_name(name), + { + "location": location, + "sku": {"name": "Standard"}, + "public_ip_allocation_method": "Static", + "public_ip_address_version": "IPV4", + }, + ) + public_ip_result = poller.result() + + # Create a NIC for this instance, with accelerated networking enabled + poller = network_client.network_interfaces.begin_create_or_update( + resource_group, + AzureServer.nic_name(name), + { + "location": location, + "ip_configurations": [ + {"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}} + ], + "enable_accelerated_networking": True, + }, + ) + nic_result = poller.result() + + # Create the VM + poller = compute_client.virtual_machines.begin_create_or_update( + resource_group, + AzureServer.vm_name(name), + { + "location": location, + "zones": ["1"], + "hardware_profile": {"vm_size": vm_size}, + "storage_profile": { + "image_reference": { + "publisher": "canonical", + "offer": "0001-com-ubuntu-server-focal", + "sku": "20_04-lts", + "version": "latest", + }, + }, + "os_profile": { + "computer_name": AzureServer.vm_name(name), + "admin_username": uname, + "linux_configuration": { + "disable_password_authentication": True, + "ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]}, + }, + }, + "network_profile": {"network_interfaces": [{"id": nic_result.id}]}, + }, + ) + vm_result = poller.result() + + return AzureServer(self.subscription_id, resource_group) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py new file mode 100644 index 000000000..9226700d1 --- /dev/null +++ b/skylark/compute/azure/azure_server.py @@ -0,0 +1,154 @@ +import os +from pathlib import Path + +import azure.core.exceptions +from azure.identity import DefaultAzureCredential +from azure.mgmt.resource import ResourceManagementClient +from azure.mgmt.network import NetworkManagementClient +from azure.mgmt.compute import ComputeManagementClient + +import paramiko + +from skylark import key_root +from skylark.compute.server import Server, ServerState +from skylark.utils.utils import PathLike + + +class AzureServer(Server): + def __init__(self, subscription_id: str, name: str, key_root: PathLike = key_root / "azure", log_dir=None, ssh_private_key=None): + self.subscription_id = subscription_id + self.name = name + self.location = None + + resource_group = self.get_resource_group() + + self.location = resource_group.location + region_tag = f"azure:{self.location}" + + super().__init__(region_tag, log_dir=log_dir) + + key_root = Path(key_root) + key_root.mkdir(parents=True, exist_ok=True) + if ssh_private_key is None: + self.ssh_private_key = key_root / "azure_key" + else: + self.ssh_private_key = ssh_private_key + + @staticmethod + def vnet_name(name): + return name + "-vnet" + + @staticmethod + def nsg_name(name): + return name + "-nsg" + + @staticmethod + def subnet_name(name): + return name + "-subnet" + + @staticmethod + def vm_name(name): + return name + "-vm" + + @staticmethod + def wdisk_name(name): + return AzureServer.vm_name(name) + "-wdisk" + + @staticmethod + def ip_name(name): + return AzureServer.vm_name(name) + "-ip" + + @staticmethod + def nic_name(name): + return AzureServer.vm_name(name) + "-nic" + + def get_resource_group(self): + credential = DefaultAzureCredential() + resource_client = ResourceManagementClient(credential, self.subscription_id) + rg = resource_client.resource_groups.get(self.name) + + # Sanity checks + assert self.location is None or rg.location == self.location + assert rg.name == self.name + assert rg.tags.get("skylark", None) == "true" + + return rg + + def get_virtual_machine(self): + credential = DefaultAzureCredential() + compute_client = ComputeManagementClient(credential, self.subscription_id) + vm = compute_client.virtual_machines.get(self.name, AzureServer.vm_name(self.name)) + + # Sanity checks + assert vm.location == self.location + assert vm.name == AzureServer.vm_name(self.name) + + return vm + + def is_valid(self): + try: + _ = self.get_virtual_machine() + return True + except azure.core.exceptions.ResourceNotFoundError: + return False + + def uuid(self): + return f"{self.subscription_id}:{self.region_tag}:{self.name}" + + def instance_state(self) -> ServerState: + credential = DefaultAzureCredential() + compute_client = ComputeManagementClient(credential, self.subscription_id) + vm_instance_view = compute_client.virtual_machines.instance_view(self.name, AzureServer.vm_name(self.name)) + statuses = vm_instance_view.statuses + for status in statuses: + if status.code.startswith("PowerState"): + return ServerState.from_azure_state(status.code) + return ServerState.UNKNOWN + + def public_ip(self): + credential = DefaultAzureCredential() + network_client = NetworkManagementClient(credential, self.subscription_id) + public_ip = network_client.public_ip_addresses.get(self.name, AzureServer.ip_name(self.name)) + + # Sanity checks + assert public_ip.location == self.location + assert public_ip.name == AzureServer.ip_name(self.name) + + return public_ip.ip_address + + def instance_class(self): + vm = self.get_virtual_machine() + return vm.hardware_profile.vm_size + + def region(self): + return self.location + + def instance_name(self): + return self.name + + def tags(self): + resource_group = self.get_resource_group() + return resource_group.tags + + def network_tier(self): + return "PREMIUM" + + def terminate_instance_impl(self): + credential = DefaultAzureCredential() + resource_client = ResourceManagementClient(credential, self.subscription_id) + _ = self.get_resource_group() # for the sanity checks + poller = resource_client.resource_groups.begin_delete(self.name) + _ = poller.result() + + def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): + """Return paramiko client that connects to this instance.""" + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_client.connect( + hostname=self.public_ip(), + username=uname, + key_filename=str(self.ssh_private_key), + passphrase=ssh_key_password, + look_for_keys=False, + ) + return ssh_client diff --git a/skylark/compute/server.py b/skylark/compute/server.py index 93c507c59..f917f2dbe 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -38,6 +38,18 @@ def from_gcp_state(gcp_state): } return mapping.get(gcp_state, ServerState.UNKNOWN) + @staticmethod + def from_azure_state(azure_state): + mapping = { + "PowerState/starting": ServerState.PENDING, + "PowerState/running": ServerState.RUNNING, + "PowerState/stopping": ServerState.SUSPENDED, + "PowerState/stopped": ServerState.SUSPENDED, + "PowerState/deallocating": ServerState.TERMINATED, + "PowerState/deallocated": ServerState.TERMINATED, + } + return mapping.get(azure_state, ServerState.UNKNOWN) + @staticmethod def from_aws_state(aws_state): mapping = { diff --git a/skylark/replicate/replication_plan.py b/skylark/replicate/replication_plan.py index 8365b3d55..21c35efc0 100644 --- a/skylark/replicate/replication_plan.py +++ b/skylark/replicate/replication_plan.py @@ -2,6 +2,7 @@ from typing import List, Optional from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider +from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.obj_store.s3_interface import S3Interface from skylark.utils.utils import do_parallel @@ -47,6 +48,8 @@ def add_path(self, path: List[str]): for p in path: if p.startswith("aws:"): assert p.split(":")[1] in AWSCloudProvider.region_list(), f"{p} is not a valid AWS region" + elif p.startswith("azure:"): + assert p.split(":")[1] in AzureCloudProvider.region_list(), f"{p} is not a valid Azure region" elif p.startswith("gcp:"): assert p.split(":")[1] in GCPCloudProvider.region_list(), f"{p} is not a valid GCP region" else: diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index 2bd924874..af62fcf1a 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -17,6 +17,7 @@ from skylark.benchmark.utils import refresh_instance_list from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider +from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider from skylark.compute.server import Server, ServerState from skylark.chunk import Chunk, ChunkRequest, ChunkRequestHop, ChunkState @@ -30,20 +31,25 @@ class ReplicatorClient: def __init__( self, topology: ReplicationTopology, + azure_subscription: str, gcp_project: str, gateway_docker_image: str = "ghcr.io/parasj/skylark:latest", aws_instance_class: Optional[str] = "m5.4xlarge", # set to None to disable AWS + azure_instance_class: Optional[str] = "Standard_D2_v5", # set to None to disable Azure gcp_instance_class: Optional[str] = "n2-standard-16", # set to None to disable GCP gcp_use_premium_network: bool = True, ): self.topology = topology self.gateway_docker_image = gateway_docker_image self.aws_instance_class = aws_instance_class + self.azure_instance_class = azure_instance_class + self.azure_subscription = azure_subscription self.gcp_instance_class = gcp_instance_class self.gcp_use_premium_network = gcp_use_premium_network # provisioning self.aws = AWSCloudProvider() if aws_instance_class is not None else None + self.azure = AzureCloudProvider(azure_subscription) if azure_instance_class is not None else None self.gcp = GCPCloudProvider(gcp_project) if gcp_instance_class is not None else None self.bound_paths: Optional[List[List[Server]]] = None @@ -52,6 +58,8 @@ def __init__( if self.aws is not None: for r in self.aws.region_list(): jobs.append(partial(self.aws.add_ip_to_security_group, r)) + if self.azure is not None: + jobs.append(self.azure.create_ssh_key) if self.gcp is not None: jobs.append(self.gcp.create_ssh_key) jobs.append(self.gcp.configure_default_network) @@ -68,9 +76,11 @@ def provision_gateways( ): regions_to_provision = [r for path in self.topology.paths for r in path] aws_regions_to_provision = [r for r in regions_to_provision if r.startswith("aws:")] + azure_regions_to_provision = [r for r in regions_to_provision if r.startswith("azure:")] gcp_regions_to_provision = [r for r in regions_to_provision if r.startswith("gcp:")] assert len(aws_regions_to_provision) == 0 or self.aws is not None, "AWS not enabled" + assert len(azure_regions_to_provision) == 0 or self.azure is not None, "Azure not enabled" assert len(gcp_regions_to_provision) == 0 or self.gcp is not None, "GCP not enabled" # reuse existing AWS instances @@ -91,6 +101,22 @@ def provision_gateways( else: current_aws_instances = {} + if self.azure is not None: + azure_instance_filter = { + "tags": {"skylark": "true"}, + "instance_type": self.azure_instance_class, + "state": [ServerState.PENDING, ServerState.RUNNING], + } + current_azure_instances = refresh_instance_list( + self.azure, set([r.split(":")[1] for r in azure_regions_to_provision]), azure_instance_filter + ) + for r, ilist in current_azure_instances.items(): + for i in ilist: + if f"azure:{r}" in azure_regions_to_provision: + azure_regions_to_provision.remove(f"azure:{r}") + else: + current_azure_instances = {} + if self.gcp is not None: gcp_instance_filter = { "tags": {"skylark": "true"}, @@ -114,6 +140,9 @@ def provision_gateway_instance(region: str) -> Server: if provider == "aws": assert self.aws is not None server = self.aws.provision_instance(subregion, self.aws_instance_class) + elif provider == "azure": + assert self.azure is not None + server = self.azure.provision_instance(subregion, self.azure_instance_class) elif provider == "gcp": assert self.gcp is not None # todo specify network tier in ReplicationTopology @@ -122,7 +151,9 @@ def provision_gateway_instance(region: str) -> Server: raise NotImplementedError(f"Unknown provider {provider}") return server - results = do_parallel(provision_gateway_instance, list(aws_regions_to_provision + gcp_regions_to_provision)) + results = do_parallel( + provision_gateway_instance, list(aws_regions_to_provision + azure_regions_to_provision + gcp_regions_to_provision) + ) instances_by_region = { r: [instance for instance_region, instance in results if instance_region == r] for r in set(regions_to_provision) } @@ -133,6 +164,10 @@ def provision_gateway_instance(region: str) -> Server: if f"aws:{r}" not in instances_by_region: instances_by_region[f"aws:{r}"] = [] instances_by_region[f"aws:{r}"].extend(ilist) + for r, ilist in current_azure_instances.items(): + if f"azure:{r}" not in instances_by_region: + instances_by_region[f"azure:{r}"] = [] + instances_by_region[f"azure:{r}"].extend(ilist) for r, ilist in current_gcp_instances.items(): if f"gcp:{r}" not in instances_by_region: instances_by_region[f"gcp:{r}"] = [] @@ -173,9 +208,8 @@ def deprovision_gateway_instance(server: Server): def run_replication_plan(self, job: ReplicationJob): # assert all(len(path) == 2 for path in self.bound_paths), f"Only two-hop replication is supported" - # todo support GCP - assert job.source_region.split(":")[0] in ["aws", "gcp"], f"Only AWS and GCP is supported for now, got {job.source_region}" - assert job.dest_region.split(":")[0] in ["aws", "gcp"], f"Only AWS and GCP is supported for now, got {job.dest_region}" + assert job.source_region.split(":")[0] in ["aws", "azure", "gcp"], f"Only AWS, Azure, and GCP are supported, but got {job.source_region}" + assert job.dest_region.split(":")[0] in ["aws", "azure", "gcp"], f"Only AWS, Azure, and GCP are supported, but got {job.dest_region}" # make list of chunks chunks = [] diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index aa98493e6..2aba2f392 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -29,8 +29,10 @@ def parse_args(): # gateway provisioning parser.add_argument("--gcp-project", default="skylark-333700", help="GCP project ID") + parser.add_argument("--azure-subscription", default="", help="Azure subscription") parser.add_argument("--gateway-docker-image", default="ghcr.io/parasj/skylark:main", help="Docker image for gateway instances") parser.add_argument("--aws-instance-class", default="m5.4xlarge", help="AWS instance class") + parser.add_argument("--azure-instance-class", default="Standard_D2_v5", help="Azure instance class") parser.add_argument("--gcp-instance-class", default="n2-standard-16", help="GCP instance class") parser.add_argument("--copy-ssh-key", default=None, help="SSH public key to add to gateways") parser.add_argument("--log-dir", default=None, help="Directory to write instance SSH logs to") @@ -40,6 +42,8 @@ def parse_args(): # add support for None arguments if args.aws_instance_class == "None": args.aws_instance_class = None + if args.azure_instance_class == "None": + args.azure_instance_class = None if args.gcp_instance_class == "None": args.gcp_instance_class = None @@ -89,8 +93,10 @@ def main(args): rc = ReplicatorClient( topo, gcp_project=args.gcp_project, + azure_subscription=args.azure_subscription, gateway_docker_image=args.gateway_docker_image, aws_instance_class=args.aws_instance_class, + azure_instance_class=args.azure_instance_class, gcp_instance_class=args.gcp_instance_class, gcp_use_premium_network=args.gcp_use_premium_network, )