diff --git a/.pytype.cfg b/.pytype.cfg index 1d4017332..17396fe59 100644 --- a/.pytype.cfg +++ b/.pytype.cfg @@ -56,4 +56,4 @@ pythonpath = ; protocols = False # Experimental: Only load submodules that are explicitly imported. -; strict_import = False \ No newline at end of file +; strict_import = False diff --git a/skyplane/broadcast/gateway/chunk_store.py b/skyplane/broadcast/gateway/chunk_store.py index a48b40af4..db52920b4 100644 --- a/skyplane/broadcast/gateway/chunk_store.py +++ b/skyplane/broadcast/gateway/chunk_store.py @@ -6,11 +6,10 @@ from pathlib import Path from typing import Dict, Optional +from skyplane.broadcast.gateway.gateway_queue import GatewayQueue from skyplane.chunk import ChunkRequest, ChunkState from skyplane.utils import logger -from skyplane.gateway.gateway_queue import GatewayQueue - class ChunkStore: def __init__(self, chunk_dir: PathLike): @@ -30,19 +29,19 @@ def __init__(self, chunk_dir: PathLike): # self.chunk_requests: Dict[int, ChunkRequest] = {} # type: ignore # queues of incoming chunk requests for each partition from gateway API (passed to operator graph) - self.chunk_requests: Dict[int, GatewayQueue] = {} + self.chunk_requests: Dict[str, GatewayQueue] = {} # queue of chunk status updates coming from operators (passed to gateway API) self.chunk_status_queue: Queue[Dict] = Queue() self.chunk_completions = defaultdict(list) - def add_partition(self, partition_id: int): - if partition_id in self.chunk_requests: - raise ValueError(f"Partition {partition_id} already exists") - self.chunk_requests[partition_id] = GatewayQueue() + def add_partition(self, partition: str): + if partition in self.chunk_requests: + raise ValueError(f"Partition {partition} already exists") + self.chunk_requests[partition] = GatewayQueue() - def get_chunk_file_path(self, chunk_id: int) -> Path: + def get_chunk_file_path(self, chunk_id: str) -> Path: return self.chunk_dir / f"{chunk_id:05d}.chunk" ### @@ -51,7 +50,7 @@ def get_chunk_file_path(self, chunk_id: int) -> Path: def log_chunk_state(self, chunk_req: ChunkRequest, new_status: ChunkState, metadata: Optional[Dict] = None): rec = { "chunk_id": chunk_req.chunk.chunk_id, - "partition": chunk_req.chunk.partition_id, + "partition": chunk_req.chunk.partition, "state": new_status.name, "time": str(datetime.utcnow().isoformat()), } @@ -67,7 +66,7 @@ def log_chunk_state(self, chunk_req: ChunkRequest, new_status: ChunkState, metad ### def add_chunk_request(self, chunk_request: ChunkRequest, state: ChunkState = ChunkState.registered): - self.chunk_requests[chunk_request.chunk.partition_id].put(chunk_request) + self.chunk_requests[chunk_request.chunk.partition].put(chunk_request) # TODO: consider adding to partition queues here? # update state diff --git a/skyplane/broadcast/gateway/gateway_daemon.py b/skyplane/broadcast/gateway/gateway_daemon.py index 53f60312e..fe8cb9383 100644 --- a/skyplane/broadcast/gateway/gateway_daemon.py +++ b/skyplane/broadcast/gateway/gateway_daemon.py @@ -1,33 +1,31 @@ import argparse -from pprint import pprint import atexit import json import os import signal import sys import time +from collections import defaultdict from multiprocessing import Event, Queue from os import PathLike from pathlib import Path +from pprint import pprint from typing import Dict -from skyplane.gateway.chunk_store import ChunkStore -from skyplane.gateway.gateway_daemon_api import GatewayDaemonAPI -from skyplane.utils import logger - -from skyplane.gateway.gateway_queue import GatewayANDQueue, GatewayORQueue - -from skyplane.gateway.operators.gateway_operator import ( - GatewaySender, - GatewayRandomDataGen, - GatewayWriteLocal, +from skyplane.broadcast.gateway.chunk_store import ChunkStore +from skyplane.broadcast.gateway.gateway_daemon_api import GatewayDaemonAPI +from skyplane.broadcast.gateway.gateway_queue import GatewayANDQueue, GatewayORQueue +from skyplane.broadcast.gateway.operators.gateway_operator import ( + GatewayWaitReciever, GatewayObjStoreReadOperator, + GatewayRandomDataGen, + GatewaySender, GatewayObjStoreWriteOperator, - GatewayWaitReciever, + GatewayWriteLocal, ) -from skyplane.gateway.operators.gateway_receiver import GatewayReceiver +from skyplane.broadcast.gateway.operators.gateway_receiver import GatewayReceiver +from skyplane.utils import logger -from collections import defaultdict # TODO: add default partition ID to main # create gateway broadcast diff --git a/skyplane/broadcast/gateway/gateway_daemon_api.py b/skyplane/broadcast/gateway/gateway_daemon_api.py index 3281e3b26..8a456056c 100644 --- a/skyplane/broadcast/gateway/gateway_daemon_api.py +++ b/skyplane/broadcast/gateway/gateway_daemon_api.py @@ -1,19 +1,19 @@ import logging -from collections import defaultdict import logging.handlers import os import threading +from collections import defaultdict from multiprocessing import Queue from queue import Empty from traceback import TracebackException -from typing import Dict, List +from typing import Dict, List, Optional from flask import Flask, jsonify, request from werkzeug.serving import make_server +from skyplane.broadcast.gateway.chunk_store import ChunkStore +from skyplane.broadcast.gateway.operators.gateway_receiver import GatewayReceiver from skyplane.chunk import ChunkRequest, ChunkState -from skyplane.gateway.chunk_store import ChunkStore -from skyplane.gateway.operators.gateway_receiver import GatewayReceiver from skyplane.utils import logger @@ -37,7 +37,7 @@ def __init__( gateway_receiver: GatewayReceiver, error_event, error_queue: Queue, - terminal_operators: Dict[str, List[str]] = None, + terminal_operators: Optional[Dict[str, List[str]]] = None, host="0.0.0.0", port=8081, ): @@ -81,7 +81,7 @@ def __init__( logging.getLogger("werkzeug").setLevel(logging.WARNING) self.server = make_server(host, port, self.app, threaded=True) - def pull_chunk_status_queue(self) -> List[Dict]: + def pull_chunk_status_queue(self): print("pulling queue") out_events = [] while True: @@ -188,16 +188,15 @@ def remove_server(port: int): def register_request_routes(self, app): def make_chunk_req_payload(chunk_req: ChunkRequest): state = self.chunk_status[chunk_req.chunk.chunk_id] - state_name = state.name if state is not None else "unknown" + state_name = state if state is not None else "unknown" return {"req": chunk_req.as_dict(), "state": state_name} def get_chunk_reqs(state=None) -> Dict[int, Dict]: out = {} - for chunk_req, chunk_state in self.chunk_status.items(): + for chunk_id, chunk_state in self.chunk_status.items(): if state is None or chunk_state == state: - out[chunk_req.chunk.chunk_id] = make_chunk_req_payload(chunk_req) - # for chunk_req in self.chunk_store.get_chunk_requests(state): - # out[chunk_req.chunk.chunk_id] = make_chunk_req_payload(chunk_req) + chunk_req = self.chunk_requests[chunk_id] + out[chunk_id] = make_chunk_req_payload(chunk_id) return out def add_chunk_req(body, state): @@ -303,7 +302,7 @@ def get_receiver_compression_profile(): for k, v in self.sender_compressed_sizes.items(): total_size_compressed_bytes += v # TODO: figure out how to get final size of chunks - total_size_uncompressed_bytes += self.chunk_store.get_chunk_request(k).chunk.chunk_length_bytes + total_size_uncompressed_bytes += self.chunk_requests[k].chunk.chunk_length_bytes return jsonify( { "compressed_bytes_sent": total_size_compressed_bytes, diff --git a/skyplane/broadcast/gateway/gateway_program.py b/skyplane/broadcast/gateway/gateway_program.py index 817911ad8..3969eb03d 100644 --- a/skyplane/broadcast/gateway/gateway_program.py +++ b/skyplane/broadcast/gateway/gateway_program.py @@ -1,6 +1,6 @@ -from typing import Optional, List import json from collections import defaultdict +from typing import Optional, List class GatewayOperator: diff --git a/skyplane/broadcast/gateway/gateway_queue.py b/skyplane/broadcast/gateway/gateway_queue.py index fa7f6e37d..92262d697 100644 --- a/skyplane/broadcast/gateway/gateway_queue.py +++ b/skyplane/broadcast/gateway/gateway_queue.py @@ -29,6 +29,7 @@ def __init__(self, maxsize=0): class GatewayANDQueue(GatewayQueue): def __init__(self, maxsize=0): + super().__init__(maxsize) self.q = {} self.maxsize = maxsize diff --git a/skyplane/broadcast/gateway/operators/gateway_operator.py b/skyplane/broadcast/gateway/operators/gateway_operator.py index 86919a393..6e0772894 100644 --- a/skyplane/broadcast/gateway/operators/gateway_operator.py +++ b/skyplane/broadcast/gateway/operators/gateway_operator.py @@ -1,29 +1,28 @@ import json import os -from typing import List import queue import socket import ssl import time import traceback +from abc import ABC, abstractmethod from functools import partial from multiprocessing import Event, Process from typing import Dict, List, Optional import nacl.secret import urllib3 -from abc import ABC, abstractmethod -from skyplane import MB, cloud_config +from skyplane.broadcast.gateway.chunk_store import ChunkStore +from skyplane.broadcast.gateway.gateway_queue import GatewayQueue from skyplane.chunk import ChunkRequest -from skyplane.gateway.chunk_store import ChunkStore +from skyplane.chunk import ChunkState +from skyplane.config_paths import cloud_config +from skyplane.obj_store.object_store_interface import ObjectStoreInterface from skyplane.utils import logger +from skyplane.utils.definitions import MB from skyplane.utils.retry import retry_backoff from skyplane.utils.timer import Timer -from skyplane.obj_store.object_store_interface import ObjectStoreInterface - -from skyplane.gateway.gateway_queue import GatewayQueue -from skyplane.chunk import ChunkState class GatewayOperator(ABC): @@ -73,7 +72,7 @@ def stop_workers(self): p.join() self.processes = [] - def worker_loop(self, worker_id: int, *args): + def worker_loop(self, worker_id: int, **kwargs): self.worker_id = worker_id while not self.exit_flags[worker_id].is_set() and not self.error_event.is_set(): try: @@ -85,7 +84,7 @@ def worker_loop(self, worker_id: int, *args): continue # process chunk - succ = self.process(chunk_req, *args) + succ = self.process(chunk_req, **kwargs) # place in output queue if succ: @@ -114,7 +113,7 @@ def worker_exit(self, worker_id: int): pass @abstractmethod - def process(self, chunk_req: ChunkRequest, **args): + def process(self, chunk_req: ChunkRequest, **kwargs): pass @@ -122,7 +121,7 @@ class GatewayWaitReciever(GatewayOperator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def process(self, chunk_req: ChunkRequest): + def process(self, chunk_req: ChunkRequest, **kwargs): chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id) if not os.path.exists(chunk_file_path): # chunk still not downloaded, re-queue print("Chunk not downloaded yet, re-queueing", chunk_req.chunk.chunk_id, chunk_file_path) @@ -167,7 +166,7 @@ def __init__( # SSL context if use_tls: - self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # type: ignore self.ssl_context.check_hostname = False self.ssl_context.verify_mode = ssl.CERT_NONE logger.info(f"Using {str(ssl.OPENSSL_VERSION)}") @@ -235,10 +234,9 @@ def make_socket(self, dst_host): return sock # send chunks to other instances - def process(self, chunk_req: ChunkRequest, dst_host: str): + def process(self, chunk_req: ChunkRequest, dst_host: str, **kwargs): """Send list of chunks to gateway server, pipelining small chunks together into a single socket stream.""" # notify server of upcoming ChunkRequests - logger.debug(f"[sender:{self.worker_id}] Sending chunk ID {chunk_req.chunk.chunk_id} to IP {dst_host}") chunk_ids = [chunk_req.chunk.chunk_id] @@ -330,8 +328,7 @@ def __init__( super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes) self.size_mb = size_mb - def process(self, chunk_req: ChunkRequest): - + def process(self, chunk_req: ChunkRequest, **kwargs): # wait until enough space available fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute()) size_bytes = int(self.size_mb * MB) @@ -357,10 +354,8 @@ def process(self, chunk_req: ChunkRequest): # os.system(f"fallocate -l {size_bytes} {fpath}") # file_size = os.path.getsize(fpath) # assert file_size == size_bytes, f"File {fpath} has size {file_size} but should be {size_bytes} - chunk store remaining size: {self.chunk_store.remaining_bytes()}" - - print(f"Wrote chunk {chunk_req.chunk.chunk_id} with size {file_size} to {fpath}") + print(f"Wrote chunk {chunk_req.chunk.chunk_id} to {fpath}") chunk_req.chunk.chunk_length_bytes = os.path.getsize(fpath) - return True @@ -378,7 +373,7 @@ def __init__( ): super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes) - def process(self, chunk_req: ChunkRequest): + def process(self, chunk_req: ChunkRequest, **kwargs): # do nothing (already written locally) return True @@ -393,9 +388,9 @@ def __init__( error_event, error_queue: GatewayQueue, n_processes: int = 1, - chunk_store: ChunkStore = None, - bucket_name: str = None, - bucket_region: str = None, + chunk_store: Optional[ChunkStore] = None, + bucket_name: Optional[str] = None, + bucket_region: Optional[str] = None, ): super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes) self.bucket_name = bucket_name @@ -428,9 +423,9 @@ def __init__( error_event, error_queue: GatewayQueue, n_processes: int = 32, - chunk_store: ChunkStore = None, - bucket_name: str = None, - bucket_region: str = None, + chunk_store: Optional[ChunkStore] = None, + bucket_name: Optional[str] = None, + bucket_region: Optional[str] = None, ): super().__init__( handle, region, input_queue, output_queue, error_event, error_queue, n_processes, chunk_store, bucket_name, bucket_region @@ -486,17 +481,17 @@ def __init__( error_event, error_queue: GatewayQueue, n_processes: int = 32, - chunk_store: ChunkStore = None, - bucket_name: str = None, - bucket_region: str = None, + chunk_store: Optional[ChunkStore] = None, + bucket_name: Optional[str] = None, + bucket_region: Optional[str] = None, ): super().__init__( handle, region, input_queue, output_queue, error_event, error_queue, n_processes, chunk_store, bucket_name, bucket_region ) - def process(self, chunk_req: ChunkRequest): + def process(self, chunk_req: ChunkRequest, **kwargs): fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute()) - print("writing", chunk_req.chunk.dest_key, self.bucket_name, self.bucket_region, chunk_req.chunk.chunk_size) + print("writing", chunk_req.chunk.dest_key, self.bucket_name, self.bucket_region, chunk_req.chunk.chunk_length_bytes) logger.debug( f"[obj_store:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}" ) diff --git a/skyplane/broadcast/gateway/operators/gateway_receiver.py b/skyplane/broadcast/gateway/operators/gateway_receiver.py index e93ce4b59..7dde575cf 100644 --- a/skyplane/broadcast/gateway/operators/gateway_receiver.py +++ b/skyplane/broadcast/gateway/operators/gateway_receiver.py @@ -10,17 +10,14 @@ import nacl.secret -from skyplane import MB +from skyplane.broadcast.gateway.cert import generate_self_signed_certificate +from skyplane.broadcast.gateway.chunk_store import ChunkStore from skyplane.chunk import WireProtocolHeader -from skyplane.gateway.cert import generate_self_signed_certificate -from skyplane.gateway.chunk_store import ChunkStore from skyplane.utils import logger +from skyplane.utils.definitions import MB from skyplane.utils.timer import Timer - - - class GatewayReceiver: def __init__( self, @@ -138,7 +135,7 @@ def stop_servers(self): assert len(self.server_processes) == 0 def stop_workers(self): - self.stop_servers(self) + self.stop_servers() def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): server_port = conn.getsockname()[1]