Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection pooling + tune batch size to eliminate stragglers #38

Merged
merged 3 commits into from
Jan 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/test_gateway.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash
set -x
sudo docker build -t gateway_test .
sudo docker run --rm --ipc=host --network=host --name=skylark_gateway gateway_test /env/bin/python /pkg/skylark/gateway/gateway_daemon.py
sudo docker run --rm --ipc=host --network=host --name=skylark_gateway gateway_test /env/bin/python /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /dev/shm/skylark_test/chunks
2 changes: 1 addition & 1 deletion skylark/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def start_gateway(
gateway_daemon_cmd = f"/env/bin/python /pkg/skylark/gateway/gateway_daemon.py --debug --chunk-dir /dev/shm/skylark/chunks --outgoing-connections {num_outgoing_connections}"
docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}"
start_out, start_err = self.run_command(docker_launch_cmd)
assert not start_err, f"Error starting gateway: {start_err}"
assert not start_err.strip(), f"Error starting gateway: {start_err}"
gateway_container_hash = start_out.strip().split("\n")[-1][:12]
self.gateway_api_url = f"http://{self.public_ip()}:8080/api/v1"
self.gateway_log_viewer_url = f"http://{self.public_ip()}:8888/container/{gateway_container_hash}"
Expand Down
27 changes: 14 additions & 13 deletions skylark/gateway/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
dst_object_store_bucket: str
As compared to a ChunkRequest, the WireProtocolHeader is solely used to manage transfers over network sockets. It identifies the ID and
length of the upcoming stream of data (contents of the Chunk) on the socket. An end_of_stream flag is used to indicate that this is the
last transfer over a socket while a magic int (SKY_LARK) as well as the protocol version are used to enable wire protocol upgrades.
length of the upcoming stream of data (contents of the Chunk) on the socket.
WireProtocolHeader:
magic: int
protocol_version: int
chunk_id: int
chunk_len: int
end_of_stream: bool
n_chunks_left_on_socket: int
"""

from functools import total_ordering
Expand All @@ -50,8 +49,10 @@ class Chunk:
file_offset_bytes: int
chunk_length_bytes: int

def to_wire_header(self, end_of_stream: bool = False):
return WireProtocolHeader(chunk_id=self.chunk_id, chunk_len=self.chunk_length_bytes, end_of_stream=end_of_stream)
def to_wire_header(self, n_chunks_left_on_socket):
return WireProtocolHeader(
chunk_id=self.chunk_id, chunk_len=self.chunk_length_bytes, n_chunks_left_on_socket=n_chunks_left_on_socket
)

def as_dict(self):
return asdict(self)
Expand Down Expand Up @@ -136,9 +137,9 @@ def __lt__(self, other):
class WireProtocolHeader:
"""Lightweight wire protocol header for chunk transfers along socket."""

chunk_id: int # unsigned long
chunk_len: int # unsigned long
end_of_stream: bool = False # false by default, but true if this is the last chunk
chunk_id: int # long
chunk_len: int # long
n_chunks_left_on_socket: int # long

@staticmethod
def magic_hex():
Expand All @@ -150,8 +151,8 @@ def protocol_version():

@staticmethod
def length_bytes():
# magic (8) + protocol_version (4) + chunk_id (8) + chunk_len (8) + end_of_stream (1)
return 8 + 4 + 8 + 8 + 1
# magic (8) + protocol_version (4) + chunk_id (8) + chunk_len (8) + n_chunks_left_on_socket (8)
return 8 + 4 + 8 + 8 + 8

@staticmethod
def from_bytes(data: bytes):
Expand All @@ -164,16 +165,16 @@ def from_bytes(data: bytes):
raise ValueError("Invalid protocol version")
chunk_id = int.from_bytes(data[12:20], byteorder="big")
chunk_len = int.from_bytes(data[20:28], byteorder="big")
end_of_stream = bool(data[28])
return WireProtocolHeader(chunk_id=chunk_id, chunk_len=chunk_len, end_of_stream=end_of_stream)
n_chunks_left_on_socket = int.from_bytes(data[28:36], byteorder="big")
return WireProtocolHeader(chunk_id=chunk_id, chunk_len=chunk_len, n_chunks_left_on_socket=n_chunks_left_on_socket)

def to_bytes(self):
out_bytes = b""
out_bytes += self.magic_hex().to_bytes(8, byteorder="big")
out_bytes += self.protocol_version().to_bytes(4, byteorder="big")
out_bytes += self.chunk_id.to_bytes(8, byteorder="big")
out_bytes += self.chunk_len.to_bytes(8, byteorder="big")
out_bytes += bytes([int(self.end_of_stream)])
out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big")
assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}"
return out_bytes

Expand Down
6 changes: 3 additions & 3 deletions skylark/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@


class GatewayDaemon:
def __init__(self, chunk_dir: PathLike, debug=False, log_dir: Optional[PathLike] = None, outgoing_connections=1, outgoing_batch_size=1):
def __init__(self, chunk_dir: PathLike, debug=False, log_dir: Optional[PathLike] = None, outgoing_connections=1):
if log_dir is not None:
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True)
logger.remove()
logger.add(log_dir / "gateway_daemon.log", rotation="10 MB", enqueue=True)
logger.add(sys.stderr, colorize=True, format="{function:>15}:{line:<3} {level:<8} {message}", level="DEBUG")
logger.add(sys.stderr, colorize=True, format="{function:>15}:{line:<3} {level:<8} {message}", level="DEBUG", enqueue=True)
self.chunk_store = ChunkStore(chunk_dir)
self.gateway_receiver = GatewayReceiver(chunk_store=self.chunk_store)
self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, n_processes=outgoing_connections, batch_size=outgoing_batch_size)
self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, n_processes=outgoing_connections)

# API server
self.api_server = GatewayDaemonAPI(self.chunk_store, self.gateway_receiver, debug=debug, log_dir=log_dir)
Expand Down
23 changes: 14 additions & 9 deletions skylark/gateway/gateway_reciever.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

import setproctitle
from loguru import logger
from skylark import MB

from skylark.gateway.chunk import WireProtocolHeader
from skylark.gateway.chunk_store import ChunkStore
from skylark.utils.utils import Timer


class GatewayReceiver:
def __init__(self, chunk_store: ChunkStore, server_blk_size=4096 * 16):
def __init__(self, chunk_store: ChunkStore, server_blk_size=1 * MB):
self.chunk_store = chunk_store
self.server_blk_size = server_blk_size

Expand Down Expand Up @@ -50,13 +51,12 @@ def signal_handler(signal, frame):
if exit_flag.value == 1:
logger.warning(f"[server:{socket_port}] Exiting on signal")
return

# Wait for a connection with a timeout of 1 second w/ select
readable, _, _ = select.select([sock], [], [], 1)
if readable:
conn, addr = sock.accept()
chunks_received = self.recv_chunks(conn, addr)
conn.close()
logger.debug(f"[receiver] {chunks_received} chunks received")
self.recv_chunks(conn, addr)

p = Process(target=server_worker)
p.start()
Expand Down Expand Up @@ -95,12 +95,15 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
while True:
# receive header and write data to file
chunk_header = WireProtocolHeader.from_socket(conn)
self.chunk_store.state_start_download(chunk_header.chunk_id)
logger.debug(f"[server:{server_port}] Got chunk header {chunk_header.chunk_id}: {chunk_header}")
self.chunk_store.state_start_download(chunk_header.chunk_id)

# get data
with Timer() as t:
chunk_data_size = chunk_header.chunk_len
chunk_received_size = 0
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)

with chunk_file_path.open("wb") as f:
while chunk_data_size > 0:
data = conn.recv(min(chunk_data_size, self.server_blk_size))
Expand All @@ -110,13 +113,15 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
logger.debug(
f"[receiver:{server_port}] {chunk_header.chunk_id} chunk received {chunk_received_size}/{chunk_header.chunk_len}"
)
# todo check hash, update status and close socket if transfer is complete

# todo check hash
self.chunk_store.state_finish_download(chunk_header.chunk_id, t.elapsed)
chunks_received.append(chunk_header.chunk_id)
logger.info(
f"[receiver:{server_port}] Received chunk {chunk_header.chunk_id} ({chunk_received_size} bytes) in {t.elapsed:.2f} seconds"
)
if chunk_header.end_of_stream:

if chunk_header.n_chunks_left_on_socket == 0:
conn.close()
logger.debug(f"[receiver:{server_port}] End of stream reached")
return chunks_received
logger.debug(f"[receiver:{server_port}] End of stream reached, closing connection and waiting for another")
return
89 changes: 53 additions & 36 deletions skylark/gateway/gateway_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
import time
from contextlib import closing
from multiprocessing import Event, Manager, Process, Value
from typing import List
from typing import Dict, List

import requests
import setproctitle
from loguru import logger
from skylark import MB

from skylark.gateway.chunk import ChunkRequest
from skylark.gateway.chunk import ChunkRequest, WireProtocolHeader
from skylark.gateway.chunk_store import ChunkStore


class GatewaySender:
def __init__(self, chunk_store: ChunkStore, n_processes=1, batch_size=1):
def __init__(self, chunk_store: ChunkStore, n_processes=1, max_batch_size_mb=16):
self.chunk_store = chunk_store
self.n_processes = n_processes
self.batch_size = batch_size
self.max_batch_size_bytes = max_batch_size_mb * MB
self.processes = []

# shared state
Expand All @@ -26,6 +27,10 @@ def __init__(self, chunk_store: ChunkStore, n_processes=1, batch_size=1):
self.worker_queue: queue.Queue[int] = self.manager.Queue()
self.exit_flags = [Event() for _ in range(self.n_processes)]

# process-local state
self.worker_id: int = None
self.destination_ports: Dict[str, int] = None # ip_address -> port

def start_workers(self):
for i in range(self.n_processes):
p = Process(target=self.worker_loop, args=(i,))
Expand All @@ -41,71 +46,83 @@ def stop_workers(self):

def worker_loop(self, worker_id: int):
setproctitle.setproctitle(f"skylark-gateway-sender:{worker_id}")
self.worker_id = worker_id
self.destination_ports = {}
while not self.exit_flags[worker_id].is_set():
# get up to pipeline_batch_size chunks from the queue
# get all items from queue
chunk_ids_to_send = []
while len(chunk_ids_to_send) < self.batch_size:
total_bytes = 0.0
while True:
try:
chunk_ids_to_send.append(self.worker_queue.get_nowait())
total_bytes = sum(
self.chunk_store.get_chunk_request(chunk_id).chunk.chunk_length_bytes for chunk_id in chunk_ids_to_send
)
if total_bytes > self.max_batch_size_bytes:
break
except queue.Empty:
break

# check next hop is the same for all chunks in the batch
if chunk_ids_to_send:
logger.debug(f"worker {worker_id} sending {len(chunk_ids_to_send)} chunks")
chunks = []
for idx in chunk_ids_to_send:
self.chunk_store.pop_chunk_request_path(idx)
req = self.chunk_store.get_chunk_request(idx)
chunks.append(req)
next_hop = chunks[0].path[0]
assert all(next_hop.hop_cloud_region == chunk.path[0].hop_cloud_region for chunk in chunks)
assert all(next_hop.hop_ip_address == chunk.path[0].hop_ip_address for chunk in chunks)

# send chunks
chunk_ids = [req.chunk.chunk_id for req in chunks]
self.send_chunks(chunk_ids, next_hop.hop_ip_address)
time.sleep(0.1) # short interval to batch requests
if len(chunk_ids_to_send) > 0:
# check next hop is the same for all chunks in the batch
if chunk_ids_to_send:
logger.debug(f"[sender:{worker_id}] sending {len(chunk_ids_to_send)} chunks, {chunk_ids_to_send}")
chunks = []
for idx in chunk_ids_to_send:
self.chunk_store.pop_chunk_request_path(idx)
req = self.chunk_store.get_chunk_request(idx)
chunks.append(req)
next_hop = chunks[0].path[0]
assert all(next_hop.hop_cloud_region == chunk.path[0].hop_cloud_region for chunk in chunks)
assert all(next_hop.hop_ip_address == chunk.path[0].hop_ip_address for chunk in chunks)

# send chunks
logger.debug(
f"[sender:{worker_id}] Sending {len(chunks)} chunks ({total_bytes / MB:.2f}MB) to {next_hop.hop_ip_address}, {chunk_ids_to_send}"
)
chunk_ids = [req.chunk.chunk_id for req in chunks]
self.send_chunks(chunk_ids, next_hop.hop_ip_address)

# close destination sockets
for dst_host, dst_port in self.destination_ports.items():
response = requests.delete(f"http://{dst_host}:8080/api/v1/servers/{dst_port}")
assert response.status_code == 200 and response.json() == {"status": "ok"}, response.json()
logger.info(f"[sender:{worker_id}] closed destination socket {dst_host}:{dst_port}")

def queue_request(self, chunk_request: ChunkRequest):
logger.debug(f"queuing chunk request {chunk_request.chunk.chunk_id}")
self.worker_queue.put(chunk_request.chunk.chunk_id)

def send_chunks(self, chunk_ids: List[int], dst_host="127.0.0.1"):
def send_chunks(self, chunk_ids: List[int], dst_host: str):
"""Send list of chunks to gateway server, pipelining small chunks together into a single socket stream."""
# notify server of upcoming ChunkRequests
# pop chunk_req.path[0] to remove self
chunk_reqs = [self.chunk_store.get_chunk_request(chunk_id) for chunk_id in chunk_ids]
response = requests.post(f"http://{dst_host}:8080/api/v1/chunk_requests", json=[c.as_dict() for c in chunk_reqs])
assert response.status_code == 200 and response.json()["status"] == "ok"

# contact server to set up socket connection
response = requests.post(f"http://{dst_host}:8080/api/v1/servers")
assert response.status_code == 200
dst_port = int(response.json()["server_port"])
if self.destination_ports.get(dst_host) is None:
response = requests.post(f"http://{dst_host}:8080/api/v1/servers")
assert response.status_code == 200
self.destination_ports[dst_host] = int(response.json()["server_port"])
logger.info(f"[sender:{self.worker_id}] started new server connection to {dst_host}:{self.destination_ports[dst_host]}")
dst_port = self.destination_ports[dst_host]

with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
parasj marked this conversation as resolved.
Show resolved Hide resolved
sock.connect((dst_host, dst_port))
for idx, chunk_id in enumerate(chunk_ids):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) # disable Nagle's algorithm
logger.warning(f"[sender -> {dst_port}] Sending chunk {chunk_id} to {dst_host}:{dst_port}")
self.chunk_store.state_start_upload(chunk_id)
chunk = self.chunk_store.get_chunk_request(chunk_id).chunk

# send chunk header
chunk.to_wire_header(end_of_stream=idx == len(chunk_ids) - 1).to_socket(sock)
chunk.to_wire_header(n_chunks_left_on_socket=len(chunk_ids) - idx - 1).to_socket(sock)

# send chunk data
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id)
assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist"
with open(chunk_file_path, "rb") as fd:
bytes_sent = sock.sendfile(fd)
logger.debug(f"[sender -> {dst_port}] Sent {bytes_sent} bytes of data")

self.chunk_store.state_finish_upload(chunk_id)
chunk_file_path.unlink()
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) # send remaining packets

# close server
response = requests.delete(f"http://{dst_host}:8080/api/v1/servers/{dst_port}")
assert response.status_code == 200 and response.json() == {"status": "ok"}, response.json()
logger.info(f"[sender:{self.worker_id} -> {dst_port}] Sent {len(chunk_ids)} chunks, {chunk_ids}")
2 changes: 1 addition & 1 deletion skylark/test/test_replicator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):
crs = rc.run_replication_plan(job)
logger.info(f"{total_bytes / GB:.2f}GByte replication job launched")
stats = rc.monitor_transfer(crs)
logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}GByte/s)")
logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}Gbit/s)")


if __name__ == "__main__":
Expand Down