diff --git a/skylark/cli/cli.py b/skylark/cli/cli.py index 8b4439ddd..98a5e9013 100644 --- a/skylark/cli/cli.py +++ b/skylark/cli/cli.py @@ -21,6 +21,7 @@ import skylark.cli.cli_aws import skylark.cli.cli_azure +import skylark.cli.cli_gcp import skylark.cli.cli_solver import skylark.cli.experiments import typer @@ -51,6 +52,7 @@ app.add_typer(skylark.cli.experiments.app, name="experiments") app.add_typer(skylark.cli.cli_aws.app, name="aws") app.add_typer(skylark.cli.cli_azure.app, name="azure") +app.add_typer(skylark.cli.cli_gcp.app, name="gcp") app.add_typer(skylark.cli.cli_solver.app, name="solver") diff --git a/skylark/cli/cli_aws.py b/skylark/cli/cli_aws.py index 2b658e675..155bbc871 100644 --- a/skylark/cli/cli_aws.py +++ b/skylark/cli/cli_aws.py @@ -40,7 +40,7 @@ def ssh(region: Optional[str] = None): typer.secho("Querying AWS for instances", fg="green") instances = aws.get_matching_instances(region=region) if len(instances) == 0: - typer.secho(f"No instancess found", fg="red") + typer.secho(f"No instances found", fg="red") typer.Abort() instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} @@ -48,7 +48,8 @@ def ssh(region: Optional[str] = None): instance_name: AWSServer = questionary.select("Select an instance", choices=choices).ask() if instance_name is not None and instance_name in instance_map: instance = instance_map[instance_name] - proc = subprocess.Popen(split(f"ssh -i {str(instance.local_keyfile)} ec2-user@{instance.public_ip()}")) + cmd = instance.get_ssh_cmd() + proc = subprocess.Popen(split(cmd)) proc.wait() else: typer.secho(f"No instance selected", fg="red") diff --git a/skylark/cli/cli_gcp.py b/skylark/cli/cli_gcp.py new file mode 100644 index 000000000..5c3f3b692 --- /dev/null +++ b/skylark/cli/cli_gcp.py @@ -0,0 +1,41 @@ +import os +import subprocess +from shlex import split +from typing import Optional + +import questionary +import typer +from skylark.config import load_config + +from skylark.utils import logger +from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider +from skylark.compute.gcp.gcp_server import GCPServer + +app = typer.Typer(name="skylark-gcp") + + +@app.command() +def ssh( + region: Optional[str] = None, + gcp_project: str = typer.Option("", "--gcp-project", help="GCP project ID"), +): + config = load_config() + gcp_project = gcp_project or config.get("gcp_project_id") + typer.secho(f"Loaded from config file: gcp_project={gcp_project}", fg="blue") + gcp = GCPCloudProvider(gcp_project) + typer.secho("Querying GCP for instances", fg="green") + instances = gcp.get_matching_instances(region=region) + if len(instances) == 0: + typer.secho(f"No instances found", fg="red") + typer.Abort() + + instance_map = {f"{i.region()}, {i.public_ip()} ({i.instance_state()})": i for i in instances} + choices = list(sorted(instance_map.keys())) + instance_name: GCPServer = questionary.select("Select an instance", choices=choices).ask() + if instance_name is not None and instance_name in instance_map: + cmd = instance_map[instance_name].get_ssh_cmd() + typer.secho(cmd, fg="green") + proc = subprocess.Popen(split(cmd)) + proc.wait() + else: + typer.secho(f"No instance selected", fg="red") diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 7bed09007..1b3e0677b 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -122,3 +122,6 @@ def get_ssh_client_impl(self): banner_timeout=200, ) return client + + def get_ssh_cmd(self): + return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}" diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 49ba12a26..872d2513f 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -247,7 +247,7 @@ def wait_for_operation_to_complete(self, zone, operation_name, timeout=120): time.sleep(time_intervals.pop(0)) def provision_instance( - self, region, instance_class, name=None, premium_network=False, uname=os.environ.get("USER"), tags={"skylark": "true"} + self, region, instance_class, name=None, premium_network=False, uname="skylark", tags={"skylark": "true"} ) -> GCPServer: assert not region.startswith("gcp:"), "Region should be GCP region" if name is None: diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index e79134f96..f1b2678fa 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -96,16 +96,19 @@ def terminate_instance_impl(self): compute = self.get_gcp_client() compute.instances().delete(project=self.gcp_project, zone=self.gcp_region, instance=self.instance_name()).execute() - def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="skylark"): + def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): """Return paramiko client that connects to this instance.""" ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh_client.connect( hostname=self.public_ip(), username=uname, - key_filename=str(self.ssh_private_key), - passphrase=ssh_key_password, + pkey=paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password=ssh_key_password), look_for_keys=False, banner_timeout=200, ) return ssh_client + + def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): + # todo can we include the key password inline? + return f"ssh -i {self.ssh_private_key} -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no {uname}@{self.public_ip()}" diff --git a/skylark/compute/server.py b/skylark/compute/server.py index f95291738..aec54378d 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -93,6 +93,9 @@ def init_log_files(self, log_dir): def get_ssh_client_impl(self): raise NotImplementedError() + def get_ssh_cmd(self) -> str: + raise NotImplementedError() + @property def ssh_client(self): """Create SSH client and cache.""" @@ -212,7 +215,7 @@ def check_stderr(tup): # copy config file config = config_file.read_text()[:-2] + "}" config = json.dumps(config) # Convert to JSON string and remove trailing comma/new-line - self.run_command(f'mkdir -p /opt; echo "{config}" | sudo tee /opt/{config_file.name} > /dev/null') + self.run_command(f'mkdir -p /tmp; echo "{config}" | sudo tee /tmp/{config_file.name} > /dev/null') docker_envs = "" # If needed, add environment variables to docker command @@ -224,7 +227,7 @@ def check_stderr(tup): f"-d --rm --log-driver=local --log-opt max-file=16 --ipc=host --network=host --ulimit nofile={1024 * 1024} {docker_envs}" ) docker_run_flags += " --mount type=tmpfs,dst=/skylark,tmpfs-size=$(($(free -b | head -n2 | tail -n1 | awk '{print $2}')/2))" - docker_run_flags += f" -v /opt/{config_file.name}:/pkg/data/{config_file.name}" + docker_run_flags += f" -v /tmp/{config_file.name}:/pkg/data/{config_file.name}" gateway_daemon_cmd = f"python -u /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /skylark/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" start_out, start_err = self.run_command(docker_launch_cmd)