Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Microsoft Azure #55

Merged
merged 1 commit into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
packages=["skylark"],
python_requires=">=3.8",
install_requires=[
"azure-mgmt-resource",
"azure-mgmt-compute",
"azure-mgmt-network",
"azure-identity",
"awscrt",
"boto3",
"flask",
Expand Down
17 changes: 13 additions & 4 deletions skylark/benchmark/stop_all_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider
from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider
from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skylark.compute.server import Server
from skylark.utils.utils import do_parallel
Expand All @@ -16,19 +17,27 @@ def stop_instance(instance: Server):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stop all instances")
parser.add_argument("--disable-aws", action="store_true", help="Disables AWS operations if present")
parser.add_argument("--gcp-project", type=str, help="GCP project", default=None)
parser.add_argument("--azure-subscription", type=str, help="Microsoft Azure Subscription", default=None)
args = parser.parse_args()

instances = []

logger.info("Getting matching AWS instances")
aws = AWSCloudProvider()
for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True):
instances += instance_list
if not args.disable_aws:
logger.info("Getting matching AWS instances")
aws = AWSCloudProvider()
for _, instance_list in do_parallel(aws.get_matching_instances, aws.region_list(), progress_bar=True):
instances += instance_list

if args.gcp_project:
logger.info("Getting matching GCP instances")
gcp = GCPCloudProvider(gcp_project=args.gcp_project)
instances += gcp.get_matching_instances()

if args.azure_subscription:
logger.info("Getting matching Azure instances")
azure = AzureCloudProvider(azure_subscription=args.azure_subscription)
instances += azure.get_matching_instances()

do_parallel(stop_instance, instances, progress_bar=True)
8 changes: 6 additions & 2 deletions skylark/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ def replicate_random(
chunk_size_mb: int = 8,
n_chunks: int = 2048,
reuse_gateways: bool = True,
azure_subscription: str = "skylark-invalid-subscription",
gcp_project: str = "skylark-333700",
gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/parasj/skylark:main"),
aws_instance_class: str = "m5.8xlarge",
azure_instance_class: str = "Standard_D2_v5",
gcp_instance_class: Optional[str] = "n2-highmem-4",
gcp_use_premium_network: bool = False,
key_prefix: str = "/test/replicate_random",
Expand All @@ -105,9 +107,11 @@ def replicate_random(
num_conn = num_outgoing_connections
rc = ReplicatorClient(
topo,
azure_subscription=azure_subscription,
gcp_project=gcp_project,
gateway_docker_image=gateway_docker_image,
aws_instance_class=aws_instance_class,
azure_instance_class=azure_instance_class,
gcp_instance_class=gcp_instance_class,
gcp_use_premium_network=gcp_use_premium_network,
)
Expand Down Expand Up @@ -155,9 +159,9 @@ def replicate_random(


@app.command()
def deprovision(gcp_project: Optional[str] = None):
def deprovision(azure_subscription: Optional[str] = None, gcp_project: Optional[str] = None):
"""Deprovision gateways."""
deprovision_skylark_instances(gcp_project_id=gcp_project)
deprovision_skylark_instances(azure_subscription=azure_subscription, gcp_project_id=gcp_project)


if __name__ == "__main__":
Expand Down
23 changes: 16 additions & 7 deletions skylark/cli/cli_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm import tqdm
import typer
from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider
from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider
from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skylark.obj_store.object_store_interface import ObjectStoreObject

Expand Down Expand Up @@ -135,21 +136,29 @@ def _copy(src_obj: ObjectStoreObject, dst: Path):
# utility functions


def deprovision_skylark_instances(gcp_project_id: Optional[str] = None):
def deprovision_skylark_instances(azure_subscription: Optional[str] = None, gcp_project_id: Optional[str] = None):
instances = []

if not gcp_project_id:
typer.secho("No GCP project ID given, so will only deprovision AWS instances", color=typer.colors.YELLOW, bold=True)
else:
gcp = GCPCloudProvider(gcp_project=gcp_project_id)
instances += gcp.get_matching_instances()

aws = AWSCloudProvider()
for _, instance_list in do_parallel(
aws.get_matching_instances, aws.region_list(), progress_bar=True, leave_pbar=False, desc="Retrieve AWS instances"
):
instances += instance_list

if not azure_subscription:
typer.secho(
"No Microsoft Azure subscription given, so Azure instances will not be terminated", color=typer.colors.YELLOW, bold=True
)
else:
azure = AzureCloudProvider(azure_subscription=azure_subscription)
instances += azure.get_matching_instances()

if not gcp_project_id:
typer.secho("No GCP project ID given, so GCP instances will not be deprovisioned", color=typer.colors.YELLOW, bold=True)
else:
gcp = GCPCloudProvider(gcp_project=gcp_project_id)
instances += gcp.get_matching_instances()

if instances:
typer.secho(f"Deprovisioning {len(instances)} instances", color=typer.colors.YELLOW, bold=True)
do_parallel(lambda instance: instance.terminate_instance(), instances, progress_bar=True, desc="Deprovisioning")
Expand Down
5 changes: 1 addition & 4 deletions skylark/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
class AWSCloudProvider(CloudProvider):
def __init__(self):
super().__init__()
# Ubuntu deep learning AMI
# https://aws.amazon.com/marketplace/pp/prodview-dxk3xpeg6znhm
self.ami_alias = "resolve:ssm:/aws/service/marketplace/prod-oivea5digmbj6/latest"

@property
def name(self):
Expand Down Expand Up @@ -264,7 +261,7 @@ def provision_instance(
for i in range(4):
try:
instance = ec2.create_instances(
ImageId=self.ami_alias,
ImageId=self.get_ubuntu_ami_id(region),
InstanceType=instance_class,
MinCount=1,
MaxCount=1,
Expand Down
244 changes: 244 additions & 0 deletions skylark/compute/azure/azure_cloud_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import os
import uuid

from pathlib import Path
from typing import List, Optional

from azure.identity import DefaultAzureCredential
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.compute import ComputeManagementClient

import paramiko

from loguru import logger

from skylark import key_root
from skylark.compute.cloud_providers import CloudProvider
from skylark.compute.azure.azure_server import AzureServer
from skylark.utils.utils import Timer


class AzureCloudProvider(CloudProvider):
def __init__(self, azure_subscription, key_root=key_root / "azure"):
super().__init__()
self.subscription_id = azure_subscription
key_root.mkdir(parents=True, exist_ok=True)
self.private_key_path = key_root / "azure_key"
self.public_key_path = key_root / "azure_key.pub"

@property
def name(self):
return "azure"

@staticmethod
def region_list():
return [
"eastasia",
"southeastasia",
"centralus",
"eastus",
"eastus2",
"westus",
"northcentralus",
"southcentralus",
"northeurope",
"westeurope",
"japanwest",
"japaneast",
"brazilsouth",
"australiaeast",
"australiasoutheast",
"southindia",
"centralindia",
"westindia",
"jioindiawest",
"jioindiacentral",
"canadacentral",
"canadaeast",
"uksouth",
"ukwest",
"westcentralus",
"westus2",
"koreacentral",
"koreasouth",
"francecentral",
"francesouth",
"australiacentral",
"australiacentral2",
"uaecentral",
"uaenorth",
"southafricanorth",
"southafricawest",
"switzerlandnorth",
"switzerlandwest",
"germanynorth",
"germanywestcentral",
"norwaywest",
"norwayeast",
"brazilsoutheast",
"westus3",
"swedencentral",
]

@staticmethod
def get_resource_group_name(name):
return name

@staticmethod
def get_transfer_cost(src_key, dst_key):
raise NotImplementedError

def get_instance_list(self, region: str) -> List[AzureServer]:
credential = DefaultAzureCredential()
resource_client = ResourceManagementClient(credential, self.subscription_id)
resource_group_list_iterator = resource_client.resource_groups.list(filter="tagName eq 'skylark' and tagValue eq 'true'")

server_list = []
for resource_group in resource_group_list_iterator:
if resource_group.location == region:
s = AzureServer(self.subscription_id, resource_group.name)
if s.is_valid():
server_list.append(s)
else:
logger.warning(
f"Warning: malformed Azure resource group {resource_group.name} found and ignored. You should go to the Microsoft Azure portal, investigate this manually, and delete any orphaned resources that may be allocated."
)
return server_list

# Copied from gcp_cloud_provider.py --- consolidate later?
def create_ssh_key(self):
private_key_path = Path(self.private_key_path)
if not private_key_path.exists():
private_key_path.parent.mkdir(parents=True, exist_ok=True)
key = paramiko.RSAKey.generate(4096)
key.write_private_key_file(self.private_key_path, password="skylark")
with open(self.public_key_path, "w") as f:
f.write(f"{key.get_name()} {key.get_base64()}\n")

# This code, along with some code in azure_server.py, is based on
# https://github.com/ucbrise/mage-scripts/blob/main/azure_cloud.py.
def provision_instance(
self,
location: str,
vm_size: str,
name: Optional[str] = None,
uname: str = os.environ.get("USER"),
) -> AzureServer:
assert ":" not in location, "invalid colon in Azure location"
if name is None:
name = f"skylark-azure-{str(uuid.uuid4()).replace('-', '')}"

with open(os.path.expanduser(self.public_key_path)) as f:
pub_key = f.read()

# Prepare for making Microsoft Azure API calls
credential = DefaultAzureCredential()
compute_client = ComputeManagementClient(credential, self.subscription_id)
network_client = NetworkManagementClient(credential, self.subscription_id)
resource_client = ResourceManagementClient(credential, self.subscription_id)

# Create a resource group for this instance
resource_group = name
if resource_client.resource_groups.check_existence(resource_group):
raise RuntimeError('Cannot spawn instance "{0}": instance already exists'.format(name))
rg_result = resource_client.resource_groups.create_or_update(resource_group, {"location": location, "tags": {"skylark": "true"}})
assert rg_result.name == resource_group

# Create a Virtual Network for this instance
poller = network_client.virtual_networks.begin_create_or_update(
resource_group, AzureServer.vnet_name(name), {"location": location, "address_space": {"address_prefixes": ["10.0.0.0/24"]}}
)
vnet_result = poller.result()

# Create a Network Security Group for this instance
# NOTE: This is insecure. We should fix this soon.
poller = network_client.network_security_groups.begin_create_or_update(
resource_group,
AzureServer.nsg_name(name),
{
"location": location,
"security_rules": [
{
"name": name + "-allow-all",
"protocol": "Tcp",
"source_port_range": "*",
"source_address_prefix": "*",
"destination_port_range": "*",
"destination_address_prefix": "*",
"access": "Allow",
"priority": 300,
"direction": "Inbound",
}
# Azure appears to add default rules for outbound connections
],
},
)
nsg_result = poller.result()

# Create a subnet for this instance with the above Network Security Group
poller = network_client.subnets.begin_create_or_update(
resource_group,
AzureServer.vnet_name(name),
AzureServer.subnet_name(name),
{"address_prefix": "10.0.0.0/26", "network_security_group": {"id": nsg_result.id}},
)
subnet_result = poller.result()

# Create a public IPv4 address for this instance
poller = network_client.public_ip_addresses.begin_create_or_update(
resource_group,
AzureServer.ip_name(name),
{
"location": location,
"sku": {"name": "Standard"},
"public_ip_allocation_method": "Static",
"public_ip_address_version": "IPV4",
},
)
public_ip_result = poller.result()

# Create a NIC for this instance, with accelerated networking enabled
poller = network_client.network_interfaces.begin_create_or_update(
resource_group,
AzureServer.nic_name(name),
{
"location": location,
"ip_configurations": [
{"name": name + "-ip", "subnet": {"id": subnet_result.id}, "public_ip_address": {"id": public_ip_result.id}}
],
"enable_accelerated_networking": True,
},
)
nic_result = poller.result()

# Create the VM
poller = compute_client.virtual_machines.begin_create_or_update(
resource_group,
AzureServer.vm_name(name),
{
"location": location,
"zones": ["1"],
"hardware_profile": {"vm_size": vm_size},
"storage_profile": {
"image_reference": {
"publisher": "canonical",
"offer": "0001-com-ubuntu-server-focal",
"sku": "20_04-lts",
"version": "latest",
},
},
"os_profile": {
"computer_name": AzureServer.vm_name(name),
"admin_username": uname,
"linux_configuration": {
"disable_password_authentication": True,
"ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]},
},
},
"network_profile": {"network_interfaces": [{"id": nic_result.id}]},
},
)
vm_result = poller.result()

return AzureServer(self.subscription_id, resource_group)
Loading