Skip to content

Commit

Permalink
support binary download to admin client
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Sep 13, 2023
1 parent 434f116 commit b7d9256
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 153 deletions.
4 changes: 2 additions & 2 deletions nvflare/apis/impl/job_def_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from nvflare.apis.client_engine_spec import ClientEngineSpec
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -282,7 +282,7 @@ def set_approval(
store.update_meta(self.job_uri(jid), updated_meta, replace=False)
return meta

def save_workspace(self, jid: str, data: bytes, fl_ctx: FLContext):
def save_workspace(self, jid: str, data: Union[bytes, str], fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
store.update_object(self.job_uri(jid), data, WORKSPACE)

Expand Down
6 changes: 3 additions & 3 deletions nvflare/apis/job_def_manager_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -213,12 +213,12 @@ def delete(self, jid: str, fl_ctx: FLContext):
pass

@abstractmethod
def save_workspace(self, jid: str, data: bytes, fl_ctx: FLContext):
def save_workspace(self, jid: str, data: Union[bytes, str], fl_ctx: FLContext):
"""Save the job workspace to the job storage.
Args:
jid (str): Job ID
data: Job workspace data
data: Job workspace data or name of data file
fl_ctx (FLContext): FLContext information
"""
Expand Down
2 changes: 1 addition & 1 deletion nvflare/apis/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class StorageSpec(ABC):
"""

@abstractmethod
def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: bool):
def create_object(self, uri: str, data: Union[bytes, str], meta: dict, overwrite_existing: bool):
"""Creates an object.
Examples of URI:
Expand Down
31 changes: 19 additions & 12 deletions nvflare/app_common/storages/filesystem_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import shutil
import uuid
from pathlib import Path
from typing import List, Tuple
from typing import List, Tuple, Union

from nvflare.apis.storage import DATA, MANIFEST, META, StorageException, StorageSpec
from nvflare.apis.utils.format_check import validate_class_methods_args
Expand Down Expand Up @@ -79,7 +79,20 @@ def __init__(self, root_dir=os.path.abspath(os.sep), uri_root="/"):
self.root_dir = root_dir
self.uri_root = uri_root

def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: bool = False):
def _save_data(self, data: Union[str, bytes], destination: str):
if isinstance(data, bytes):
_write(destination, data)
elif isinstance(data, str):
# path to file that contains data
if not os.path.exists(data):
raise FileNotFoundError(f"file {data} does not exist")
if not os.path.isfile(data):
raise ValueError(f"{data} is not a valid file")
shutil.copyfile(data, destination)
else:
raise ValueError(f"expect data to be bytes or file name but got {type(data)}")

def create_object(self, uri: str, data: Union[bytes, str], meta: dict, overwrite_existing: bool = False):
"""Creates an object.
Args:
Expand Down Expand Up @@ -107,15 +120,12 @@ def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: b

data_path = os.path.join(full_uri, DATA)
meta_path = os.path.join(full_uri, META)

tmp_data_path = data_path + "_" + str(uuid.uuid4())
_write(tmp_data_path, data)
self._save_data(data, data_path)
try:
_write(meta_path, json.dumps(str(meta)).encode("utf-8"))
except Exception as e:
os.remove(tmp_data_path)
os.remove(data_path)
raise e
os.rename(tmp_data_path, data_path)

manifest = os.path.join(full_uri, MANIFEST)
manifest_json = '{"data": {"description": "job definition","format": "bytes"},\
Expand All @@ -124,7 +134,7 @@ def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: b

return full_uri

def update_object(self, uri: str, data: bytes, component_name: str = DATA):
def update_object(self, uri: str, data: Union[bytes, str], component_name: str = DATA):
"""Update the object
Args:
Expand All @@ -142,11 +152,8 @@ def update_object(self, uri: str, data: bytes, component_name: str = DATA):
if not StorageSpec.is_valid_component(component_name):
raise StorageException(f"{component_name } is not a valid component for storage object.")

if not isinstance(data, bytes):
raise StorageException(f"data must be in the type of bytes, got {type(data)}.")

component_path = os.path.join(full_dir_path, component_name)
_write(component_path, data)
self._save_data(data, component_path)

manifest = os.path.join(full_dir_path, MANIFEST)
with open(manifest) as manifest_file:
Expand Down
1 change: 0 additions & 1 deletion nvflare/fuel/hci/checksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class Checksum:

def __init__(self):
self.current_value = 0

Expand Down
21 changes: 6 additions & 15 deletions nvflare/fuel/hci/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@


def get_slice(buf, start: int, length: int):
return buf[start:start+length]
return buf[start : start + length]


class Header:

def __init__(self, marker, num1, num2):
self.marker = marker
self.checksum = 0
Expand Down Expand Up @@ -63,7 +62,6 @@ def from_bytes(cls, buffer: bytes):
return Header(marker, num1, num2)

def to_bytes(self):
print(f"header: {self}")
if self.marker == MARKER_DATA:
num1 = self.seq
num2 = self.size
Expand All @@ -75,7 +73,6 @@ def to_bytes(self):


class ChunkState:

def __init__(self, expect_seq=1):
self.header_bytes = bytearray()
self.header = None
Expand All @@ -87,7 +84,7 @@ def __str__(self):
"header": f"{self.header}",
"header_bytes": f"{self.header_bytes}",
"received": self.received,
"expect_seq": self.expect_seq
"expect_seq": self.expect_seq,
}
return f"{d}"

Expand All @@ -96,18 +93,17 @@ def unpack_header(self):
if self.header.marker == MARKER_DATA:
if self.header.seq != self.expect_seq:
raise RuntimeError(
f"Protocol Error: received seq {self.header.seq} does not match expected seq {self.expect_seq}")
f"Protocol Error: received seq {self.header.seq} does not match expected seq {self.expect_seq}"
)

if self.header.size < 0 or self.header.size > MAX_CHUNK_SIZE:
raise RuntimeError(
f"Protocol Error: received size {self.header.size} is not in [0, {MAX_CHUNK_SIZE}]")
raise RuntimeError(f"Protocol Error: received size {self.header.size} is not in [0, {MAX_CHUNK_SIZE}]")

def is_last(self):
return self.header and self.header.marker == MARKER_END


class Receiver:

def __init__(self, receive_data_func):
self.receive_data_func = receive_data_func
self.checksum = Checksum()
Expand All @@ -126,8 +122,6 @@ def receive(self, data) -> bool:
expected_checksum = self.checksum.result()
if expected_checksum != s.header.checksum:
raise RuntimeError(f"checksum mismatch: expect {expected_checksum} but received {s.header.checksum}")
else:
print("checksum matched!")
return done

def _process_chunk(self, c: ChunkState, data, start: int, length: int):
Expand All @@ -137,7 +131,6 @@ def _process_chunk(self, c: ChunkState, data, start: int, length: int):


class Sender:

def __init__(self, send_data_func):
self.send_data_func = send_data_func
self.checksum = Checksum()
Expand All @@ -159,10 +152,8 @@ def close(self):
raise RuntimeError("this sender is already closed")
self.closed = True
cs = self.checksum.result()
print(f"sending checksum: {cs}")
header = Header(MARKER_END, 0, cs)
header_bytes = header.to_bytes()
print(f"checksum headerbytes: {header_bytes}")
self.send_data_func(header_bytes)


Expand All @@ -188,7 +179,7 @@ def chunk_it(c: ChunkState, data, cursor: int, process_chunk_func) -> ChunkState
c.header_bytes.extend(get_slice(data, cursor, num_bytes_needed))
cursor += num_bytes_needed
data_len -= num_bytes_needed
c.unpack_header() # header bytes are ready
c.unpack_header() # header bytes are ready
else:
c.header_bytes.extend(get_slice(data, cursor, data_len))
return c
Expand Down
46 changes: 20 additions & 26 deletions nvflare/fuel/hci/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ class _ServerReplyJsonProcessor(object):
def __init__(self, ctx: CommandContext):
if not isinstance(ctx, CommandContext):
raise TypeError(f"ctx is not an instance of CommandContext. but get {type(ctx)}")
api = ctx.get_api()
self.debug = api.debug
self.ctx = ctx

def process_server_reply(self, resp):
Expand All @@ -72,8 +70,8 @@ def process_server_reply(self, resp):
Args:
resp: The raw response that returns by the server.
"""
if self.debug:
print("DEBUG: Server Reply: {}".format(resp))
api = self.ctx.get_api()
api.debug("Server Reply: {}".format(resp))

ctx = self.ctx

Expand All @@ -87,11 +85,10 @@ def process_server_reply(self, resp):

if resp is not None:
data = resp[ProtoKey.DATA]
meta = resp[ProtoKey.META]
for item in data:
it = item[ProtoKey.TYPE]
if it == ProtoKey.STRING:
reply_processor.process_string(ctx, item[ProtoKey.DATA], meta)
reply_processor.process_string(ctx, item[ProtoKey.DATA])
elif it == ProtoKey.SUCCESS:
reply_processor.process_success(ctx, item[ProtoKey.DATA])
elif it == ProtoKey.ERROR:
Expand Down Expand Up @@ -130,7 +127,7 @@ def process_shutdown(self, ctx: CommandContext, msg: str):
class _LoginReplyProcessor(ReplyProcessor):
"""Reply processor for handling login and setting the token for the admin client."""

def process_string(self, ctx: CommandContext, item: str, meta: {}):
def process_string(self, ctx: CommandContext, item: str):
api = ctx.get_api()
api.login_result = item

Expand Down Expand Up @@ -391,7 +388,7 @@ def __init__(
self.service_finder.set_secure_context(
ca_cert_path=self.ca_cert, cert_path=self.client_cert, private_key_path=self.client_key
)
self.debug = debug
self._debug = debug
self.cmd_timeout = None

# for login
Expand Down Expand Up @@ -436,9 +433,12 @@ def __init__(
self.service_finder.start(self._handle_sp_address_change)
self._start_session_monitor()

def debug(self, msg):
if self._debug:
print(f"DEBUG: {msg}")

def fire_event(self, event_type: str, ctx: EventContext):
if self.debug:
print(f"DEBUG: firing event {event_type}")
self.debug(f"firing event {event_type}")
if self.event_handlers:
for h in self.event_handlers:
h.handle_event(event_type, ctx)
Expand Down Expand Up @@ -507,8 +507,7 @@ def _try_auto_login(self):
def auto_login(self):
try:
result = self._try_auto_login()
if self.debug:
print(f"DEBUG: login result is {result}")
self.debug(f"login result is {result}")
except Exception as e:
result = {
ResultKey.STATUS: APIStatus.ERROR_RUNTIME,
Expand Down Expand Up @@ -539,8 +538,7 @@ def _close_session_monitor(self):
self.sess_monitor_active = False
if self.sess_monitor_thread:
self.sess_monitor_thread = None
if self.debug:
print("DEBUG: session monitor closed!")
self.debug("session monitor closed!")

def check_session_status_on_server(self):
return self.server_execute("_check_session")
Expand Down Expand Up @@ -584,8 +582,7 @@ def _monitor_session(self, interval):
try:
self.fire_session_event(EventType.SESSION_CLOSED, msg)
except Exception as ex:
if self.debug:
print(f"exception occurred handling event {EventType.SESSION_CLOSED}: {secure_format_exception(ex)}")
self.debug(f"exception occurred handling event {EventType.SESSION_CLOSED}: {secure_format_exception(ex)}")
pass

# this is in the session_monitor thread - do not close the monitor, or we'll run into
Expand Down Expand Up @@ -713,22 +710,22 @@ def _send_to_sock(self, sock, ctx: CommandContext):
conn.close()
receive_bytes_func = ctx.get_bytes_receiver()
if receive_bytes_func is not None:
print("receive_bytes_and_process ...")
self.debug("receive_bytes_and_process ...")
ok = receive_bytes_and_process(sock, receive_bytes_func)
if ok:
ctx.set_command_result({"status": APIStatus.SUCCESS, "details": "OK"})
else:
ctx.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": "error receive_bytes"})
else:
print("receive_and_process ...")
self.debug("receive_and_process ...")
ok = receive_and_process(sock, process_json_func)

if not ok:
process_json_func(
make_error("Failed to communicate with Admin Server {} on {}".format(self.host, self.port))
)
else:
print("reply received!")
self.debug("reply received!")

def _try_command(self, cmd_ctx: CommandContext):
"""Try to execute a command on server side.
Expand All @@ -737,8 +734,7 @@ def _try_command(self, cmd_ctx: CommandContext):
cmd_ctx: The command to execute.
"""
# process_json_func can't return data because how "receive_and_process" is written.
if self.debug:
print(f"DEBUG: sending command '{cmd_ctx.get_command()}'")
self.debug(f"sending command '{cmd_ctx.get_command()}'")

json_processor = _ServerReplyJsonProcessor(cmd_ctx)
process_json_func = json_processor.process_server_reply
Expand Down Expand Up @@ -766,8 +762,7 @@ def _try_command(self, cmd_ctx: CommandContext):
sp_host = self.host
sp_port = self.port

if self.debug:
print(f"DEBUG: use server address {sp_host}:{sp_port}")
self.debug(f"use server address {sp_host}:{sp_port}")

try:
if not self.insecure:
Expand All @@ -790,7 +785,7 @@ def _try_command(self, cmd_ctx: CommandContext):
sock.connect((sp_host, sp_port))
self._send_to_sock(sock, cmd_ctx)
except Exception as e:
if self.debug:
if self._debug:
secure_log_traceback()

process_json_func(
Expand Down Expand Up @@ -924,8 +919,7 @@ def server_execute(self, command, reply_processor=None, cmd_entry=None, cmd_ctx=
secs = time.time() - start
usecs = int(secs * 1000000)

if self.debug:
print(f"DEBUG: server_execute Done [{usecs} usecs] {datetime.now()}")
self.debug(f"server_execute Done [{usecs} usecs] {datetime.now()}")

result = ctx.get_command_result()
meta = ctx.get_meta()
Expand Down
Loading

0 comments on commit b7d9256

Please sign in to comment.