Skip to content

Commit

Permalink
develop binary download
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Sep 12, 2023
1 parent 797224c commit 9771d36
Show file tree
Hide file tree
Showing 13 changed files with 545 additions and 18 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class AdminCommandNames(object):
LIST_JOBS = "list_jobs"
GET_JOB_META = "get_job_meta"
DOWNLOAD_JOB = "download_job"
DOWNLOAD_JOB_FILE = "download_job_file"
ABORT_JOB = "abort_job"
DELETE_JOB = "delete_job"
CLONE_JOB = "clone_job"
Expand Down
27 changes: 27 additions & 0 deletions nvflare/fuel/hci/checksum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import zlib


class Checksum:

def __init__(self):
self.current_value = 0

def update(self, data):
self.current_value = zlib.crc32(data, self.current_value)

def result(self):
return self.current_value & 0xFFFFFFFF
216 changes: 216 additions & 0 deletions nvflare/fuel/hci/chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct

from .checksum import Checksum

HEADER_STRUCT = struct.Struct(">BII") # marker(1), seq_num(4), size(4)
HEADER_LEN = HEADER_STRUCT.size

MARKER_DATA = 101
MARKER_END = 102

MAX_CHUNK_SIZE = 1024 * 1024


def get_slice(buf, start: int, length: int):
return buf[start:start+length]


class Header:

def __init__(self, marker, num1, num2):
self.marker = marker
self.checksum = 0
self.seq = 0
self.size = 0
if marker == MARKER_DATA:
self.seq = num1
self.size = num2
elif marker == MARKER_END:
if num1 != 0:
raise ValueError(f"num1 must be 0 for checksum but got {num1}")
self.checksum = num2
else:
raise ValueError(f"invalid chunk marker {marker}")

def __str__(self):
d = {
"marker": self.marker,
"seq": self.seq,
"size": self.size,
"checksum": self.checksum,
}
return f"{d}"

@classmethod
def from_bytes(cls, buffer: bytes):
if len(buffer) < HEADER_LEN:
raise ValueError(f"Prefix too short")

marker, num1, num2 = HEADER_STRUCT.unpack_from(buffer, 0)
return Header(marker, num1, num2)

def to_bytes(self):
print(f"header: {self}")
if self.marker == MARKER_DATA:
num1 = self.seq
num2 = self.size
else:
num1 = 0
num2 = self.checksum

return HEADER_STRUCT.pack(self.marker, num1, num2)


class ChunkState:

def __init__(self, expect_seq=1):
self.header_bytes = bytearray()
self.header = None
self.received = 0
self.expect_seq = expect_seq

def __str__(self):
d = {
"header": f"{self.header}",
"header_bytes": f"{self.header_bytes}",
"received": self.received,
"expect_seq": self.expect_seq
}
return f"{d}"

def unpack_header(self):
self.header = Header.from_bytes(self.header_bytes)
if self.header.marker == MARKER_DATA:
if self.header.seq != self.expect_seq:
raise RuntimeError(
f"Protocol Error: received seq {self.header.seq} does not match expected seq {self.expect_seq}")

if self.header.size < 0 or self.header.size > MAX_CHUNK_SIZE:
raise RuntimeError(
f"Protocol Error: received size {self.header.size} is not in [0, {MAX_CHUNK_SIZE}]")

def is_last(self):
return self.header and self.header.marker == MARKER_END


class Receiver:

def __init__(self, receive_data_func):
self.receive_data_func = receive_data_func
self.checksum = Checksum()
self.current_state = ChunkState()
self.done = False

def receive(self, data) -> bool:
if self.done:
raise RuntimeError("this receiver is already done")
s = chunk_it(self.current_state, data, 0, self._process_chunk)
self.current_state = s
done = s.is_last()
if done:
self.done = True
# compare checksum
expected_checksum = self.checksum.result()
if expected_checksum != s.header.checksum:
raise RuntimeError(f"checksum mismatch: expect {expected_checksum} but received {s.header.checksum}")
else:
print("checksum matched!")
return done

def _process_chunk(self, c: ChunkState, data, start: int, length: int):
self.checksum.update(get_slice(data, start, length))
if self.receive_data_func:
self.receive_data_func(data, start, length)


class Sender:

def __init__(self, send_data_func):
self.send_data_func = send_data_func
self.checksum = Checksum()
self.next_seq = 1
self.closed = False

def send(self, data):
if self.closed:
raise RuntimeError("this sender is already closed")
header = Header(MARKER_DATA, self.next_seq, len(data))
self.next_seq += 1
self.checksum.update(data)
header_bytes = header.to_bytes()
self.send_data_func(header_bytes)
self.send_data_func(data)

def close(self):
if self.closed:
raise RuntimeError("this sender is already closed")
self.closed = True
cs = self.checksum.result()
print(f"sending checksum: {cs}")
header = Header(MARKER_END, 0, cs)
header_bytes = header.to_bytes()
print(f"checksum headerbytes: {header_bytes}")
self.send_data_func(header_bytes)


def chunk_it(c: ChunkState, data, cursor: int, process_chunk_func) -> ChunkState:
if not isinstance(data, (bytearray, bytes)):
raise ValueError(f"can only chunk bytes data but got {type(data)}")

data_len = len(data)
if data_len <= 0:
return c

if cursor < 0 or cursor >= data_len:
raise ValueError(f"cursor {cursor} is out of data range [0, {data_len-1}]")
data_len -= cursor

header_bytes_len = len(c.header_bytes)
if header_bytes_len < HEADER_LEN:
# header not completed yet
num_bytes_needed = HEADER_LEN - header_bytes_len
# need this many bytes for header
if data_len >= num_bytes_needed:
# data has enough bytes
c.header_bytes.extend(get_slice(data, cursor, num_bytes_needed))
cursor += num_bytes_needed
data_len -= num_bytes_needed
c.unpack_header() # header bytes are ready
else:
c.header_bytes.extend(get_slice(data, cursor, data_len))
return c

if data_len == 0 or c.is_last():
return c

lack = c.header.size - c.received
if data_len <= lack:
# remaining data is part of the chunk
c.received += data_len
process_chunk_func(c, data, cursor, data_len)
if c.received == c.header.size:
# this chunk is completed: start a new chunk
return ChunkState(c.header.seq + 1)
else:
# this chunk is not done
return c
else:
# some remaining data is part of the chunk, but after that belongs to next chunk
c.received += lack
process_chunk_func(c, data, cursor, lack)
cursor += lack
next_chunk = ChunkState(c.header.seq + 1)
return chunk_it(next_chunk, data, cursor, process_chunk_func)
24 changes: 20 additions & 4 deletions nvflare/fuel/hci/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from nvflare.fuel.hci.client.event import EventContext, EventHandler, EventPropKey, EventType
from nvflare.fuel.hci.cmd_arg_utils import split_to_args
from nvflare.fuel.hci.conn import Connection, receive_and_process
from nvflare.fuel.hci.conn import Connection, receive_and_process, receive_bytes_and_process
from nvflare.fuel.hci.proto import ConfirmMethod, InternalCommands, MetaKey, ProtoKey, make_error
from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandRegister
from nvflare.fuel.hci.table import Table
Expand Down Expand Up @@ -710,11 +710,24 @@ def _send_to_sock(self, sock, ctx: CommandContext):
conn.update_meta({MetaKey.CUSTOM_PROPS: custom_props})

conn.close()
ok = receive_and_process(sock, process_json_func)
receive_bytes_func = ctx.get_bytes_receiver()
if receive_bytes_func is not None:
print("receive_bytes_and_process ...")
ok = receive_bytes_and_process(sock, receive_bytes_func)
if ok:
ctx.set_command_result({"status": APIStatus.SUCCESS, "details": "OK"})
else:
ctx.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": "error receive_bytes"})
else:
print("receive_and_process ...")
ok = receive_and_process(sock, process_json_func)

if not ok:
process_json_func(
make_error("Failed to communicate with Admin Server {} on {}".format(self.host, self.port))
)
else:
print("reply received!")

def _try_command(self, cmd_ctx: CommandContext):
"""Try to execute a command on server side.
Expand Down Expand Up @@ -895,12 +908,15 @@ def do_command(self, command):

return self.server_execute(command, cmd_entry=ent)

def server_execute(self, command, reply_processor=None, cmd_entry=None):
def server_execute(self, command, reply_processor=None, cmd_entry=None, cmd_ctx=None):
if self.in_logout:
return {ResultKey.STATUS: APIStatus.SUCCESS, ResultKey.DETAILS: "session is logging out"}

args = split_to_args(command)
ctx = self._new_command_context(command, args, cmd_entry)
if cmd_ctx:
ctx = cmd_ctx
else:
ctx = self._new_command_context(command, args, cmd_entry)
start = time.time()
ctx.set_reply_processor(reply_processor)
self._try_command(ctx)
Expand Down
14 changes: 13 additions & 1 deletion nvflare/fuel/hci/client/api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ class CommandCtxKey(object):
JSON_PROCESSOR = "json_processor"
META = "meta"
CUSTOM_PROPS = "custom_props"
BYTES_RECEIVER = "bytes_receiver"


class CommandContext(SimpleContext):

def set_bytes_receiver(self, r):
self.set_prop(CommandCtxKey.BYTES_RECEIVER, r)

def get_bytes_receiver(self):
return self.get_prop(CommandCtxKey.BYTES_RECEIVER)

def set_command_result(self, result):
self.set_prop(CommandCtxKey.RESULT, result)

Expand Down Expand Up @@ -145,6 +153,9 @@ def protocol_error(self, ctx: CommandContext, err: str):
def reply_done(self, ctx: CommandContext):
pass

def process_bytes(self, ctx: CommandContext):
pass


class AdminAPISpec(ABC):
@abstractmethod
Expand All @@ -163,12 +174,13 @@ def do_command(self, command: str):
pass

@abstractmethod
def server_execute(self, command: str, reply_processor=None):
def server_execute(self, command: str, reply_processor=None, cmd_ctx=None):
"""Executes a command on server side.
Args:
command: The command to be executed.
reply_processor: processor to process reply from server
cmd_ctx: command context
"""
pass

Expand Down
Loading

0 comments on commit 9771d36

Please sign in to comment.