From 9771d3661dc73e5e2d1d859eed70373503dc6586 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Tue, 12 Sep 2023 11:40:24 -0400 Subject: [PATCH] develop binary download --- nvflare/apis/fl_constant.py | 1 + nvflare/fuel/hci/checksum.py | 27 +++ nvflare/fuel/hci/chunk.py | 216 +++++++++++++++++++++ nvflare/fuel/hci/client/api.py | 24 ++- nvflare/fuel/hci/client/api_spec.py | 14 +- nvflare/fuel/hci/client/file_transfer.py | 85 +++++++- nvflare/fuel/hci/conn.py | 23 ++- nvflare/fuel/hci/file_transfer_defs.py | 5 +- nvflare/fuel/hci/proto.py | 3 + nvflare/fuel/hci/server/binary_transfer.py | 91 +++++++++ nvflare/fuel/hci/server/file_transfer.py | 7 - nvflare/private/fed/server/job_cmds.py | 64 +++++- nvflare/security/security.py | 3 + 13 files changed, 545 insertions(+), 18 deletions(-) create mode 100644 nvflare/fuel/hci/checksum.py create mode 100644 nvflare/fuel/hci/chunk.py create mode 100644 nvflare/fuel/hci/server/binary_transfer.py diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 47b6ba8fbf..c14713eacc 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -172,6 +172,7 @@ class AdminCommandNames(object): LIST_JOBS = "list_jobs" GET_JOB_META = "get_job_meta" DOWNLOAD_JOB = "download_job" + DOWNLOAD_JOB_FILE = "download_job_file" ABORT_JOB = "abort_job" DELETE_JOB = "delete_job" CLONE_JOB = "clone_job" diff --git a/nvflare/fuel/hci/checksum.py b/nvflare/fuel/hci/checksum.py new file mode 100644 index 0000000000..53e3f28138 --- /dev/null +++ b/nvflare/fuel/hci/checksum.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zlib + + +class Checksum: + + def __init__(self): + self.current_value = 0 + + def update(self, data): + self.current_value = zlib.crc32(data, self.current_value) + + def result(self): + return self.current_value & 0xFFFFFFFF diff --git a/nvflare/fuel/hci/chunk.py b/nvflare/fuel/hci/chunk.py new file mode 100644 index 0000000000..de6911f493 --- /dev/null +++ b/nvflare/fuel/hci/chunk.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import struct + +from .checksum import Checksum + +HEADER_STRUCT = struct.Struct(">BII") # marker(1), seq_num(4), size(4) +HEADER_LEN = HEADER_STRUCT.size + +MARKER_DATA = 101 +MARKER_END = 102 + +MAX_CHUNK_SIZE = 1024 * 1024 + + +def get_slice(buf, start: int, length: int): + return buf[start:start+length] + + +class Header: + + def __init__(self, marker, num1, num2): + self.marker = marker + self.checksum = 0 + self.seq = 0 + self.size = 0 + if marker == MARKER_DATA: + self.seq = num1 + self.size = num2 + elif marker == MARKER_END: + if num1 != 0: + raise ValueError(f"num1 must be 0 for checksum but got {num1}") + self.checksum = num2 + else: + raise ValueError(f"invalid chunk marker {marker}") + + def __str__(self): + d = { + "marker": self.marker, + "seq": self.seq, + "size": self.size, + "checksum": self.checksum, + } + return f"{d}" + + @classmethod + def from_bytes(cls, buffer: bytes): + if len(buffer) < HEADER_LEN: + raise ValueError(f"Prefix too short") + + marker, num1, num2 = HEADER_STRUCT.unpack_from(buffer, 0) + return Header(marker, num1, num2) + + def to_bytes(self): + print(f"header: {self}") + if self.marker == MARKER_DATA: + num1 = self.seq + num2 = self.size + else: + num1 = 0 + num2 = self.checksum + + return HEADER_STRUCT.pack(self.marker, num1, num2) + + +class ChunkState: + + def __init__(self, expect_seq=1): + self.header_bytes = bytearray() + self.header = None + self.received = 0 + self.expect_seq = expect_seq + + def __str__(self): + d = { + "header": f"{self.header}", + "header_bytes": f"{self.header_bytes}", + "received": self.received, + "expect_seq": self.expect_seq + } + return f"{d}" + + def unpack_header(self): + self.header = Header.from_bytes(self.header_bytes) + if self.header.marker == MARKER_DATA: + if self.header.seq != self.expect_seq: + raise RuntimeError( + f"Protocol Error: received seq {self.header.seq} does not match expected seq {self.expect_seq}") + + if self.header.size < 0 or self.header.size > MAX_CHUNK_SIZE: + raise RuntimeError( + f"Protocol Error: received size {self.header.size} is not in [0, {MAX_CHUNK_SIZE}]") + + def is_last(self): + return self.header and self.header.marker == MARKER_END + + +class Receiver: + + def __init__(self, receive_data_func): + self.receive_data_func = receive_data_func + self.checksum = Checksum() + self.current_state = ChunkState() + self.done = False + + def receive(self, data) -> bool: + if self.done: + raise RuntimeError("this receiver is already done") + s = chunk_it(self.current_state, data, 0, self._process_chunk) + self.current_state = s + done = s.is_last() + if done: + self.done = True + # compare checksum + expected_checksum = self.checksum.result() + if expected_checksum != s.header.checksum: + raise RuntimeError(f"checksum mismatch: expect {expected_checksum} but received {s.header.checksum}") + else: + print("checksum matched!") + return done + + def _process_chunk(self, c: ChunkState, data, start: int, length: int): + self.checksum.update(get_slice(data, start, length)) + if self.receive_data_func: + self.receive_data_func(data, start, length) + + +class Sender: + + def __init__(self, send_data_func): + self.send_data_func = send_data_func + self.checksum = Checksum() + self.next_seq = 1 + self.closed = False + + def send(self, data): + if self.closed: + raise RuntimeError("this sender is already closed") + header = Header(MARKER_DATA, self.next_seq, len(data)) + self.next_seq += 1 + self.checksum.update(data) + header_bytes = header.to_bytes() + self.send_data_func(header_bytes) + self.send_data_func(data) + + def close(self): + if self.closed: + raise RuntimeError("this sender is already closed") + self.closed = True + cs = self.checksum.result() + print(f"sending checksum: {cs}") + header = Header(MARKER_END, 0, cs) + header_bytes = header.to_bytes() + print(f"checksum headerbytes: {header_bytes}") + self.send_data_func(header_bytes) + + +def chunk_it(c: ChunkState, data, cursor: int, process_chunk_func) -> ChunkState: + if not isinstance(data, (bytearray, bytes)): + raise ValueError(f"can only chunk bytes data but got {type(data)}") + + data_len = len(data) + if data_len <= 0: + return c + + if cursor < 0 or cursor >= data_len: + raise ValueError(f"cursor {cursor} is out of data range [0, {data_len-1}]") + data_len -= cursor + + header_bytes_len = len(c.header_bytes) + if header_bytes_len < HEADER_LEN: + # header not completed yet + num_bytes_needed = HEADER_LEN - header_bytes_len + # need this many bytes for header + if data_len >= num_bytes_needed: + # data has enough bytes + c.header_bytes.extend(get_slice(data, cursor, num_bytes_needed)) + cursor += num_bytes_needed + data_len -= num_bytes_needed + c.unpack_header() # header bytes are ready + else: + c.header_bytes.extend(get_slice(data, cursor, data_len)) + return c + + if data_len == 0 or c.is_last(): + return c + + lack = c.header.size - c.received + if data_len <= lack: + # remaining data is part of the chunk + c.received += data_len + process_chunk_func(c, data, cursor, data_len) + if c.received == c.header.size: + # this chunk is completed: start a new chunk + return ChunkState(c.header.seq + 1) + else: + # this chunk is not done + return c + else: + # some remaining data is part of the chunk, but after that belongs to next chunk + c.received += lack + process_chunk_func(c, data, cursor, lack) + cursor += lack + next_chunk = ChunkState(c.header.seq + 1) + return chunk_it(next_chunk, data, cursor, process_chunk_func) diff --git a/nvflare/fuel/hci/client/api.py b/nvflare/fuel/hci/client/api.py index 1f2415d6d3..9aafddc298 100644 --- a/nvflare/fuel/hci/client/api.py +++ b/nvflare/fuel/hci/client/api.py @@ -23,7 +23,7 @@ from nvflare.fuel.hci.client.event import EventContext, EventHandler, EventPropKey, EventType from nvflare.fuel.hci.cmd_arg_utils import split_to_args -from nvflare.fuel.hci.conn import Connection, receive_and_process +from nvflare.fuel.hci.conn import Connection, receive_and_process, receive_bytes_and_process from nvflare.fuel.hci.proto import ConfirmMethod, InternalCommands, MetaKey, ProtoKey, make_error from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandRegister from nvflare.fuel.hci.table import Table @@ -710,11 +710,24 @@ def _send_to_sock(self, sock, ctx: CommandContext): conn.update_meta({MetaKey.CUSTOM_PROPS: custom_props}) conn.close() - ok = receive_and_process(sock, process_json_func) + receive_bytes_func = ctx.get_bytes_receiver() + if receive_bytes_func is not None: + print("receive_bytes_and_process ...") + ok = receive_bytes_and_process(sock, receive_bytes_func) + if ok: + ctx.set_command_result({"status": APIStatus.SUCCESS, "details": "OK"}) + else: + ctx.set_command_result({"status": APIStatus.ERROR_RUNTIME, "details": "error receive_bytes"}) + else: + print("receive_and_process ...") + ok = receive_and_process(sock, process_json_func) + if not ok: process_json_func( make_error("Failed to communicate with Admin Server {} on {}".format(self.host, self.port)) ) + else: + print("reply received!") def _try_command(self, cmd_ctx: CommandContext): """Try to execute a command on server side. @@ -895,12 +908,15 @@ def do_command(self, command): return self.server_execute(command, cmd_entry=ent) - def server_execute(self, command, reply_processor=None, cmd_entry=None): + def server_execute(self, command, reply_processor=None, cmd_entry=None, cmd_ctx=None): if self.in_logout: return {ResultKey.STATUS: APIStatus.SUCCESS, ResultKey.DETAILS: "session is logging out"} args = split_to_args(command) - ctx = self._new_command_context(command, args, cmd_entry) + if cmd_ctx: + ctx = cmd_ctx + else: + ctx = self._new_command_context(command, args, cmd_entry) start = time.time() ctx.set_reply_processor(reply_processor) self._try_command(ctx) diff --git a/nvflare/fuel/hci/client/api_spec.py b/nvflare/fuel/hci/client/api_spec.py index 415d377ef4..92728ca785 100644 --- a/nvflare/fuel/hci/client/api_spec.py +++ b/nvflare/fuel/hci/client/api_spec.py @@ -33,9 +33,17 @@ class CommandCtxKey(object): JSON_PROCESSOR = "json_processor" META = "meta" CUSTOM_PROPS = "custom_props" + BYTES_RECEIVER = "bytes_receiver" class CommandContext(SimpleContext): + + def set_bytes_receiver(self, r): + self.set_prop(CommandCtxKey.BYTES_RECEIVER, r) + + def get_bytes_receiver(self): + return self.get_prop(CommandCtxKey.BYTES_RECEIVER) + def set_command_result(self, result): self.set_prop(CommandCtxKey.RESULT, result) @@ -145,6 +153,9 @@ def protocol_error(self, ctx: CommandContext, err: str): def reply_done(self, ctx: CommandContext): pass + def process_bytes(self, ctx: CommandContext): + pass + class AdminAPISpec(ABC): @abstractmethod @@ -163,12 +174,13 @@ def do_command(self, command: str): pass @abstractmethod - def server_execute(self, command: str, reply_processor=None): + def server_execute(self, command: str, reply_processor=None, cmd_ctx=None): """Executes a command on server side. Args: command: The command to be executed. reply_processor: processor to process reply from server + cmd_ctx: command context """ pass diff --git a/nvflare/fuel/hci/client/file_transfer.py b/nvflare/fuel/hci/client/file_transfer.py index 9551669685..aaa1b20375 100644 --- a/nvflare/fuel/hci/client/file_transfer.py +++ b/nvflare/fuel/hci/client/file_transfer.py @@ -14,6 +14,8 @@ import os +from nvflare.fuel.hci.proto import MetaKey, ProtoKey + import nvflare.fuel.hci.file_transfer_defs as ftd from nvflare.fuel.hci.base64_utils import ( b64str_to_binary_file, @@ -157,6 +159,27 @@ def process_string(self, ctx: CommandContext, item: str): ) +class _FileReceiver: + + def __init__(self, file_path): + self.file_path = file_path + self.tmp_name = f"{file_path}.tmp" + dir_name = os.path.dirname(file_path) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + if os.path.exists(file_path): + # remove existing file + os.remove(file_path) + self.tmp_file = open(self.tmp_name, "ab") + + def close(self): + self.tmp_file.close() + os.rename(self.tmp_name, self.file_path) + + def receive_data(self, data, start: int, length: int): + self.tmp_file.write(data[start:start+length]) + + class FileTransferModule(CommandModule): """Command module with commands relevant to file transfer.""" @@ -173,6 +196,8 @@ def __init__(self, upload_dir: str, download_dir: str): self.cmd_handlers = { ftd.UPLOAD_FOLDER_FQN: self.upload_folder, ftd.DOWNLOAD_FOLDER_FQN: self.download_folder, + ftd.PULL_BINARY_FQN: self.pull_binary_file, + ftd.PULL_FOLDER_FQN: self.pull_folder, } def get_spec(self): @@ -207,6 +232,13 @@ def get_spec(self): handler_func=self.download_binary_file, visible=False, ), + CommandSpec( + name="pull_binary", + description="download one binary files in the download_dir", + usage="pull_binary control_id file_name", + handler_func=self.pull_binary_file, + visible=False, + ), CommandSpec( name="upload_folder", description="Submit application to the server", @@ -246,8 +278,10 @@ def generate_module_spec(self, server_cmd_spec: CommandSpec): handler = self.cmd_handlers.get(server_cmd_spec.client_cmd) if handler is None: - # print('no cmd handler found for {}'.format(server_cmd_spec.client_cmd)) + print('no cmd handler found for {}'.format(server_cmd_spec.client_cmd)) return None + else: + print('cmd handler found for {}'.format(server_cmd_spec.client_cmd)) return CommandModuleSpec( name=server_cmd_spec.scope_name, @@ -309,6 +343,55 @@ def download_text_file(self, args, ctx: CommandContext): def download_binary_file(self, args, ctx: CommandContext): return self.download_file(args, ctx, ftd.SERVER_CMD_DOWNLOAD_BINARY, b64str_to_binary_file) + def pull_binary_file(self, args, ctx: CommandContext): + cmd_entry = ctx.get_command_entry() + if len(args) != 3: + return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, + ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)} + file_name = args[2] + control_id = args[1] + parts = [cmd_entry.full_command_name(), control_id, file_name] + command = join_args(parts) + file_path = os.path.join(self.download_dir, file_name) + receiver = _FileReceiver(file_path) + api = ctx.get_api() + ctx.set_bytes_receiver(receiver.receive_data) + result = api.server_execute(command, cmd_ctx=ctx) + if result.get(ProtoKey.STATUS) == APIStatus.SUCCESS: + receiver.close() + return result + + def pull_folder(self, args, ctx: CommandContext): + cmd_entry = ctx.get_command_entry() + if len(args) != 2: + return {ProtoKey.STATUS: APIStatus.ERROR_SYNTAX, + ProtoKey.DETAILS: "usage: {}".format(cmd_entry.usage)} + folder_name = args[1] + parts = [cmd_entry.full_command_name(), folder_name] + command = join_args(parts) + api = ctx.get_api() + result = api.server_execute(command) + meta = result.get(ProtoKey.META) + if not meta: + return result + + file_names = meta.get(MetaKey.FILES) + ctl_id = meta.get(MetaKey.CONTROL_ID) + print(f"received ctl_id {ctl_id}, file names: {file_names}") + if not file_names: + return result + + cmd_name = meta.get(MetaKey.CMD_NAME) + + for file_name in file_names: + command = f"{cmd_name} {ctl_id} {file_name}" + print(f"sending command: {command}") + reply = api.do_command(command) + if reply.get(ProtoKey.STATUS) != APIStatus.SUCCESS: + return reply + + return {ProtoKey.STATUS: APIStatus.SUCCESS, ProtoKey.DETAILS: "OK"} + def upload_folder(self, args, ctx: CommandContext): cmd_entry = ctx.get_command_entry() assert isinstance(cmd_entry, CommandEntry) diff --git a/nvflare/fuel/hci/conn.py b/nvflare/fuel/hci/conn.py index ce0944bb3c..a2fc3e100e 100644 --- a/nvflare/fuel/hci/conn.py +++ b/nvflare/fuel/hci/conn.py @@ -18,6 +18,8 @@ from .proto import Buffer, validate_proto from .table import Table +from .chunk import Receiver + # ASCII Message Format: # @@ -30,6 +32,7 @@ MAX_MSG_SIZE = 1024 +MAX_BYTES_SIZE = 1024 * 1024 def receive_til_end(sock, end=ALL_END): @@ -72,6 +75,17 @@ def _process_one_line(line: str, process_json_func): process_json_func(json_data) +def receive_bytes_and_process(sock, receive_bytes_func): + receiver = Receiver(receive_data_func=receive_bytes_func) + while True: + data = sock.recv(MAX_BYTES_SIZE) + if not data: + return False + done = receiver.receive(data) + if done: + return True + + def receive_and_process(sock, process_json_func): """Receives and sends lines to process with process_json_func.""" leftover = "" @@ -115,6 +129,7 @@ def __init__(self, sock, server): self.command = None self.args = None self.buffer = Buffer() + self.binary_mode = False def _send_line(self, line: str, all_end=False): """If not ``self.ended``, send line with sock.""" @@ -129,6 +144,9 @@ def _send_line(self, line: str, all_end=False): self.sock.sendall(bytes(line + end, "utf-8")) + def flush_bytes(self, data): + self.sock.sendall(data) + def append_table(self, headers: List[str], name=None) -> Table: return self.buffer.append_table(headers, name=name) @@ -190,5 +208,6 @@ def flush(self): self._send_line(line, all_end=False) def close(self): - self.flush() - self._send_line("", all_end=True) + if not self.binary_mode: + self.flush() + self._send_line("", all_end=True) diff --git a/nvflare/fuel/hci/file_transfer_defs.py b/nvflare/fuel/hci/file_transfer_defs.py index 6d3c45e87e..5ddf2519ef 100644 --- a/nvflare/fuel/hci/file_transfer_defs.py +++ b/nvflare/fuel/hci/file_transfer_defs.py @@ -22,9 +22,12 @@ SERVER_CMD_UPLOAD_FOLDER = "_upload_folder" SERVER_CMD_SUBMIT_JOB = "_submit_job" SERVER_CMD_DOWNLOAD_JOB = "_download_job" -SERVER_CMD_DOWNLOAD_JOB_SINGLE_FILE = "_download_job_single_file" SERVER_CMD_INFO = "_info" +SERVER_CMD_PULL_BINARY = "_pull_binary_file" + DOWNLOAD_URL_MARKER = "Download_URL:" UPLOAD_FOLDER_FQN = "file_transfer.upload_folder" DOWNLOAD_FOLDER_FQN = "file_transfer.download_folder" +PULL_FOLDER_FQN = "file_transfer.pull_folder" +PULL_BINARY_FQN = "file_transfer.pull_binary" diff --git a/nvflare/fuel/hci/proto.py b/nvflare/fuel/hci/proto.py index dea3fac5d7..22aa9689bb 100644 --- a/nvflare/fuel/hci/proto.py +++ b/nvflare/fuel/hci/proto.py @@ -59,6 +59,9 @@ class MetaKey(object): DURATION = "duration" CMD_TIMEOUT = "cmd_timeout" CUSTOM_PROPS = "custom_props" + FILES = "files" + CMD_NAME = "cmd_name" + CONTROL_ID = "control_id" class MetaStatusValue(object): diff --git a/nvflare/fuel/hci/server/binary_transfer.py b/nvflare/fuel/hci/server/binary_transfer.py new file mode 100644 index 0000000000..a421fa83c9 --- /dev/null +++ b/nvflare/fuel/hci/server/binary_transfer.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from nvflare.fuel.hci.conn import Connection +from nvflare.fuel.hci.proto import MetaKey, MetaStatusValue, make_meta +from nvflare.fuel.hci.server.constants import ConnProps +from nvflare.fuel.hci.chunk import Sender + +BINARY_CHUNK_SIZE = 1024 * 1024 # 1M + + +class _BytesSender: + + def __init__(self, conn: Connection): + self.conn = conn + + def send(self, data): + self.conn.flush_bytes(data) + + +class BinaryTransfer: + + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + + def download_file(self, conn: Connection, file_name): + download_dir = conn.get_prop(ConnProps.DOWNLOAD_DIR) + conn.binary_mode = True + full_path = os.path.join(download_dir, file_name) + if not os.path.exists(full_path): + self.logger.error(f"no such file: {full_path}") + return + + if not os.path.isfile(full_path): + self.logger.error(f"not a file: {full_path}") + return + + self.logger.info(f"called to send {full_path} ...") + bytes_sender = _BytesSender(conn) + sender = Sender(send_data_func=bytes_sender.send) + buffer_size = BINARY_CHUNK_SIZE + bytes_sent = 0 + with open(full_path, mode="rb") as f: + chunk = f.read(buffer_size) + while chunk: + sender.send(chunk) + bytes_sent += len(chunk) + chunk = f.read(buffer_size) + sender.close() + self.logger.info(f"finished sending {full_path}: {bytes_sent} bytes sent") + + def download_folder(self, conn: Connection, folder_name: str, download_file_cmd_name: str, control_id: str): + download_dir = conn.get_prop(ConnProps.DOWNLOAD_DIR) + folder_path = os.path.join(download_dir, folder_name) + self.logger.info(f"download_folder called for {folder_name}") + + # return list of the files + files = [] + for (dir_path, dir_names, file_names) in os.walk(folder_path): + for f in file_names: + p = os.path.join(dir_path, f) + p = os.path.relpath(p, folder_path) + p = os.path.join(folder_name, p) + files.append(p) + + self.logger.info(f"files of the folder: {files}") + conn.append_string( + "OK", + meta=make_meta( + MetaStatusValue.OK, + extra={ + MetaKey.FILES: files, + MetaKey.CONTROL_ID: control_id, + MetaKey.CMD_NAME: download_file_cmd_name + } + ) + ) diff --git a/nvflare/fuel/hci/server/file_transfer.py b/nvflare/fuel/hci/server/file_transfer.py index e2f2c828c6..fba1ee99bd 100644 --- a/nvflare/fuel/hci/server/file_transfer.py +++ b/nvflare/fuel/hci/server/file_transfer.py @@ -89,13 +89,6 @@ def get_spec(self): handler_func=self.upload_folder, visible=False, ), - CommandSpec( - name=ftd.SERVER_CMD_DOWNLOAD_JOB_SINGLE_FILE, - description="download a single file from a completed job in the job store", - usage="download_job_single_file job_id file_path", - handler_func=self.download_job_single_file, - visible=False, - ), CommandSpec( name=ftd.SERVER_CMD_INFO, description="show info", diff --git a/nvflare/private/fed/server/job_cmds.py b/nvflare/private/fed/server/job_cmds.py index 19976595f9..531f73e2ed 100644 --- a/nvflare/private/fed/server/job_cmds.py +++ b/nvflare/private/fed/server/job_cmds.py @@ -40,6 +40,7 @@ from nvflare.private.fed.server.server_engine import ServerEngine from nvflare.private.fed.server.server_engine_internal_spec import ServerEngineInternalSpec from nvflare.security.logging import secure_format_exception, secure_log_traceback +from nvflare.fuel.hci.server.binary_transfer import BinaryTransfer from .cmd_utils import CommandUtil @@ -72,7 +73,7 @@ def _create_list_job_cmd_parser(): return parser -class JobCommandModule(CommandModule, CommandUtil): +class JobCommandModule(CommandModule, CommandUtil, BinaryTransfer): """Command module with commands for job management.""" def __init__(self): @@ -145,16 +146,43 @@ def get_spec(self): client_cmd=ftd.UPLOAD_FOLDER_FQN, ), CommandSpec( - name=AdminCommandNames.DOWNLOAD_JOB, + name="old_download_job", description="download a specified job", usage=f"{AdminCommandNames.DOWNLOAD_JOB} job_id", handler_func=self.download_job, authz_func=self.authorize_job, client_cmd=ftd.DOWNLOAD_FOLDER_FQN, + visible=False, + ), + + CommandSpec( + name=AdminCommandNames.DOWNLOAD_JOB, + description="download a specified job", + usage=f"{AdminCommandNames.DOWNLOAD_JOB} job_id", + handler_func=self.pull_job, + authz_func=self.authorize_job, + client_cmd=ftd.PULL_FOLDER_FQN, + ), + + CommandSpec( + name=AdminCommandNames.DOWNLOAD_JOB_FILE, + description="download a specified job file", + usage=f"{AdminCommandNames.DOWNLOAD_JOB_FILE} job_id file_name", + handler_func=self.pull_file, + authz_func=self.authorize_job_file, + client_cmd=ftd.PULL_BINARY_FQN, ), ], ) + def authorize_job_file(self, conn: Connection, args: List[str]): + if len(args) < 2: + conn.append_error( + "syntax error: missing job_id", meta=make_meta(MetaStatusValue.SYNTAX_ERROR, "missing job_id") + ) + return PreAuthzReturnCode.ERROR + return self.authorize_job(conn, args[0:2]) + def authorize_job(self, conn: Connection, args: List[str]): if len(args) < 2: conn.append_error( @@ -613,6 +641,38 @@ def _unzip_data(self, download_dir, job_data, job_id): os.mkdir(workspace_dir) if workspace_bytes is not None: unzip_all_from_bytes(workspace_bytes, workspace_dir) + return job_id_dir + + def pull_file(self, conn: Connection, args: List[str]): + if len(args) != 3: + self.logger.error("syntax error: missing file name") + return + self.download_file(conn, file_name=args[2]) + + def pull_job(self, conn: Connection, args: List[str]): + job_id = args[1] + download_dir = conn.get_prop(ConnProps.DOWNLOAD_DIR) + self.logger.info(f"pull_job called for {job_id}") + + engine = conn.app_ctx + job_def_manager = engine.job_def_manager + if not isinstance(job_def_manager, JobDefManagerSpec): + self.logger.error( + f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" + ) + conn.append_error( + "internal error", + meta=make_meta(MetaStatusValue.INTERNAL_ERROR) + ) + return + + with engine.new_context() as fl_ctx: + job_data = job_def_manager.get_job_data(job_id, fl_ctx) + self._unzip_data(download_dir, job_data, job_id) + self.download_folder( + conn, job_id, + download_file_cmd_name=AdminCommandNames.DOWNLOAD_JOB_FILE, + control_id=job_id) def download_job(self, conn: Connection, args: List[str]): job_id = args[1] diff --git a/nvflare/security/security.py b/nvflare/security/security.py index 1eb1f21a0e..05cd39dfd8 100644 --- a/nvflare/security/security.py +++ b/nvflare/security/security.py @@ -24,6 +24,7 @@ class CommandCategory(object): OPERATE = "operate" VIEW = "view" SHELL_COMMANDS = "shell_commands" + DOWNLOAD_JOB = "download_job" COMMAND_CATEGORIES = { @@ -52,6 +53,8 @@ class CommandCategory(object): AC.SHELL_LS: CommandCategory.SHELL_COMMANDS, AC.SHELL_PWD: CommandCategory.SHELL_COMMANDS, AC.SHELL_TAIL: CommandCategory.SHELL_COMMANDS, + AC.DOWNLOAD_JOB: CommandCategory.DOWNLOAD_JOB, + AC.DOWNLOAD_JOB_FILE: CommandCategory.DOWNLOAD_JOB, }