Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[broadcast] Fix issues with gateway merge to main #689

Merged
merged 3 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pytype.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ pythonpath =
; protocols = False

# Experimental: Only load submodules that are explicitly imported.
; strict_import = False
; strict_import = False
19 changes: 9 additions & 10 deletions skyplane/broadcast/gateway/chunk_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from pathlib import Path
from typing import Dict, Optional

from skyplane.broadcast.gateway.gateway_queue import GatewayQueue
from skyplane.chunk import ChunkRequest, ChunkState
from skyplane.utils import logger

from skyplane.gateway.gateway_queue import GatewayQueue


class ChunkStore:
def __init__(self, chunk_dir: PathLike):
Expand All @@ -30,19 +29,19 @@ def __init__(self, chunk_dir: PathLike):
# self.chunk_requests: Dict[int, ChunkRequest] = {} # type: ignore

# queues of incoming chunk requests for each partition from gateway API (passed to operator graph)
self.chunk_requests: Dict[int, GatewayQueue] = {}
self.chunk_requests: Dict[str, GatewayQueue] = {}

# queue of chunk status updates coming from operators (passed to gateway API)
self.chunk_status_queue: Queue[Dict] = Queue()

self.chunk_completions = defaultdict(list)

def add_partition(self, partition_id: int):
if partition_id in self.chunk_requests:
raise ValueError(f"Partition {partition_id} already exists")
self.chunk_requests[partition_id] = GatewayQueue()
def add_partition(self, partition: str):
if partition in self.chunk_requests:
raise ValueError(f"Partition {partition} already exists")
self.chunk_requests[partition] = GatewayQueue()

def get_chunk_file_path(self, chunk_id: int) -> Path:
def get_chunk_file_path(self, chunk_id: str) -> Path:
return self.chunk_dir / f"{chunk_id:05d}.chunk"

###
Expand All @@ -51,7 +50,7 @@ def get_chunk_file_path(self, chunk_id: int) -> Path:
def log_chunk_state(self, chunk_req: ChunkRequest, new_status: ChunkState, metadata: Optional[Dict] = None):
rec = {
"chunk_id": chunk_req.chunk.chunk_id,
"partition": chunk_req.chunk.partition_id,
"partition": chunk_req.chunk.partition,
"state": new_status.name,
"time": str(datetime.utcnow().isoformat()),
}
Expand All @@ -67,7 +66,7 @@ def log_chunk_state(self, chunk_req: ChunkRequest, new_status: ChunkState, metad
###
def add_chunk_request(self, chunk_request: ChunkRequest, state: ChunkState = ChunkState.registered):

self.chunk_requests[chunk_request.chunk.partition_id].put(chunk_request)
self.chunk_requests[chunk_request.chunk.partition].put(chunk_request)
# TODO: consider adding to partition queues here?

# update state
Expand Down
26 changes: 12 additions & 14 deletions skyplane/broadcast/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import argparse
from pprint import pprint
import atexit
import json
import os
import signal
import sys
import time
from collections import defaultdict
from multiprocessing import Event, Queue
from os import PathLike
from pathlib import Path
from pprint import pprint
from typing import Dict

from skyplane.gateway.chunk_store import ChunkStore
from skyplane.gateway.gateway_daemon_api import GatewayDaemonAPI
from skyplane.utils import logger

from skyplane.gateway.gateway_queue import GatewayANDQueue, GatewayORQueue

from skyplane.gateway.operators.gateway_operator import (
GatewaySender,
GatewayRandomDataGen,
GatewayWriteLocal,
from skyplane.broadcast.gateway.chunk_store import ChunkStore
from skyplane.broadcast.gateway.gateway_daemon_api import GatewayDaemonAPI
from skyplane.broadcast.gateway.gateway_queue import GatewayANDQueue, GatewayORQueue
from skyplane.broadcast.gateway.operators.gateway_operator import (
GatewayWaitReciever,
GatewayObjStoreReadOperator,
GatewayRandomDataGen,
GatewaySender,
GatewayObjStoreWriteOperator,
GatewayWaitReciever,
GatewayWriteLocal,
)
from skyplane.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.broadcast.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.utils import logger

from collections import defaultdict

# TODO: add default partition ID to main
# create gateway broadcast
Expand Down
23 changes: 11 additions & 12 deletions skyplane/broadcast/gateway/gateway_daemon_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging
from collections import defaultdict
import logging.handlers
import os
import threading
from collections import defaultdict
from multiprocessing import Queue
from queue import Empty
from traceback import TracebackException
from typing import Dict, List
from typing import Dict, List, Optional

from flask import Flask, jsonify, request
from werkzeug.serving import make_server

from skyplane.broadcast.gateway.chunk_store import ChunkStore
from skyplane.broadcast.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.chunk import ChunkRequest, ChunkState
from skyplane.gateway.chunk_store import ChunkStore
from skyplane.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.utils import logger


Expand All @@ -37,7 +37,7 @@ def __init__(
gateway_receiver: GatewayReceiver,
error_event,
error_queue: Queue,
terminal_operators: Dict[str, List[str]] = None,
terminal_operators: Optional[Dict[str, List[str]]] = None,
host="0.0.0.0",
port=8081,
):
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
logging.getLogger("werkzeug").setLevel(logging.WARNING)
self.server = make_server(host, port, self.app, threaded=True)

def pull_chunk_status_queue(self) -> List[Dict]:
def pull_chunk_status_queue(self):
print("pulling queue")
out_events = []
while True:
Expand Down Expand Up @@ -188,16 +188,15 @@ def remove_server(port: int):
def register_request_routes(self, app):
def make_chunk_req_payload(chunk_req: ChunkRequest):
state = self.chunk_status[chunk_req.chunk.chunk_id]
state_name = state.name if state is not None else "unknown"
state_name = state if state is not None else "unknown"
return {"req": chunk_req.as_dict(), "state": state_name}

def get_chunk_reqs(state=None) -> Dict[int, Dict]:
out = {}
for chunk_req, chunk_state in self.chunk_status.items():
for chunk_id, chunk_state in self.chunk_status.items():
if state is None or chunk_state == state:
out[chunk_req.chunk.chunk_id] = make_chunk_req_payload(chunk_req)
# for chunk_req in self.chunk_store.get_chunk_requests(state):
# out[chunk_req.chunk.chunk_id] = make_chunk_req_payload(chunk_req)
chunk_req = self.chunk_requests[chunk_id]
out[chunk_id] = make_chunk_req_payload(chunk_id)
return out

def add_chunk_req(body, state):
Expand Down Expand Up @@ -303,7 +302,7 @@ def get_receiver_compression_profile():
for k, v in self.sender_compressed_sizes.items():
total_size_compressed_bytes += v
# TODO: figure out how to get final size of chunks
total_size_uncompressed_bytes += self.chunk_store.get_chunk_request(k).chunk.chunk_length_bytes
total_size_uncompressed_bytes += self.chunk_requests[k].chunk.chunk_length_bytes
return jsonify(
{
"compressed_bytes_sent": total_size_compressed_bytes,
Expand Down
2 changes: 1 addition & 1 deletion skyplane/broadcast/gateway/gateway_program.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List
import json
from collections import defaultdict
from typing import Optional, List


class GatewayOperator:
Expand Down
1 change: 1 addition & 0 deletions skyplane/broadcast/gateway/gateway_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, maxsize=0):

class GatewayANDQueue(GatewayQueue):
def __init__(self, maxsize=0):
super().__init__(maxsize)
self.q = {}
self.maxsize = maxsize

Expand Down
59 changes: 27 additions & 32 deletions skyplane/broadcast/gateway/operators/gateway_operator.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import json
import os
from typing import List
import queue
import socket
import ssl
import time
import traceback
from abc import ABC, abstractmethod
from functools import partial
from multiprocessing import Event, Process
from typing import Dict, List, Optional

import nacl.secret
import urllib3
from abc import ABC, abstractmethod

from skyplane import MB, cloud_config
from skyplane.broadcast.gateway.chunk_store import ChunkStore
from skyplane.broadcast.gateway.gateway_queue import GatewayQueue
from skyplane.chunk import ChunkRequest
from skyplane.gateway.chunk_store import ChunkStore
from skyplane.chunk import ChunkState
from skyplane.config_paths import cloud_config
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.utils import logger
from skyplane.utils.definitions import MB
from skyplane.utils.retry import retry_backoff
from skyplane.utils.timer import Timer
from skyplane.obj_store.object_store_interface import ObjectStoreInterface

from skyplane.gateway.gateway_queue import GatewayQueue
from skyplane.chunk import ChunkState


class GatewayOperator(ABC):
Expand Down Expand Up @@ -73,7 +72,7 @@ def stop_workers(self):
p.join()
self.processes = []

def worker_loop(self, worker_id: int, *args):
def worker_loop(self, worker_id: int, **kwargs):
self.worker_id = worker_id
while not self.exit_flags[worker_id].is_set() and not self.error_event.is_set():
try:
Expand All @@ -85,7 +84,7 @@ def worker_loop(self, worker_id: int, *args):
continue

# process chunk
succ = self.process(chunk_req, *args)
succ = self.process(chunk_req, **kwargs)

# place in output queue
if succ:
Expand Down Expand Up @@ -114,15 +113,15 @@ def worker_exit(self, worker_id: int):
pass

@abstractmethod
def process(self, chunk_req: ChunkRequest, **args):
def process(self, chunk_req: ChunkRequest, **kwargs):
pass


class GatewayWaitReciever(GatewayOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def process(self, chunk_req: ChunkRequest):
def process(self, chunk_req: ChunkRequest, **kwargs):
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id)
if not os.path.exists(chunk_file_path): # chunk still not downloaded, re-queue
print("Chunk not downloaded yet, re-queueing", chunk_req.chunk.chunk_id, chunk_file_path)
Expand Down Expand Up @@ -167,7 +166,7 @@ def __init__(

# SSL context
if use_tls:
self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # type: ignore
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
logger.info(f"Using {str(ssl.OPENSSL_VERSION)}")
Expand Down Expand Up @@ -235,10 +234,9 @@ def make_socket(self, dst_host):
return sock

# send chunks to other instances
def process(self, chunk_req: ChunkRequest, dst_host: str):
def process(self, chunk_req: ChunkRequest, dst_host: str, **kwargs):
"""Send list of chunks to gateway server, pipelining small chunks together into a single socket stream."""
# notify server of upcoming ChunkRequests

logger.debug(f"[sender:{self.worker_id}] Sending chunk ID {chunk_req.chunk.chunk_id} to IP {dst_host}")

chunk_ids = [chunk_req.chunk.chunk_id]
Expand Down Expand Up @@ -330,8 +328,7 @@ def __init__(
super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes)
self.size_mb = size_mb

def process(self, chunk_req: ChunkRequest):

def process(self, chunk_req: ChunkRequest, **kwargs):
# wait until enough space available
fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute())
size_bytes = int(self.size_mb * MB)
Expand All @@ -357,10 +354,8 @@ def process(self, chunk_req: ChunkRequest):
# os.system(f"fallocate -l {size_bytes} {fpath}")
# file_size = os.path.getsize(fpath)
# assert file_size == size_bytes, f"File {fpath} has size {file_size} but should be {size_bytes} - chunk store remaining size: {self.chunk_store.remaining_bytes()}"

print(f"Wrote chunk {chunk_req.chunk.chunk_id} with size {file_size} to {fpath}")
print(f"Wrote chunk {chunk_req.chunk.chunk_id} to {fpath}")
chunk_req.chunk.chunk_length_bytes = os.path.getsize(fpath)

return True


Expand All @@ -378,7 +373,7 @@ def __init__(
):
super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes)

def process(self, chunk_req: ChunkRequest):
def process(self, chunk_req: ChunkRequest, **kwargs):
# do nothing (already written locally)
return True

Expand All @@ -393,9 +388,9 @@ def __init__(
error_event,
error_queue: GatewayQueue,
n_processes: int = 1,
chunk_store: ChunkStore = None,
bucket_name: str = None,
bucket_region: str = None,
chunk_store: Optional[ChunkStore] = None,
bucket_name: Optional[str] = None,
bucket_region: Optional[str] = None,
):
super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes)
self.bucket_name = bucket_name
Expand Down Expand Up @@ -428,9 +423,9 @@ def __init__(
error_event,
error_queue: GatewayQueue,
n_processes: int = 32,
chunk_store: ChunkStore = None,
bucket_name: str = None,
bucket_region: str = None,
chunk_store: Optional[ChunkStore] = None,
bucket_name: Optional[str] = None,
bucket_region: Optional[str] = None,
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, n_processes, chunk_store, bucket_name, bucket_region
Expand Down Expand Up @@ -486,17 +481,17 @@ def __init__(
error_event,
error_queue: GatewayQueue,
n_processes: int = 32,
chunk_store: ChunkStore = None,
bucket_name: str = None,
bucket_region: str = None,
chunk_store: Optional[ChunkStore] = None,
bucket_name: Optional[str] = None,
bucket_region: Optional[str] = None,
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, n_processes, chunk_store, bucket_name, bucket_region
)

def process(self, chunk_req: ChunkRequest):
def process(self, chunk_req: ChunkRequest, **kwargs):
fpath = str(self.chunk_store.get_chunk_file_path(chunk_req.chunk.chunk_id).absolute())
print("writing", chunk_req.chunk.dest_key, self.bucket_name, self.bucket_region, chunk_req.chunk.chunk_size)
print("writing", chunk_req.chunk.dest_key, self.bucket_name, self.bucket_region, chunk_req.chunk.chunk_length_bytes)
logger.debug(
f"[obj_store:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}"
)
Expand Down
Loading