diff --git a/scripts/test_gateway.sh b/scripts/test_gateway.sh index 718486c4f..39c77c540 100644 --- a/scripts/test_gateway.sh +++ b/scripts/test_gateway.sh @@ -1,3 +1,4 @@ #!/bin/bash +set -x sudo docker build -t gateway_test . -sudo docker run --rm --ipc=host --network=host --name=skylark_gateway gateway_test /env/bin/python /pkg/skylark/gateway/gateway_daemon.py \ No newline at end of file +sudo docker run --rm --ipc=host --network=host --name=skylark_gateway gateway_test /env/bin/python /pkg/skylark/gateway/gateway_daemon.py --chunk-dir /dev/shm/skylark_test/chunks \ No newline at end of file diff --git a/skylark/compute/server.py b/skylark/compute/server.py index 635928ac5..07fca4d23 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -266,7 +266,7 @@ def start_gateway( gateway_daemon_cmd = f"/env/bin/python /pkg/skylark/gateway/gateway_daemon.py --debug --chunk-dir /dev/shm/skylark/chunks --outgoing-connections {num_outgoing_connections}" docker_launch_cmd = f"sudo docker run {docker_run_flags} --name skylark_gateway {gateway_docker_image} {gateway_daemon_cmd}" start_out, start_err = self.run_command(docker_launch_cmd) - assert not start_err, f"Error starting gateway: {start_err}" + assert not start_err.strip(), f"Error starting gateway: {start_err}" gateway_container_hash = start_out.strip().split("\n")[-1][:12] self.gateway_api_url = f"http://{self.public_ip()}:8080/api/v1" self.gateway_log_viewer_url = f"http://{self.public_ip()}:8888/container/{gateway_container_hash}" diff --git a/skylark/gateway/chunk.py b/skylark/gateway/chunk.py index 23b0ea9f5..fff0fd7bd 100644 --- a/skylark/gateway/chunk.py +++ b/skylark/gateway/chunk.py @@ -23,15 +23,14 @@ dst_object_store_bucket: str As compared to a ChunkRequest, the WireProtocolHeader is solely used to manage transfers over network sockets. It identifies the ID and -length of the upcoming stream of data (contents of the Chunk) on the socket. An end_of_stream flag is used to indicate that this is the -last transfer over a socket while a magic int (SKY_LARK) as well as the protocol version are used to enable wire protocol upgrades. +length of the upcoming stream of data (contents of the Chunk) on the socket. WireProtocolHeader: magic: int protocol_version: int chunk_id: int chunk_len: int - end_of_stream: bool + n_chunks_left_on_socket: int """ from functools import total_ordering @@ -50,8 +49,10 @@ class Chunk: file_offset_bytes: int chunk_length_bytes: int - def to_wire_header(self, end_of_stream: bool = False): - return WireProtocolHeader(chunk_id=self.chunk_id, chunk_len=self.chunk_length_bytes, end_of_stream=end_of_stream) + def to_wire_header(self, n_chunks_left_on_socket): + return WireProtocolHeader( + chunk_id=self.chunk_id, chunk_len=self.chunk_length_bytes, n_chunks_left_on_socket=n_chunks_left_on_socket + ) def as_dict(self): return asdict(self) @@ -136,9 +137,9 @@ def __lt__(self, other): class WireProtocolHeader: """Lightweight wire protocol header for chunk transfers along socket.""" - chunk_id: int # unsigned long - chunk_len: int # unsigned long - end_of_stream: bool = False # false by default, but true if this is the last chunk + chunk_id: int # long + chunk_len: int # long + n_chunks_left_on_socket: int # long @staticmethod def magic_hex(): @@ -150,8 +151,8 @@ def protocol_version(): @staticmethod def length_bytes(): - # magic (8) + protocol_version (4) + chunk_id (8) + chunk_len (8) + end_of_stream (1) - return 8 + 4 + 8 + 8 + 1 + # magic (8) + protocol_version (4) + chunk_id (8) + chunk_len (8) + n_chunks_left_on_socket (8) + return 8 + 4 + 8 + 8 + 8 @staticmethod def from_bytes(data: bytes): @@ -164,8 +165,8 @@ def from_bytes(data: bytes): raise ValueError("Invalid protocol version") chunk_id = int.from_bytes(data[12:20], byteorder="big") chunk_len = int.from_bytes(data[20:28], byteorder="big") - end_of_stream = bool(data[28]) - return WireProtocolHeader(chunk_id=chunk_id, chunk_len=chunk_len, end_of_stream=end_of_stream) + n_chunks_left_on_socket = int.from_bytes(data[28:36], byteorder="big") + return WireProtocolHeader(chunk_id=chunk_id, chunk_len=chunk_len, n_chunks_left_on_socket=n_chunks_left_on_socket) def to_bytes(self): out_bytes = b"" @@ -173,7 +174,7 @@ def to_bytes(self): out_bytes += self.protocol_version().to_bytes(4, byteorder="big") out_bytes += self.chunk_id.to_bytes(8, byteorder="big") out_bytes += self.chunk_len.to_bytes(8, byteorder="big") - out_bytes += bytes([int(self.end_of_stream)]) + out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big") assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}" return out_bytes diff --git a/skylark/gateway/gateway_daemon.py b/skylark/gateway/gateway_daemon.py index 4d74e67e1..dd68acb0b 100644 --- a/skylark/gateway/gateway_daemon.py +++ b/skylark/gateway/gateway_daemon.py @@ -22,16 +22,16 @@ class GatewayDaemon: - def __init__(self, chunk_dir: PathLike, debug=False, log_dir: Optional[PathLike] = None, outgoing_connections=1, outgoing_batch_size=1): + def __init__(self, chunk_dir: PathLike, debug=False, log_dir: Optional[PathLike] = None, outgoing_connections=1): if log_dir is not None: log_dir = Path(log_dir) log_dir.mkdir(exist_ok=True) logger.remove() logger.add(log_dir / "gateway_daemon.log", rotation="10 MB", enqueue=True) - logger.add(sys.stderr, colorize=True, format="{function:>15}:{line:<3} {level:<8} {message}", level="DEBUG") + logger.add(sys.stderr, colorize=True, format="{function:>15}:{line:<3} {level:<8} {message}", level="DEBUG", enqueue=True) self.chunk_store = ChunkStore(chunk_dir) self.gateway_receiver = GatewayReceiver(chunk_store=self.chunk_store) - self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, n_processes=outgoing_connections, batch_size=outgoing_batch_size) + self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, n_processes=outgoing_connections) # API server self.api_server = GatewayDaemonAPI(self.chunk_store, self.gateway_receiver, debug=debug, log_dir=log_dir) diff --git a/skylark/gateway/gateway_reciever.py b/skylark/gateway/gateway_reciever.py index f777a33b3..79320bd90 100644 --- a/skylark/gateway/gateway_reciever.py +++ b/skylark/gateway/gateway_reciever.py @@ -9,6 +9,7 @@ import setproctitle from loguru import logger +from skylark import MB from skylark.gateway.chunk import WireProtocolHeader from skylark.gateway.chunk_store import ChunkStore @@ -16,9 +17,9 @@ class GatewayReceiver: - def __init__(self, chunk_store: ChunkStore, server_blk_size=4096 * 16): + def __init__(self, chunk_store: ChunkStore, write_back_block_size=1 * MB): self.chunk_store = chunk_store - self.server_blk_size = server_blk_size + self.write_back_block_size = write_back_block_size # shared state self.manager = Manager() @@ -50,13 +51,12 @@ def signal_handler(signal, frame): if exit_flag.value == 1: logger.warning(f"[server:{socket_port}] Exiting on signal") return + # Wait for a connection with a timeout of 1 second w/ select readable, _, _ = select.select([sock], [], [], 1) if readable: conn, addr = sock.accept() - chunks_received = self.recv_chunks(conn, addr) - conn.close() - logger.debug(f"[receiver] {chunks_received} chunks received") + self.recv_chunks(conn, addr) p = Process(target=server_worker) p.start() @@ -95,28 +95,33 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): while True: # receive header and write data to file chunk_header = WireProtocolHeader.from_socket(conn) - self.chunk_store.state_start_download(chunk_header.chunk_id) logger.debug(f"[server:{server_port}] Got chunk header {chunk_header.chunk_id}: {chunk_header}") + self.chunk_store.state_start_download(chunk_header.chunk_id) + + # get data with Timer() as t: chunk_data_size = chunk_header.chunk_len chunk_received_size = 0 chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id) + with chunk_file_path.open("wb") as f: while chunk_data_size > 0: - data = conn.recv(min(chunk_data_size, self.server_blk_size)) + data = conn.recv(min(chunk_data_size, self.write_back_block_size)) f.write(data) chunk_data_size -= len(data) chunk_received_size += len(data) logger.debug( f"[receiver:{server_port}] {chunk_header.chunk_id} chunk received {chunk_received_size}/{chunk_header.chunk_len}" ) - # todo check hash, update status and close socket if transfer is complete + + # todo check hash self.chunk_store.state_finish_download(chunk_header.chunk_id, t.elapsed) chunks_received.append(chunk_header.chunk_id) logger.info( f"[receiver:{server_port}] Received chunk {chunk_header.chunk_id} ({chunk_received_size} bytes) in {t.elapsed:.2f} seconds" ) - if chunk_header.end_of_stream: + + if chunk_header.n_chunks_left_on_socket == 0: conn.close() - logger.debug(f"[receiver:{server_port}] End of stream reached") - return chunks_received + logger.debug(f"[receiver:{server_port}] End of stream reached, closing connection and waiting for another") + return diff --git a/skylark/gateway/gateway_sender.py b/skylark/gateway/gateway_sender.py index d47fe5f8a..0aaa683be 100644 --- a/skylark/gateway/gateway_sender.py +++ b/skylark/gateway/gateway_sender.py @@ -1,23 +1,24 @@ import queue import socket -import time from contextlib import closing from multiprocessing import Event, Manager, Process, Value -from typing import List +from typing import Dict, List import requests import setproctitle from loguru import logger +from skylark import MB from skylark.gateway.chunk import ChunkRequest from skylark.gateway.chunk_store import ChunkStore +from skylark.utils.utils import Timer class GatewaySender: - def __init__(self, chunk_store: ChunkStore, n_processes=1, batch_size=1): + def __init__(self, chunk_store: ChunkStore, n_processes=1, max_batch_size_bytes=64 * MB): self.chunk_store = chunk_store self.n_processes = n_processes - self.batch_size = batch_size + self.max_batch_size_bytes = max_batch_size_bytes self.processes = [] # shared state @@ -26,6 +27,10 @@ def __init__(self, chunk_store: ChunkStore, n_processes=1, batch_size=1): self.worker_queue: queue.Queue[int] = self.manager.Queue() self.exit_flags = [Event() for _ in range(self.n_processes)] + # process-local state + self.worker_id: int = None + self.destination_ports: Dict[str, int] = None # ip_address -> port + def start_workers(self): for i in range(self.n_processes): p = Process(target=self.worker_loop, args=(i,)) @@ -41,71 +46,81 @@ def stop_workers(self): def worker_loop(self, worker_id: int): setproctitle.setproctitle(f"skylark-gateway-sender:{worker_id}") + self.worker_id = worker_id + self.destination_ports = {} while not self.exit_flags[worker_id].is_set(): - # get up to pipeline_batch_size chunks from the queue + # get all items from queue chunk_ids_to_send = [] - while len(chunk_ids_to_send) < self.batch_size: + total_bytes = 0.0 + while True: try: chunk_ids_to_send.append(self.worker_queue.get_nowait()) + total_bytes = sum( + self.chunk_store.get_chunk_request(chunk_id).chunk.chunk_length_bytes for chunk_id in chunk_ids_to_send + ) + if total_bytes > self.max_batch_size_bytes: + break except queue.Empty: break - # check next hop is the same for all chunks in the batch - if chunk_ids_to_send: - logger.debug(f"worker {worker_id} sending {len(chunk_ids_to_send)} chunks") - chunks = [] - for idx in chunk_ids_to_send: - self.chunk_store.pop_chunk_request_path(idx) - req = self.chunk_store.get_chunk_request(idx) - chunks.append(req) - next_hop = chunks[0].path[0] - assert all(next_hop.hop_cloud_region == chunk.path[0].hop_cloud_region for chunk in chunks) - assert all(next_hop.hop_ip_address == chunk.path[0].hop_ip_address for chunk in chunks) - - # send chunks - chunk_ids = [req.chunk.chunk_id for req in chunks] - self.send_chunks(chunk_ids, next_hop.hop_ip_address) - time.sleep(0.1) # short interval to batch requests + if len(chunk_ids_to_send) > 0: + # check next hop is the same for all chunks in the batch + if chunk_ids_to_send: + logger.debug(f"[sender:{worker_id}] sending {len(chunk_ids_to_send)} chunks, {chunk_ids_to_send}") + chunks = [] + for idx in chunk_ids_to_send: + self.chunk_store.pop_chunk_request_path(idx) + req = self.chunk_store.get_chunk_request(idx) + chunks.append(req) + next_hop = chunks[0].path[0] + assert all(next_hop.hop_cloud_region == chunk.path[0].hop_cloud_region for chunk in chunks) + assert all(next_hop.hop_ip_address == chunk.path[0].hop_ip_address for chunk in chunks) + + # send chunks + chunk_ids = [req.chunk.chunk_id for req in chunks] + self.send_chunks(chunk_ids, next_hop.hop_ip_address) + + # close destination sockets + for dst_host, dst_port in self.destination_ports.items(): + response = requests.delete(f"http://{dst_host}:8080/api/v1/servers/{dst_port}") + assert response.status_code == 200 and response.json() == {"status": "ok"}, response.json() + logger.info(f"[sender:{worker_id}] closed destination socket {dst_host}:{dst_port}") def queue_request(self, chunk_request: ChunkRequest): - logger.debug(f"queuing chunk request {chunk_request.chunk.chunk_id}") self.worker_queue.put(chunk_request.chunk.chunk_id) - def send_chunks(self, chunk_ids: List[int], dst_host="127.0.0.1"): + def send_chunks(self, chunk_ids: List[int], dst_host: str): """Send list of chunks to gateway server, pipelining small chunks together into a single socket stream.""" # notify server of upcoming ChunkRequests - # pop chunk_req.path[0] to remove self chunk_reqs = [self.chunk_store.get_chunk_request(chunk_id) for chunk_id in chunk_ids] response = requests.post(f"http://{dst_host}:8080/api/v1/chunk_requests", json=[c.as_dict() for c in chunk_reqs]) assert response.status_code == 200 and response.json()["status"] == "ok" # contact server to set up socket connection - response = requests.post(f"http://{dst_host}:8080/api/v1/servers") - assert response.status_code == 200 - dst_port = int(response.json()["server_port"]) + if self.destination_ports.get(dst_host) is None: + response = requests.post(f"http://{dst_host}:8080/api/v1/servers") + assert response.status_code == 200 + self.destination_ports[dst_host] = int(response.json()["server_port"]) + logger.info(f"[sender:{self.worker_id}] started new server connection to {dst_host}:{self.destination_ports[dst_host]}") + dst_port = self.destination_ports[dst_host] with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.connect((dst_host, dst_port)) - for idx, chunk_id in enumerate(chunk_ids): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) # disable Nagle's algorithm - logger.warning(f"[sender -> {dst_port}] Sending chunk {chunk_id} to {dst_host}:{dst_port}") - self.chunk_store.state_start_upload(chunk_id) - chunk = self.chunk_store.get_chunk_request(chunk_id).chunk - - # send chunk header - chunk.to_wire_header(end_of_stream=idx == len(chunk_ids) - 1).to_socket(sock) - - # send chunk data - chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) - assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist" - with open(chunk_file_path, "rb") as fd: - bytes_sent = sock.sendfile(fd) - logger.debug(f"[sender -> {dst_port}] Sent {bytes_sent} bytes of data") - - self.chunk_store.state_finish_upload(chunk_id) - chunk_file_path.unlink() - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) # send remaining packets - - # close server - response = requests.delete(f"http://{dst_host}:8080/api/v1/servers/{dst_port}") - assert response.status_code == 200 and response.json() == {"status": "ok"}, response.json() + with Timer() as t: + for idx, chunk_id in enumerate(chunk_ids): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) # disable Nagle's algorithm + self.chunk_store.state_start_upload(chunk_id) + chunk = self.chunk_store.get_chunk_request(chunk_id).chunk + + # send chunk header + chunk.to_wire_header(n_chunks_left_on_socket=len(chunk_ids) - idx - 1).to_socket(sock) + + # send chunk data + chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) + assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist" + with open(chunk_file_path, "rb") as fd: + bytes_sent = sock.sendfile(fd) + self.chunk_store.state_finish_upload(chunk_id) + chunk_file_path.unlink() + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) # send remaining packets + logger.info(f"[sender:{self.worker_id} -> {dst_port}] Sent {len(chunk_ids)} chunks in {t.elapsed:.2}s, {chunk_ids}") diff --git a/skylark/test/test_replicator_client.py b/skylark/test/test_replicator_client.py index f51f44921..630b3ab09 100644 --- a/skylark/test/test_replicator_client.py +++ b/skylark/test/test_replicator_client.py @@ -122,7 +122,7 @@ def main(args): crs = rc.run_replication_plan(job) logger.info(f"{total_bytes / GB:.2f}GByte replication job launched") stats = rc.monitor_transfer(crs) - logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}GByte/s)") + logger.info(f"Replication completed in {stats['total_runtime_s']:.2f}s ({stats['throughput_gbits']:.2f}Gbit/s)") if __name__ == "__main__":