Skip to content

Commit

Permalink
Improve performance of gateway by improving performance of GatewayDae…
Browse files Browse the repository at this point in the history
…monAPI SSL (#418)

* Migrate retry_requests to urllib3 PoolManager

* Start implementing stunnel

* Revert to werkzeug and launch stunnel on container launch

* Update launch script

* Run in bash

* Escape bash launch command

* Format code
  • Loading branch information
parasj authored Jun 16, 2022
1 parent ab46a41 commit 84623f4
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 94 deletions.
27 changes: 20 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
# syntax=docker/dockerfile:1
FROM python:3.10-slim

# install apt packages
RUN --mount=type=cache,target=/var/cache/apt apt update \
&& apt-get install --no-install-recommends -y curl ca-certificates stunnel4 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# configure stunnel
RUN mkdir -p /etc/stunnel \
&& openssl genrsa -out key.pem 2048 \
&& openssl req -new -x509 -key key.pem -out cert.pem -days 1095 -subj "/C=US/ST=California/L=San Francisco" \
&& cat key.pem cert.pem >> /etc/stunnel/stunnel.pem \
&& rm key.pem cert.pem \
&& mkdir -p /usr/local/var/run/ \
&& echo "client = no" >> /etc/stunnel/stunnel.conf \
&& echo "[gateway]" >> /etc/stunnel/stunnel.conf \
&& echo "accept = 8080" >> /etc/stunnel/stunnel.conf \
&& echo "connect = 8081" >> /etc/stunnel/stunnel.conf \
&& echo "cert = /etc/stunnel/stunnel.pem" >> /etc/stunnel/stunnel.conf

# increase number of open files and concurrent TCP connections
RUN (echo 'net.ipv4.ip_local_port_range = 12000 65535' >> /etc/sysctl.conf) \
&& (echo 'fs.file-max = 1048576' >> /etc/sysctl.conf) \
Expand All @@ -10,12 +29,6 @@ RUN (echo 'net.ipv4.ip_local_port_range = 12000 65535' >> /etc/sysctl.conf) \
&& (echo 'root soft nofile 1048576' >> /etc/security/limits.conf) \
&& (echo 'root hard nofile 1048576' >> /etc/security/limits.conf)

# install apt packages
RUN --mount=type=cache,target=/var/cache/apt apt update \
&& apt-get install --no-install-recommends -y curl ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# install gateway
COPY scripts/requirements-gateway.txt /tmp/requirements-gateway.txt
RUN --mount=type=cache,target=/root/.cache/pip pip3 install --no-cache-dir -r /tmp/requirements-gateway.txt && rm -r /tmp/requirements-gateway.txt
Expand All @@ -24,4 +37,4 @@ WORKDIR /pkg
COPY . .
RUN pip3 install --no-dependencies -e ".[gateway]"

CMD ["python3", "skyplane/gateway/gateway_daemon.py"]
CMD /etc/init.d/stunnel4 start; python3 /pkg/skyplane/gateway/gateway_daemon.py --chunk-dir /skyplane/chunks --outgoing-ports '{}' --region local
4 changes: 3 additions & 1 deletion skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def launch_replication_job(
reuse_gateways: bool = False,
use_bbr: bool = False,
use_compression: bool = False,
verify_checksums: bool = True,
# cloud provider specific options
aws_instance_class: str = "m5.8xlarge",
azure_instance_class: str = "Standard_D32_v4",
Expand Down Expand Up @@ -254,7 +255,8 @@ def launch_replication_job(
signal.signal(signal.SIGINT, s)

# verify transfer
rc.verify_transfer(job)
if verify_checksums:
rc.verify_transfer(job)

stats = stats if stats else {}
stats["success"] = stats["monitor_status"] == "completed"
Expand Down
8 changes: 6 additions & 2 deletions skyplane/cli/cli_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def replicate_random(
chunk_size_mb: int = typer.Option(8, "--chunk-size-mb", help="Chunk size in MB."),
use_bbr: bool = typer.Option(True, help="If true, will use BBR congestion control"),
reuse_gateways: bool = False,
debug: bool = False,
):
"""Replicate objects from remote object store to another remote object store."""
print_header()
Expand Down Expand Up @@ -69,10 +70,11 @@ def replicate_random(
stats = launch_replication_job(
topo=topo,
job=job,
debug=False,
debug=debug,
reuse_gateways=reuse_gateways,
use_bbr=use_bbr,
use_compression=False,
verify_checksums=False,
)
return 0 if stats["success"] else 1

Expand All @@ -95,6 +97,7 @@ def replicate_random_solve(
skyplane_root / "profiles" / "throughput.csv", "--throughput-grid", help="Throughput grid file"
),
solver_verbose: bool = False,
debug: bool = False,
):
"""Replicate objects from remote object store to another remote object store."""
print_header()
Expand Down Expand Up @@ -153,9 +156,10 @@ def replicate_random_solve(
stats = launch_replication_job(
topo=topo,
job=job,
debug=False,
debug=debug,
reuse_gateways=reuse_gateways,
use_bbr=use_bbr,
use_compression=False,
verify_checksums=False,
)
return 0 if stats["success"] else 1
23 changes: 12 additions & 11 deletions skyplane/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

import paramiko
import urllib3
from skyplane import config_path, key_root
from skyplane.compute.const_cmds import make_autoshutdown_script, make_dozzle_command, make_sysctl_tcp_tuning_command
from skyplane.utils import logger
from skyplane.utils.fn import PathLike, wait_for
from skyplane.utils.net import retry_requests
from skyplane.utils.retry import retry_backoff
from skyplane.utils.timer import Timer

Expand Down Expand Up @@ -301,25 +300,30 @@ def check_stderr(tup):
docker_run_flags += f" -v /tmp/{service_key_file}:/pkg/data/{service_key_file}"

docker_run_flags += " " + " ".join(f"--env {k}={v}" for k, v in docker_envs.items())
gateway_daemon_cmd = f"python -u /pkg/skyplane/gateway/gateway_daemon.py --chunk-dir /skyplane/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag} {'--use-compression' if use_compression else ''}"
docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skyplane_gateway {gateway_docker_image} {gateway_daemon_cmd}"
gateway_daemon_cmd = f"/etc/init.d/stunnel4 start && python -u /pkg/skyplane/gateway/gateway_daemon.py --chunk-dir /skyplane/chunks --outgoing-ports '{json.dumps(outgoing_ports)}' --region {self.region_tag} {'--use-compression' if use_compression else ''}"
escaped_gateway_daemon_cmd = gateway_daemon_cmd.replace('"', '\\"')
docker_launch_cmd = (
f'sudo docker run {docker_run_flags} --name skyplane_gateway {gateway_docker_image} /bin/bash -c "{escaped_gateway_daemon_cmd}"'
)
logger.fs.info(f"{desc_prefix}: {docker_launch_cmd}")
start_out, start_err = self.run_command(docker_launch_cmd)
logger.fs.debug(desc_prefix + f": Gateway started {start_out.strip()}")
assert not start_err.strip(), f"Error starting gateway: {start_err.strip()}"
assert not start_err.strip(), f"Error starting gateway:\n{start_out.strip()}\n{start_err.strip()}"

gateway_container_hash = start_out.strip().split("\n")[-1][:12]
self.gateway_log_viewer_url = f"http://127.0.0.1:{self.tunnel_port(8888)}/container/{gateway_container_hash}"
self.gateway_api_url = f"http://127.0.0.1:{self.tunnel_port(8080 + 1)}"

# wait for gateways to start (check status API)
http_pool = urllib3.PoolManager()

def is_api_ready():
try:
api_url = f"{self.gateway_api_url}/api/v1/status"
status_val = retry_requests().get(api_url)
is_up = status_val.json().get("status") == "ok"
status_val = json.loads(http_pool.request("GET", api_url).data.decode("utf-8"))
is_up = status_val.get("status") == "ok"
return is_up
except Exception as e:
logger.error(f"{desc_prefix}: Failed to check gateway status: {e}")
return False

try:
Expand All @@ -330,9 +334,6 @@ def is_api_ready():
logger.fs.warning(desc_prefix + " gateway launch command: " + docker_launch_cmd)
logs, err = self.run_command(f"sudo docker logs skyplane_gateway --tail=100")
logger.fs.error(f"Docker logs: {logs}\nerr: {err}")

out, err = self.run_command(docker_launch_cmd.replace(" -d ", " "))
logger.fs.error(f"Relaunching gateway in foreground\nout: {out}\nerr: {err}")
logger.fs.exception(e)
raise e
finally:
Expand Down
26 changes: 8 additions & 18 deletions skyplane/gateway/gateway_daemon_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import logging.handlers
import os
import ssl
import threading
from multiprocessing import Queue
from queue import Empty
Expand Down Expand Up @@ -38,11 +37,10 @@ def __init__(
error_event,
error_queue: Queue,
host="0.0.0.0",
port=8080,
port=8081,
):
super().__init__()
self.app = Flask("gateway_metadata_server")
self.app_http = Flask("gateway_metadata_server_http")
self.chunk_store = chunk_store
self.gateway_receiver = gateway_receiver
self.error_event = error_event
Expand All @@ -51,19 +49,16 @@ def __init__(
self.error_list_lock = threading.Lock()

# load routes
for app in (self.app, self.app_http):
self.register_global_routes(app)
self.register_server_routes(app)
self.register_request_routes(app)
self.register_error_routes(app)
self.register_socket_profiling_routes(app)
self.register_global_routes(self.app)
self.register_server_routes(self.app)
self.register_request_routes(self.app)
self.register_error_routes(self.app)
self.register_socket_profiling_routes(self.app)

# make server
self.host = host
self.port = port
self.port_http = port + 1
self.url = "https://{}:{}".format(host, port)
self.url_http = "http://{}:{}".format(host, self.port_http)
self.url = "http://{}:{}".format(host, port)

# chunk status log
self.chunk_status_log: List[Dict] = []
Expand All @@ -76,18 +71,13 @@ def __init__(
self.receiver_socket_profiles_lock = threading.Lock()

logging.getLogger("werkzeug").setLevel(logging.WARNING)
self.server = make_server(host, port, self.app, threaded=True, ssl_context="adhoc")
self.server_http = make_server(host, self.port_http, self.app_http, threaded=True)
self.server = make_server(host, port, self.app, threaded=True)

def run(self):
self.server_http_thread = threading.Thread(target=self.server_http.serve_forever)
self.server_http_thread.start()
self.server.serve_forever()

def shutdown(self):
self.server.shutdown()
self.server_http.shutdown()
self.server_http_thread.join()

def register_global_routes(self, app):
# index route returns API version
Expand Down
43 changes: 25 additions & 18 deletions skyplane/gateway/gateway_sender.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import queue
import socket
import ssl
Expand All @@ -6,16 +7,15 @@
from functools import partial
from multiprocessing import Event, Process, Queue
from typing import Dict, List, Optional
import urllib3

import lz4.frame
import requests

from skyplane import MB
from skyplane.chunk import ChunkRequest
from skyplane.gateway.chunk_store import ChunkStore
from skyplane.utils import logger
from skyplane.utils.retry import retry_backoff
from skyplane.utils.net import retry_requests
from skyplane.utils.timer import Timer


Expand All @@ -30,6 +30,7 @@ def __init__(
use_tls: bool = True,
use_compression: bool = True,
):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
self.region = region
self.chunk_store = chunk_store
self.error_event = error_event
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self.destination_ports: Dict[str, int] = {} # ip_address -> int
self.destination_sockets: Dict[str, socket.socket] = {} # ip_address -> socket
self.sent_chunk_ids: Dict[str, List[int]] = {} # ip_address -> list of chunk_ids
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3), cert_reqs="CERT_NONE")

def start_workers(self):
for ip, num_connections in self.outgoing_ports.items():
Expand Down Expand Up @@ -103,9 +105,9 @@ def worker_loop(self, worker_id: int, dest_ip: str):
def wait_for_chunks():
cr_status = {}
for ip, ip_chunk_ids in self.sent_chunk_ids.items():
response = retry_requests().get(f"https://{ip}:8080/api/v1/incomplete_chunk_requests", verify=False)
assert response.status_code == 200, f"{response.status_code} {response.text}"
host_state = response.json()["chunk_requests"]
response = self.http_pool.request("GET", f"https://{ip}:8080/api/v1/incomplete_chunk_requests")
assert response.status == 200, f"{response.status_code} {response.data}"
host_state = json.loads(response.data.decode("utf-8"))["chunk_requests"]
for chunk_id in ip_chunk_ids:
if chunk_id in host_state:
cr_status[chunk_id] = host_state[chunk_id]["state"]
Expand All @@ -125,17 +127,17 @@ def wait_for_chunks():
# close servers
logger.info(f"[sender:{worker_id}] exiting, closing servers")
for dst_host, dst_port in self.destination_ports.items():
response = retry_requests().delete(f"https://{dst_host}:8080/api/v1/servers/{dst_port}", verify=False)
assert response.status_code == 200 and response.json() == {"status": "ok"}, response.json()
response = self.http_pool.request("DELETE", f"https://{dst_host}:8080/api/v1/servers/{dst_port}")
assert response.status == 200 and json.loads(response.data.decode("utf-8")) == {"status": "ok"}
logger.info(f"[sender:{worker_id}] closed destination socket {dst_host}:{dst_port}")

def queue_request(self, chunk_request: ChunkRequest):
self.worker_queue.put(chunk_request.chunk.chunk_id)

def make_socket(self, dst_host):
response = retry_requests().post(f"https://{dst_host}:8080/api/v1/servers", verify=False)
assert response.status_code == 200, f"{response.status_code} {response.text}"
self.destination_ports[dst_host] = int(response.json()["server_port"])
response = self.http_pool.request("POST", f"https://{dst_host}:8080/api/v1/servers")
assert response.status == 200, f"{response.status} {response.data.decode('utf-8')}"
self.destination_ports[dst_host] = int(json.loads(response.data.decode("utf-8"))["server_port"])
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((dst_host, self.destination_ports[dst_host]))
original_timeout = sock.gettimeout()
Expand All @@ -152,14 +154,19 @@ def make_socket(self, dst_host):
def send_chunks(self, chunk_ids: List[int], dst_host: str):
"""Send list of chunks to gateway server, pipelining small chunks together into a single socket stream."""
# notify server of upcoming ChunkRequests
logger.debug(f"[sender:{self.worker_id}]:{chunk_ids} pre-registering chunks")
chunk_reqs = [self.chunk_store.get_chunk_request(chunk_id) for chunk_id in chunk_ids]
post_req = lambda: retry_requests().post(
f"https://{dst_host}:8080/api/v1/chunk_requests", json=[c.as_dict() for c in chunk_reqs], verify=False
)
response = retry_backoff(post_req, exception_class=requests.exceptions.ConnectionError)
assert response.status_code == 200 and response.json()["status"] == "ok"
logger.debug(f"[sender:{self.worker_id}]:{chunk_ids} registered chunks")
with Timer(f"prepare to pre-register chunks {chunk_ids} to {dst_host}"):
logger.debug(f"[sender:{self.worker_id}]:{chunk_ids} pre-registering chunks")
chunk_reqs = [self.chunk_store.get_chunk_request(chunk_id) for chunk_id in chunk_ids]
register_body = json.dumps([c.as_dict() for c in chunk_reqs]).encode("utf-8")
with Timer(f"pre-register chunks {chunk_ids} to {dst_host}"):
response = self.http_pool.request(
"POST",
f"https://{dst_host}:8080/api/v1/chunk_requests",
body=register_body,
headers={"Content-Type": "application/json"},
)
assert response.status == 200 and json.loads(response.data.decode("utf-8")).get("status") == "ok"
logger.debug(f"[sender:{self.worker_id}]:{chunk_ids} registered chunks")

# contact server to set up socket connection
if self.destination_ports.get(dst_host) is None:
Expand Down
3 changes: 2 additions & 1 deletion skyplane/obj_store/gcs_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, bucket_name, gcp_region="infer", create_bucket=False):
self.auth = GCPAuthentication()
# self.auth.set_service_account_credentials("skyplane1") # use service account credentials
self._gcs_client = self.auth.get_storage_client()
self._requests_session = requests.Session()
try:
self.gcp_region = self.infer_gcp_region(bucket_name) if gcp_region is None or gcp_region == "infer" else gcp_region
if not self.bucket_exists():
Expand Down Expand Up @@ -139,7 +140,7 @@ def send_xml_request(
req = requests.Request(method, url, headers=headers)

prepared = req.prepare()
response = requests.Session().send(prepared)
response = self._requests_session.send(prepared)

if not response.ok:
raise ValueError(f"Invalid status code {response.status_code}: {response.text}")
Expand Down
2 changes: 1 addition & 1 deletion skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, bucket_name, aws_region="infer", create_bucket=False):
try:
self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region
if not self.bucket_exists():
raise exceptions.MissingBucketException()
raise exceptions.MissingBucketException(f"Bucket {bucket_name} does not exist")
except exceptions.MissingBucketException:
if create_bucket:
assert aws_region is not None and aws_region != "infer", "Must specify AWS region when creating bucket"
Expand Down
Loading

0 comments on commit 84623f4

Please sign in to comment.