Skip to content

Commit

Permalink
Cleanup __init__.py (#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Nov 2, 2022
1 parent f064cab commit 2e3b2c0
Show file tree
Hide file tree
Showing 35 changed files with 141 additions and 105 deletions.
4 changes: 2 additions & 2 deletions scripts/plot_socket_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import matplotlib.pyplot as plt # type: ignore
from tqdm import tqdm

from skyplane import skyplane_root
from skyplane import __root__


def plot(file):
Expand All @@ -28,7 +28,7 @@ def plot(file):
parser = argparse.ArgumentParser()
parser.add_argument("profile_file", help="Path to the profile file")
parser.add_argument(
"--plot_dir", default=skyplane_root / "data" / "figures" / "socket_profiles", help="Path to the directory where to save the plot"
"--plot_dir", default=__root__ / "data" / "figures" / "socket_profiles", help="Path to the directory where to save the plot"
)
args = parser.parse_args()

Expand Down
60 changes: 14 additions & 46 deletions skyplane/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,22 @@
import os
from pathlib import Path

from skyplane.config import SkyplaneConfig
from skyplane.gateway_version import gateway_version

# version
__version__ = "0.2.1"

# paths
skyplane_root = Path(__file__).parent.parent
config_root = Path("~/.skyplane").expanduser()
config_root.mkdir(exist_ok=True)

if "SKYPLANE_CONFIG" in os.environ:
config_path = Path(os.environ["SKYPLANE_CONFIG"]).expanduser()
else:
config_path = config_root / "config"

aws_config_path = config_root / "aws_config"
azure_config_path = config_root / "azure_config"
azure_sku_path = config_root / "azure_sku_mapping"
gcp_config_path = config_root / "gcp_config"

key_root = config_root / "keys"
__root__ = Path(__file__).parent.parent
__config_root__ = Path("~/.skyplane").expanduser()
__config_root__.mkdir(exist_ok=True)
tmp_log_dir = Path("/tmp/skyplane")
tmp_log_dir.mkdir(exist_ok=True)

# definitions
KB = 1024
MB = 1024 * 1024
GB = 1024 * 1024 * 1024


def format_bytes(bytes: int):
if bytes < KB:
return f"{bytes}B"
elif bytes < MB:
return f"{bytes / KB:.2f}KB"
elif bytes < GB:
return f"{bytes / MB:.2f}MB"
else:
return f"{bytes / GB:.2f}GB"


if config_path.exists():
cloud_config = SkyplaneConfig.load_config(config_path)
else:
cloud_config = SkyplaneConfig.default_config()
is_gateway_env = os.environ.get("SKYPLANE_IS_GATEWAY", None) == "1"

# load gateway docker image version
def gateway_docker_image():
return "public.ecr.aws/s6m1p0n8/skyplane:" + gateway_version
__all__ = [
"__root__",
"__config_root__",
"__version__",
"SkyplaneClient",
"Dataplane",
"TransferConfig",
"AWSConfig",
"AzureConfig",
"GCPConfig",
]
26 changes: 26 additions & 0 deletions skyplane/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import functools
import os
from pathlib import Path

from skyplane import __config_root__
from skyplane.config import SkyplaneConfig


@functools.lru_cache
def load_config_path():
if "SKYPLANE_CONFIG" in os.environ:
return Path(os.environ["SKYPLANE_CONFIG"]).expanduser()
else:
return __config_root__ / "config"


@functools.lru_cache
def load_cloud_config(path):
if path.exists():
return SkyplaneConfig.load_config(path)
else:
return SkyplaneConfig.default_config()


config_path = load_config_path()
cloud_config = load_cloud_config(config_path)
12 changes: 5 additions & 7 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import skyplane.cli.usage.client
import skyplane.cli.usage.definitions
import skyplane.cli.usage.definitions
from skyplane import GB, cloud_config, config_path, exceptions, skyplane_root
from skyplane import exceptions, __root__
from skyplane.cli import config_path, cloud_config
from skyplane.utils.definitions import GB
from skyplane.cli.cli_impl.cp_replicate import (
confirm_transfer,
enrich_dest_objs,
Expand Down Expand Up @@ -76,9 +78,7 @@ def cp(
# solver
solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"),
solver_target_tput_per_vm_gbits: float = typer.Option(4, help="Solver option: Required throughput in Gbps"),
solver_throughput_grid: Path = typer.Option(
skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"
),
solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"),
solver_verbose: bool = False,
):
"""
Expand Down Expand Up @@ -278,9 +278,7 @@ def sync(
# solver
solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"),
solver_target_tput_per_vm_gbits: float = typer.Option(4, help="Solver option: Required throughput in Gbps per instance"),
solver_throughput_grid: Path = typer.Option(
skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"
),
solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"),
solver_verbose: bool = False,
):
"""
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/cli_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import typer

from skyplane import GB
from skyplane.utils.definitions import GB
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider
from skyplane.obj_store.s3_interface import S3Interface
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/cli_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
from skyplane.utils.fn import do_parallel
from skyplane.utils import logger
from skyplane import cloud_config
from skyplane.cli import cloud_config
from rich import print as rprint

from skyplane.compute.azure.azure_auth import AzureAuthentication
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import typer

from skyplane import cloud_config, config_path
from skyplane.cli import config_path, cloud_config
from skyplane.cli.common import console
from skyplane.cli.usage.client import UsageClient

Expand Down
6 changes: 4 additions & 2 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import typer
from rich import print as rprint

from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root, cloud_config
from skyplane import exceptions, __root__
from skyplane.cli import cloud_config
from skyplane.utils.definitions import GB, format_bytes, gateway_docker_image
from skyplane.cli.common import console
from skyplane.cli.usage.client import UsageClient
from skyplane.compute.cloud_providers import CloudProvider
Expand All @@ -31,7 +33,7 @@ def generate_topology(
solver_class: str = "ILP",
solver_total_gbyte_to_transfer: Optional[float] = None,
solver_target_tput_per_vm_gbits: Optional[float] = None,
solver_throughput_grid: Optional[pathlib.Path] = skyplane_root / "profiles" / "throughput.csv",
solver_throughput_grid: Optional[pathlib.Path] = __root__ / "profiles" / "throughput.csv",
solver_verbose: Optional[bool] = False,
args: Optional[Dict] = None,
) -> ReplicationTopology:
Expand Down
6 changes: 3 additions & 3 deletions skyplane/cli/cli_impl/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from rich.progress import Progress, SpinnerColumn, TextColumn
import questionary

from skyplane import SkyplaneConfig, aws_config_path, gcp_config_path
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.aws.aws_auth import AWSAuthentication, aws_config_path
from skyplane.compute.azure.azure_auth import AzureAuthentication
from skyplane.compute.azure.azure_server import AzureServer
from skyplane.compute.gcp.gcp_auth import GCPAuthentication
from skyplane.compute.gcp.gcp_auth import GCPAuthentication, gcp_config_path
from skyplane.config import SkyplaneConfig


def load_aws_config(config: SkyplaneConfig, non_interactive: bool = False) -> SkyplaneConfig:
Expand Down
6 changes: 2 additions & 4 deletions skyplane/cli/cli_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import typer

from skyplane import skyplane_root
from skyplane import __root__
from skyplane.cli.cli_impl.cp_replicate import confirm_transfer, launch_replication_job
from skyplane.cli.common import print_header
from skyplane.obj_store.object_store_interface import ObjectStoreObject
Expand Down Expand Up @@ -91,9 +91,7 @@ def replicate_random_solve(
reuse_gateways: bool = False,
solve: bool = typer.Option(False, help="If true, will use solver to optimize transfer, else direct path is chosen"),
throughput_per_instance_gbits: float = typer.Option(2, help="Solver option: Required throughput in gbps."),
solver_throughput_grid: Path = typer.Option(
skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"
),
solver_throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"),
solver_verbose: bool = False,
debug: bool = False,
):
Expand Down
7 changes: 4 additions & 3 deletions skyplane/cli/experiments/cli_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import typer
from rich.progress import Progress

from skyplane import GB, skyplane_root
from skyplane import __root__
from skyplane.utils.definitions import GB
from skyplane.cli.experiments.provision import provision
from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
Expand Down Expand Up @@ -240,7 +241,7 @@ def setup(server: Server):
experiment_tag_words = os.popen("bash scripts/get_random_word_hash.sh").read().strip()
timestamp = datetime.now(timezone.utc).strftime("%Y.%m.%d_%H.%M")
experiment_tag = f"{timestamp}_{experiment_tag_words}_{iperf3_runtime}s_{iperf3_connections}c"
data_dir = skyplane_root / "data"
data_dir = __root__ / "data"
log_dir = data_dir / "logs" / "throughput_grid" / f"{experiment_tag}"
raw_iperf3_log_dir = log_dir / "raw_iperf3_logs"

Expand Down Expand Up @@ -433,7 +434,7 @@ def setup(server: Server):
experiment_tag_words = os.popen("bash scripts/get_random_word_hash.sh").read().strip()
timestamp = datetime.now(timezone.utc).strftime("%Y.%m.%d_%H.%M")
experiment_tag = f"{timestamp}_{experiment_tag_words}"
data_dir = skyplane_root / "data"
data_dir = __root__ / "data"
log_dir = data_dir / "logs" / "latency_grid" / f"{experiment_tag}"

# ask for confirmation
Expand Down
8 changes: 4 additions & 4 deletions skyplane/cli/experiments/cli_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typer

from skyplane import skyplane_root
from skyplane import __root__
from skyplane.replicate.solver import ThroughputSolver


Expand All @@ -11,7 +11,7 @@ def util_grid_throughput(
dest: str,
src_tier: str = "PREMIUM",
dest_tier: str = "PREMIUM",
throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"),
throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"),
):
solver = ThroughputSolver(throughput_grid)
print(solver.get_path_throughput(src, dest, src_tier, dest_tier) / 2**30)
Expand All @@ -22,7 +22,7 @@ def util_grid_cost(
dest: str,
src_tier: str = "PREMIUM",
dest_tier: str = "PREMIUM",
throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"),
throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"),
):
solver = ThroughputSolver(throughput_grid)
print(solver.get_path_cost(src, dest, src_tier, dest_tier))
Expand All @@ -42,7 +42,7 @@ def get_max_throughput(region_tag: str):


def dump_full_util_cost_grid(
throughput_grid: Path = typer.Option(skyplane_root / "profiles" / "throughput.csv", help="Throughput grid file"),
throughput_grid: Path = typer.Option(__root__ / "profiles" / "throughput.csv", help="Throughput grid file"),
):
solver = ThroughputSolver(throughput_grid)
regions = solver.get_regions()
Expand Down
3 changes: 2 additions & 1 deletion skyplane/cli/usage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from rich import print as rprint

import skyplane.cli.usage.definitions
from skyplane import cloud_config, config_path, tmp_log_dir
from skyplane import tmp_log_dir
from skyplane.cli import config_path, cloud_config
from skyplane.config import _map_type
from skyplane.replicate.replicator_client import TransferStats
from skyplane.utils import logger
Expand Down
6 changes: 4 additions & 2 deletions skyplane/compute/aws/aws_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

from skyplane import aws_config_path
from skyplane import config_path
from skyplane import __config_root__
from skyplane.cli import config_path
from skyplane.config import SkyplaneConfig
from skyplane.utils import imports

aws_config_path = __config_root__ / "aws_config"


class AWSAuthentication:
def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional[str] = None, secret_key: Optional[str] = None):
Expand Down
2 changes: 1 addition & 1 deletion skyplane/compute/aws/aws_key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

from skyplane import exceptions as skyplane_exceptions
from skyplane import key_root
from skyplane.compute.server import key_root
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.utils import logger

Expand Down
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_pricing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from skyplane import skyplane_root
from skyplane import __root__
from skyplane.utils import logger

try:
Expand All @@ -16,7 +16,7 @@ def __init__(self):
def transfer_df(self):
if pd:
if not self._transfer_df:
self._transfer_df = pd.read_csv(skyplane_root / "profiles" / "aws_transfer_costs.csv").set_index(["src", "dst"])
self._transfer_df = pd.read_csv(__root__ / "profiles" / "aws_transfer_costs.csv").set_index(["src", "dst"])
return self._transfer_df
else:
return None
Expand Down
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import paramiko

from skyplane import exceptions, key_root
from skyplane import exceptions
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.server import Server, ServerState
from skyplane.compute.server import Server, ServerState, key_root
from skyplane.utils import imports
from skyplane.utils.cache import ignore_lru_cache

Expand Down
11 changes: 7 additions & 4 deletions skyplane/compute/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import subprocess
from typing import Dict, List, Optional

from skyplane import azure_config_path
from skyplane import azure_sku_path
from skyplane import config_path
from skyplane import is_gateway_env
from skyplane import __config_root__
from skyplane.cli import config_path
from skyplane.utils.definitions import is_gateway_env
from skyplane.compute.const_cmds import query_which_cloud
from skyplane.config import SkyplaneConfig
from skyplane.utils import imports
Expand Down Expand Up @@ -168,3 +167,7 @@ def get_container_client(ContainerClient, self, account_url: str, container_name
@imports.inject("azure.storage.blob.BlobServiceClient", pip_extra="azure")
def get_blob_service_client(BlobServiceClient, self, account_url: str):
return BlobServiceClient(account_url=account_url, credential=self.credential)


azure_config_path = __config_root__ / "azure_config"
azure_sku_path = __config_root__ / "azure_sku_mapping"
4 changes: 3 additions & 1 deletion skyplane/compute/azure/azure_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import paramiko

from skyplane import cloud_config, exceptions, key_root
from skyplane import exceptions
from skyplane.cli import cloud_config
from skyplane.compute.azure.azure_auth import AzureAuthentication
from skyplane.compute.azure.azure_server import AzureServer
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.compute.server import key_root
from skyplane.utils import logger, imports
from skyplane.utils.timer import Timer

Expand Down
Loading

0 comments on commit 2e3b2c0

Please sign in to comment.