Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back compression and encryption #877

Merged
merged 18 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _start_gateway(
gateway_docker_image=gateway_docker_image,
gateway_program_path=str(gateway_program_filename),
gateway_info_path=f"{gateway_log_dir}/gateway_info.json",
e2ee_key_bytes=None, # TODO: remove
e2ee_key_bytes=e2ee_key_bytes, # TODO: remove
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
Expand Down Expand Up @@ -237,6 +237,7 @@ def copy_log(instance):
instance.download_file("/tmp/gateway.stdout", out_file)
instance.download_file("/tmp/gateway.stderr", err_file)

print("COPY GATEWAY LOGS")
lynnliu030 marked this conversation as resolved.
Show resolved Hide resolved
do_parallel(copy_log, self.bound_nodes.values(), n=-1)

def deprovision(self, max_jobs: int = 64, spinner: bool = False):
Expand Down Expand Up @@ -309,6 +310,7 @@ def run_async(self, jobs: List[TransferJob], hooks: Optional[TransferHook] = Non
"""
if not self.provisioned:
logger.error("Dataplane must be pre-provisioned. Call dataplane.provision() before starting a transfer")
print("discord", jobs)
lynnliu030 marked this conversation as resolved.
Show resolved Hide resolved
tracker = TransferProgressTracker(self, jobs, self.transfer_config, hooks)
self.pending_transfers.append(tracker)
tracker.start()
Expand Down
8 changes: 5 additions & 3 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
transfer_config: TransferConfig,
# cloud_regions: dict,
max_instances: Optional[int] = 1,
n_connections: Optional[int] = 64,
planning_algorithm: Optional[str] = "direct",
debug: Optional[bool] = False,
):
Expand All @@ -54,6 +55,7 @@ def __init__(
# self.cloud_regions = cloud_regions
# TODO: set max instances with VM CPU limits and/or config
self.max_instances = max_instances
self.n_connections = n_connections
self.provisioner = provisioner
self.transfer_config = transfer_config
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3))
Expand All @@ -68,11 +70,11 @@ def __init__(
# planner
self.planning_algorithm = planning_algorithm
if self.planning_algorithm == "direct":
self.planner = MulticastDirectPlanner(self.max_instances, 64, self.transfer_config)
self.planner = MulticastDirectPlanner(self.max_instances, self.n_connections, self.transfer_config)
elif self.planning_algorithm == "src_one_sided":
self.planner = DirectPlannerSourceOneSided(self.max_instances, 64, self.transfer_config)
self.planner = DirectPlannerSourceOneSided(self.max_instances, self.n_connections, self.transfer_config)
elif self.planning_algorithm == "dst_one_sided":
self.planner = DirectPlannerDestOneSided(self.max_instances, 64, self.transfer_config)
self.planner = DirectPlannerDestOneSided(self.max_instances, self.n_connections, self.transfer_config)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down
25 changes: 18 additions & 7 deletions skyplane/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ class Chunk:
part_number: Optional[int] = None
upload_id: Optional[str] = None # TODO: for broadcast, this is not used

def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, is_compressed: bool = False):
def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int, is_compressed: bool = False):
return WireProtocolHeader(
chunk_id=self.chunk_id, data_len=wire_length, is_compressed=is_compressed, n_chunks_left_on_socket=n_chunks_left_on_socket
chunk_id=self.chunk_id,
data_len=wire_length,
raw_data_len=raw_wire_length,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

def as_dict(self):
Expand Down Expand Up @@ -94,6 +98,7 @@ class WireProtocolHeader:

chunk_id: str # 128bit UUID
data_len: int # long
raw_data_len: int # long (uncompressed, unecrypted)
is_compressed: bool # char
n_chunks_left_on_socket: int # long

Expand All @@ -110,8 +115,8 @@ def protocol_version():

@staticmethod
def length_bytes():
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + is_compressed (1) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 1 + 8
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + is_compressed (1) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 1 + 8

@staticmethod
def from_bytes(data: bytes):
Expand All @@ -124,10 +129,15 @@ def from_bytes(data: bytes):
raise ValueError(f"Invalid protocol version, got {version} but expected {WireProtocolHeader.protocol_version()}")
chunk_id = data[12:28].hex()
chunk_len = int.from_bytes(data[28:36], byteorder="big")
is_compressed = bool(int.from_bytes(data[36:37], byteorder="big"))
n_chunks_left_on_socket = int.from_bytes(data[37:45], byteorder="big")
raw_chunk_len = int.from_bytes(data[36:44], byteorder="big")
is_compressed = bool(int.from_bytes(data[44:45], byteorder="big"))
n_chunks_left_on_socket = int.from_bytes(data[45:53], byteorder="big")
return WireProtocolHeader(
chunk_id=chunk_id, data_len=chunk_len, is_compressed=is_compressed, n_chunks_left_on_socket=n_chunks_left_on_socket
chunk_id=chunk_id,
data_len=chunk_len,
raw_data_len=raw_chunk_len,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

def to_bytes(self):
Expand All @@ -138,6 +148,7 @@ def to_bytes(self):
assert len(chunk_id_bytes) == 16
out_bytes += chunk_id_bytes
out_bytes += self.data_len.to_bytes(8, byteorder="big")
out_bytes += self.raw_data_len.to_bytes(8, byteorder="big")
out_bytes += self.is_compressed.to_bytes(1, byteorder="big")
out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big")
assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}"
Expand Down
18 changes: 10 additions & 8 deletions skyplane/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
GatewayWriteLocal,
GatewayObjStoreReadOperator,
GatewayObjStoreWriteOperator,
GatewayWaitReciever,
GatewayWaitReceiver,
)
from skyplane.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.utils import logger
Expand All @@ -38,7 +38,8 @@ def __init__(
chunk_dir: PathLike,
max_incoming_ports=64,
use_tls=True,
use_e2ee=False,
use_e2ee=True, # TODO: read from operator field
use_compression=True, # TODO: read from operator field
):
# read gateway program
gateway_program_path = Path(os.environ["GATEWAY_PROGRAM_FILE"]).expanduser()
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
e2ee_key_path = Path(os.environ["E2EE_KEY_FILE"]).expanduser()
with open(e2ee_key_path, "rb") as f:
self.e2ee_key_bytes = f.read()
print("Server side E2EE key loaded: ", self.e2ee_key_bytes)
else:
self.e2ee_key_bytes = None

Expand All @@ -79,7 +81,7 @@ def __init__(
self.num_required_terminal = {}
self.operators = self.create_gateway_operators(gateway_program)

# single gateway reciever
# single gateway receiver
self.gateway_receiver = GatewayReceiver(
"reciever",
region=region,
Expand All @@ -88,7 +90,7 @@ def __init__(
error_queue=self.error_queue,
max_pending_chunks=max_incoming_ports,
use_tls=self.use_tls,
use_compression=False, # use_compression,
use_compression=use_compression,
e2ee_key_bytes=self.e2ee_key_bytes,
)

Expand Down Expand Up @@ -178,13 +180,13 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_

# create operators
if op["op_type"] == "receive":
# wait for chunks from reciever
operators[handle] = GatewayWaitReciever(
# wait for chunks from receiver
operators[handle] = GatewayWaitReceiver(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
n_processes=1, # dummy wait thread, not actual reciever
n_processes=1, # dummy wait thread, not actual receiver
chunk_store=self.chunk_store,
error_event=self.error_event,
error_queue=self.error_queue,
Expand Down Expand Up @@ -230,7 +232,7 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_tls=self.use_tls,
use_compression=False, # operator["compress"],
use_compression=op["compress"],
e2ee_key_bytes=self.e2ee_key_bytes,
n_processes=op["num_connections"],
)
Expand Down
40 changes: 24 additions & 16 deletions skyplane/gateway/operators/gateway_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import socket
import ssl
import time
import lz4.frame
import traceback
from functools import partial
from multiprocessing import Event, Process, Queue
Expand Down Expand Up @@ -121,11 +122,11 @@ def process(self, chunk_req: ChunkRequest, **args):
pass


class GatewayWaitReciever(GatewayOperator):
class GatewayWaitReceiver(GatewayOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: alternative (potentially better performnace) implementation: connect via queue with GatewayReciever to listen
# TODO: alternative (potentially better performnace) implementation: connect via queue with GatewayReceiver to listen
# for download completition events - join with chunk request queue from ChunkStore
def process(self, chunk_req: ChunkRequest):
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id)
Expand All @@ -134,7 +135,7 @@ def process(self, chunk_req: ChunkRequest):
return False

# check to see if file is completed downloading
# Successfully recieved chunk 38400a29812142a486eaefcdebedf371, 161867776 0, 67108864
# Successfully received chunk 38400a29812142a486eaefcdebedf371, 161867776 0, 67108864
with open(chunk_file_path, "rb") as f:
data = f.read()
if len(data) < chunk_req.chunk.chunk_length_bytes:
Expand All @@ -144,7 +145,7 @@ def process(self, chunk_req: ChunkRequest):
len(data) == chunk_req.chunk.chunk_length_bytes
), f"Downloaded chunk length does not match expected length: {len(data)}, {chunk_req.chunk.chunk_length_bytes}"
print(
f"[{self.handle}:{self.worker_id}] Successfully recieved chunk {chunk_req.chunk.chunk_id}, {len(data)}, {chunk_req.chunk.chunk_length_bytes}"
f"[{self.handle}:{self.worker_id}] Successfully received chunk {chunk_req.chunk.chunk_id}, {len(data)}, {chunk_req.chunk.chunk_length_bytes}"
)
return True

Expand Down Expand Up @@ -328,17 +329,24 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}"

wire_length = len(data)
# compressed_length = None
# if self.use_compression and self.region == chunk_req.src_region:
# data = lz4.frame.compress(data)
# wire_length = len(data)
# compressed_length = wire_length
# if self.e2ee_secretbox is not None and self.region == chunk_req.src_region:
# data = self.e2ee_secretbox.encrypt(data)
# wire_length = len(data)
raw_wire_length = wire_length
compressed_length = None

if self.use_compression:
data = lz4.frame.compress(data)
wire_length = len(data)
compressed_length = wire_length
if self.e2ee_secretbox is not None:
data = self.e2ee_secretbox.encrypt(data)
wire_length = len(data)

# send chunk header
header = chunk.to_wire_header(n_chunks_left_on_socket=len(chunk_ids) - idx - 1, wire_length=wire_length, is_compressed=False)
header = chunk.to_wire_header(
n_chunks_left_on_socket=len(chunk_ids) - idx - 1,
wire_length=wire_length,
raw_wire_length=raw_wire_length,
is_compressed=(compressed_length is not None),
)
# print(f"[sender-{self.worker_id}]:{chunk_id} sending chunk header {header}")
header.to_socket(sock)
# print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header")
Expand Down Expand Up @@ -528,10 +536,10 @@ def process(self, chunk_req: ChunkRequest, **args):
# else:
# self.chunk_store.update_chunk_checksum(chunk_req.chunk.chunk_id, md5sum)

recieved_chunk_size = self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).stat().st_size
received_chunk_size = self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).stat().st_size
assert (
recieved_chunk_size == chunk_req.chunk.chunk_length_bytes
), f"Downloaded chunk {chunk_req.chunk.chunk_id} to {fpath} has incorrect size (expected {chunk_req.chunk.chunk_length_bytes} but got {recieved_chunk_size}, {chunk_req.chunk.chunk_length_bytes})"
received_chunk_size == chunk_req.chunk.chunk_length_bytes
), f"Downloaded chunk {chunk_req.chunk.chunk_id} to {fpath} has incorrect size (expected {chunk_req.chunk.chunk_length_bytes} but got {received_chunk_size}, {chunk_req.chunk.chunk_length_bytes})"
logger.debug(f"[obj_store:{self.worker_id}] Downloaded {chunk_req.chunk.chunk_id} from {self.bucket_name}")
return True

Expand Down
29 changes: 22 additions & 7 deletions skyplane/gateway/operators/gateway_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import socket
import ssl
import time
import lz4.frame
import traceback
from contextlib import closing
from multiprocessing import Event, Process, Value, Queue
Expand Down Expand Up @@ -152,8 +153,8 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
# TODO: this wont work
# chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id)

# should_decrypt = self.e2ee_secretbox is not None and chunk_request.dst_region == self.region
# should_decompress = chunk_header.is_compressed and chunk_request.dst_region == self.region
should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region
should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region

# wait for space
# while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks:
Expand All @@ -170,7 +171,7 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)
with fpath.open("wb") as f:
socket_data_len = chunk_header.data_len
chunk_received_size = 0
chunk_received_size, chunk_received_size_decompressed = 0, 0
to_write = bytearray(socket_data_len)
to_write_view = memoryview(to_write)
while socket_data_len > 0:
Expand All @@ -187,22 +188,36 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
)
to_write = bytes(to_write)

if should_decrypt:
to_write = self.e2ee_secretbox.decrypt(to_write)
print(f"[receiver:{server_port}]:{chunk_header.chunk_id} Decrypting {len(to_write)} bytes")

if should_decompress:
data_batch_decompressed = lz4.frame.decompress(to_write)
chunk_received_size_decompressed += len(data_batch_decompressed)
to_write = data_batch_decompressed
print(
f"[receiver:{server_port}]:{chunk_header.chunk_id} Decompressing {len(to_write)} bytes to {chunk_received_size_decompressed} bytes"
)

# try to write data until successful
while True:
try:
f.seek(0, 0)
f.write(to_write)
f.flush()

# check write succeeds
assert os.path.exists(fpath)

# check size
file_size = os.path.getsize(fpath)
if file_size == chunk_header.data_len:
if file_size == chunk_header.raw_data_len:
break
elif file_size >= chunk_header.data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.data_len}")
elif file_size >= chunk_header.raw_data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}")
except Exception as e:
print(e)

print(
f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
)
Expand Down
Loading