Skip to content

Commit

Permalink
fix insecure processing
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Dec 9, 2024
1 parent 22f6fcd commit 6fa3c62
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
41 changes: 24 additions & 17 deletions nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def serialize_cert(cert):


def load_crt(path):
return load_crt_bytes(open(path, "rb").read())
with open(path, "rb") as f:
return load_crt_bytes(f.read())


def load_crt_bytes(data: bytes):
Expand Down Expand Up @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
for file in files:
if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE:
continue
signatures[file] = sign_content(
content=open(os.path.join(root, file), "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(os.path.join(root, file), "rb") as f:
signatures[file] = sign_content(
content=f.read(),
signing_pri_key=signing_pri_key,
)
for folder in folders:
signatures[folder] = sign_content(
content=folder,
signing_pri_key=signing_pri_key,
)

json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f:
json.dump(signatures, f)
shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
if depth >= max_depth:
break
Expand All @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path):
root_ca_public_key = root_ca_cert.public_key()
for root, folders, files in os.walk(src_folder):
try:
signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f:
signatures = json.load(f)
cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
public_key = cert.public_key()
except:
Expand All @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path):
continue
signature = signatures.get(file)
if signature:
verify_content(
content=open(os.path.join(root, file), "rb").read(),
signature=signature,
public_key=public_key,
)
with open(os.path.join(root, file), "rb") as f:
verify_content(
content=f.read(),
signature=signature,
public_key=public_key,
)
for folder in folders:
signature = signatures.get(folder)
if signature:
Expand All @@ -173,16 +178,18 @@ def sign_all(content_folder, signing_pri_key):
for f in os.listdir(content_folder):
path = os.path.join(content_folder, f)
if os.path.isfile(path):
signatures[f] = sign_content(
content=open(path, "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(path, "rb") as file:
signatures[f] = sign_content(
content=file.read(),
signing_pri_key=signing_pri_key,
)
return signatures


def load_yaml(file):
if isinstance(file, str):
return yaml.safe_load(open(file, "r"))
with open(file, "r") as f:
return yaml.safe_load(f)
elif isinstance(file, bytes):
return yaml.safe_load(file)
else:
Expand Down
7 changes: 7 additions & 0 deletions nvflare/private/fed/server/fed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ def _generate_reply(self, headers, payload, fl_ctx: FLContext):
return return_message

def _get_id_asserter(self):
if not self.secure_train:
return None

if not self.id_asserter:
with self.engine.new_context() as fl_ctx:
server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG)
Expand Down Expand Up @@ -629,10 +632,14 @@ def _get_id_asserter(self):

def sign_auth_token(self, client_name: str, token: str):
id_asserter = self._get_id_asserter()
if not id_asserter:
return "NA"
return id_asserter.sign(client_name + token, return_str=True)

def verify_auth_token(self, client_name: str, token: str, signature):
id_asserter = self._get_id_asserter()
if not id_asserter:
return True
return id_asserter.verify_signature(client_name + token, signature)

def _ready_for_registration(self, fl_ctx: FLContext):
Expand Down

0 comments on commit 6fa3c62

Please sign in to comment.