diff --git a/scripts/setup_bucket.py b/scripts/setup_bucket.py index a3ebc4b05..f6acb7dc9 100644 --- a/scripts/setup_bucket.py +++ b/scripts/setup_bucket.py @@ -1,5 +1,6 @@ from skylark.obj_store.object_store_interface import ObjectStoreInterface from skylark.utils.utils import do_parallel +from skylark.utils import logger from tqdm import tqdm import os import argparse @@ -28,6 +29,15 @@ def main(args): obj_store_interface_dst = ObjectStoreInterface.create(args.dest_region, dst_bucket) obj_store_interface_dst.create_bucket() + # check for read access + try: + next(obj_store_interface_src.list_objects(args.key_prefix)) + except StopIteration: + pass + except Exception as e: + logger.error(f"Failed to list objects in source bucket {src_bucket}, do you have read access?: {e}") + exit(1) + # query for all keys under key_prefix objs = {obj.key: obj.size for obj in obj_store_interface_src.list_objects(args.key_prefix)} fn_args = [] diff --git a/skylark/__init__.py b/skylark/__init__.py index f364a3312..d2bd93d98 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -45,4 +45,4 @@ def print_header(): if config_path.exists(): cloud_config = SkylarkConfig.load_config(config_path) else: - cloud_config = SkylarkConfig() + cloud_config = SkylarkConfig(False, False, False) diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index f89788a49..ada18f3e5 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -22,16 +22,13 @@ import skylark.cli.cli_gcp import skylark.cli.cli_solver import skylark.cli.experiments -from skylark.obj_store.azure_interface import AzureInterface -from skylark.obj_store.gcs_interface import GCSInterface from skylark.obj_store.object_store_interface import ObjectStoreInterface -from skylark.obj_store.s3_interface import S3Interface from skylark.replicate.solver import ThroughputProblem, ThroughputSolverILP import typer from skylark.config import SkylarkConfig from skylark.utils import logger from skylark.utils.utils import Timer -from skylark import config_path, GB, MB, print_header +from skylark import GB, config_path, print_header from skylark.cli.cli_helper import ( check_ulimit, copy_azure_local, diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index ef90114ba..24ce46c5c 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -1,6 +1,5 @@ import concurrent.futures from functools import partial -import atexit import json import os import re @@ -12,12 +11,11 @@ from typing import Dict, List from sys import platform from typing import Dict, List -from urllib.parse import ParseResultBytes, parse_qs import boto3 import typer -from skylark import config_path, GB, MB, print_header +from skylark import GB, MB from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.gcp.gcp_auth import GCPAuthentication diff --git a/skylark/compute/azure/azure_auth.py b/skylark/compute/azure/azure_auth.py index 050e01129..01c633bc0 100644 --- a/skylark/compute/azure/azure_auth.py +++ b/skylark/compute/azure/azure_auth.py @@ -80,7 +80,8 @@ def get_network_client(self): return NetworkManagementClient(self.credential, self.subscription_id) def get_authorization_client(self): - return AuthorizationManagementClient(self.credential, self.subscription_id) + # set API version to avoid UnsupportedApiVersionForRoleDefinitionHasDataActions error + return AuthorizationManagementClient(self.credential, self.subscription_id, api_version="2018-01-01-preview") def get_storage_management_client(self): return StorageManagementClient(self.credential, self.subscription_id) diff --git a/skylark/compute/azure/azure_cloud_provider.py b/skylark/compute/azure/azure_cloud_provider.py index c701efba1..fd992772b 100644 --- a/skylark/compute/azure/azure_cloud_provider.py +++ b/skylark/compute/azure/azure_cloud_provider.py @@ -9,6 +9,7 @@ from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_server import AzureServer from skylark.compute.cloud_providers import CloudProvider +from azure.mgmt.authorization.models import RoleAssignmentCreateParameters, RoleAssignmentProperties from skylark.utils import logger from skylark.utils.utils import Timer, do_parallel @@ -435,22 +436,23 @@ def provision_instance(self, location: str, vm_size: str, name: Optional[str] = ) vm_result = poller.result() - with Timer("Role assignment"): - # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment - # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets + def grant_vm_role(scope, role_name): auth_client = self.auth.get_authorization_client() - scope = f"/subscriptions/{self.auth.subscription_id}" - role_name = "Contributor" roles = list(auth_client.role_definitions.list(scope, filter="roleName eq '{}'".format(role_name))) assert len(roles) == 1 - # Add RG scope to the MSI identities: - role_assignment = auth_client.role_assignments.create( + auth_client.role_assignments.create( scope, uuid.uuid4(), # Role assignment random name RoleAssignmentCreateParameters( - properties=dict(role_definition_id=roles[0].id, principal_id=vm_result.identity.principal_id) + properties=RoleAssignmentProperties(role_definition_id=roles[0].id, principal_id=vm_result.identity.principal_id) ), ) + with Timer("Role assignment"): + # Assign roles to system MSI, see https://docs.microsoft.com/en-us/samples/azure-samples/compute-python-msi-vm/compute-python-msi-vm/#role-assignment + # todo only grant storage-blob-data-reader and storage-blob-data-writer for specified buckets + scope = f"/subscriptions/{self.auth.subscription_id}" + grant_vm_role(scope, "Storage Blob Data Contributor") + return AzureServer(name) diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 4be2b5ffd..a14948d95 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import paramiko diff --git a/skylark/gateway/gateway_obj_store.py b/skylark/gateway/gateway_obj_store.py index d0d710cfe..f24b65bb2 100644 --- a/skylark/gateway/gateway_obj_store.py +++ b/skylark/gateway/gateway_obj_store.py @@ -1,6 +1,5 @@ from functools import partial import queue -import threading from multiprocessing import Event, Manager, Process, Value from typing import Dict, Optional diff --git a/skylark/obj_store/azure_interface.py b/skylark/obj_store/azure_interface.py index 46cb54cf5..12dd3d24f 100644 --- a/skylark/obj_store/azure_interface.py +++ b/skylark/obj_store/azure_interface.py @@ -1,11 +1,14 @@ import os -from concurrent.futures import Future, ThreadPoolExecutor +import subprocess from typing import Iterator, List +import uuid from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_server import AzureServer from skylark.utils import logger from skylark.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject +from azure.mgmt.authorization.models import RoleAssignmentCreateParameters, RoleAssignmentProperties +from azure.identity import AzureCliCredential class AzureObject(ObjectStoreObject): @@ -70,17 +73,55 @@ def create_storage_account(self, tier="Premium_LRS"): except ResourceExistsError: logger.warning("Unable to create storage account as it already exists") + def grant_storage_account_access(self, role_name: str, principal_id: str = None): + # lookup role + auth_client = self.auth.get_authorization_client() + scope = f"/subscriptions/{self.auth.subscription_id}/resourceGroups/{AzureServer.resource_group_name}/providers/Microsoft.Storage/storageAccounts/{self.account_name}" + roles = list(auth_client.role_definitions.list(scope, filter="roleName eq '{}'".format(role_name))) + assert len(roles) == 1 + + # lookup principal + if principal_id is None: + self.auth.credential.get_token("https://graph.windows.net") # must request token to attempt to load credential + if isinstance(self.auth.credential._successful_credential, AzureCliCredential): + principal_id = ( + subprocess.check_output(["az", "ad", "signed-in-user", "show", "--query", "objectId", "-o", "tsv"]) + .decode("utf-8") + .strip() + ) + else: + logger.error(f"Unable to determine principal ID for role assignment for {scope}, cannot automatically grant access") + return + + # query for existing role assignment + matches = [] + for assignment in auth_client.role_assignments.list_for_scope(scope, filter="principalId eq '{}'".format(principal_id)): + if assignment.role_definition_id == roles[0].id: + matches.append(assignment) + if len(matches) == 0: + logger.debug(f"Granting access to {principal_id} for role {role_name} on storage account {self.account_name}") + role_assignment = auth_client.role_assignments.create( + scope, + uuid.uuid4(), # Role assignment random name + RoleAssignmentCreateParameters( + properties=RoleAssignmentProperties(role_definition_id=roles[0].id, principal_id=principal_id) + ), + ) + def create_container(self): try: self.container_client.create_container() except ResourceExistsError: - logger.warning("Unable to create container as it already exists") + logger.warning(f"Unable to create container {self.container_name} as it already exists") def create_bucket(self, premium_tier=True): tier = "Premium_LRS" if premium_tier else "Standard_LRS" if not self.storage_account_exists(): + logger.debug(f"Creating storage account {self.account_name}") self.create_storage_account(tier=tier) + self.grant_storage_account_access("Storage Blob Data Contributor") if not self.container_exists(): + logger.debug(f"Creating container {self.container_name}") self.create_container() def delete_container(self): diff --git a/skylark/obj_store/gcs_interface.py b/skylark/obj_store/gcs_interface.py index 80b811dc5..3ef643a34 100644 --- a/skylark/obj_store/gcs_interface.py +++ b/skylark/obj_store/gcs_interface.py @@ -1,6 +1,5 @@ import mimetypes import os -from concurrent.futures import Future, ThreadPoolExecutor from typing import Iterator, List from google.cloud import storage # pytype: disable=import-error diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index d91aa623c..2faaa26ad 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -2,7 +2,6 @@ import os from typing import Iterator, List -from concurrent.futures import Future import botocore.exceptions from awscrt.auth import AwsCredentialsProvider from awscrt.http import HttpHeaders, HttpRequest