Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Dec 5, 2024
1 parent 947f8e0 commit 9aff240
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
6 changes: 6 additions & 0 deletions nvflare/lighter/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ class ProvFileName:
CHART_YAML = "Chart.yaml"
VALUES_YAML = "values.yaml"
HELM_CHART_TEMPLATES_DIR = "templates"


class CertFileBasename:
CLIENT = "client"
SERVER = "server"
OVERSEER = "overseer"
18 changes: 7 additions & 11 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from nvflare.lighter.constants import CtxKey, ParticipantType, PropKey
from nvflare.lighter.constants import CertFileBasename, CtxKey, ParticipantType, PropKey
from nvflare.lighter.ctx import ProvisionContext
from nvflare.lighter.entity import Participant, Project
from nvflare.lighter.spec import Builder
from nvflare.lighter.utils import serialize_cert, serialize_pri_key

_CERT_BASE_NAME_CLIENT = "client"
_CERT_BASE_NAME_SERVER = "server"
_CERT_BASE_NAME_OVERSEER = "overseer"


class _CertState:

Expand Down Expand Up @@ -180,7 +176,7 @@ def _build_write_cert_pair(self, participant: Participant, base_name, ctx: Provi
with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f:
f.write(serialize_pri_key(pri_key))

if base_name == _CERT_BASE_NAME_CLIENT and (listening_host := participant.get_prop(PropKey.LISTENING_HOST)):
if base_name == CertFileBasename.CLIENT and (listening_host := participant.get_prop(PropKey.LISTENING_HOST)):
project = ctx.get_project()
tmp_participant = Participant(
type=ParticipantType.SERVER,
Expand All @@ -190,7 +186,7 @@ def _build_write_cert_pair(self, participant: Participant, base_name, ctx: Provi
props={PropKey.DEFAULT_HOST: listening_host},
)
tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant)
bn = _CERT_BASE_NAME_SERVER
bn = CertFileBasename.SERVER
with open(os.path.join(dest_dir, f"{bn}.crt"), "wb") as f:
f.write(serialize_cert(tmp_cert))
with open(os.path.join(dest_dir, f"{bn}.key"), "wb") as f:
Expand All @@ -206,17 +202,17 @@ def build(self, project: Project, ctx: ProvisionContext):

overseer = project.get_overseer()
if overseer:
self._build_write_cert_pair(overseer, _CERT_BASE_NAME_OVERSEER, ctx)
self._build_write_cert_pair(overseer, CertFileBasename.OVERSEER, ctx)

server = project.get_server()
if server:
self._build_write_cert_pair(server, _CERT_BASE_NAME_SERVER, ctx)
self._build_write_cert_pair(server, CertFileBasename.SERVER, ctx)

for client in project.get_clients():
self._build_write_cert_pair(client, _CERT_BASE_NAME_CLIENT, ctx)
self._build_write_cert_pair(client, CertFileBasename.CLIENT, ctx)

for admin in project.get_admins():
self._build_write_cert_pair(admin, _CERT_BASE_NAME_CLIENT, ctx)
self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx)

def get_pri_key_cert(self, participant: Participant):
pri_key, pub_key = self._generate_keys()
Expand Down

0 comments on commit 9aff240

Please sign in to comment.