diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index 2e3fe7515b..97b18d0595 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -136,3 +136,9 @@ class ProvFileName: CHART_YAML = "Chart.yaml" VALUES_YAML = "values.yaml" HELM_CHART_TEMPLATES_DIR = "templates" + + +class CertFileBasename: + CLIENT = "client" + SERVER = "server" + OVERSEER = "overseer" diff --git a/nvflare/lighter/impl/cert.py b/nvflare/lighter/impl/cert.py index 18ddf5d32e..95b363876a 100644 --- a/nvflare/lighter/impl/cert.py +++ b/nvflare/lighter/impl/cert.py @@ -22,16 +22,12 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from nvflare.lighter.constants import CtxKey, ParticipantType, PropKey +from nvflare.lighter.constants import CertFileBasename, CtxKey, ParticipantType, PropKey from nvflare.lighter.ctx import ProvisionContext from nvflare.lighter.entity import Participant, Project from nvflare.lighter.spec import Builder from nvflare.lighter.utils import serialize_cert, serialize_pri_key -_CERT_BASE_NAME_CLIENT = "client" -_CERT_BASE_NAME_SERVER = "server" -_CERT_BASE_NAME_OVERSEER = "overseer" - class _CertState: @@ -180,7 +176,7 @@ def _build_write_cert_pair(self, participant: Participant, base_name, ctx: Provi with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f: f.write(serialize_pri_key(pri_key)) - if base_name == _CERT_BASE_NAME_CLIENT and (listening_host := participant.get_prop(PropKey.LISTENING_HOST)): + if base_name == CertFileBasename.CLIENT and (listening_host := participant.get_prop(PropKey.LISTENING_HOST)): project = ctx.get_project() tmp_participant = Participant( type=ParticipantType.SERVER, @@ -190,7 +186,7 @@ def _build_write_cert_pair(self, participant: Participant, base_name, ctx: Provi props={PropKey.DEFAULT_HOST: listening_host}, ) tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant) - bn = _CERT_BASE_NAME_SERVER + bn = CertFileBasename.SERVER with open(os.path.join(dest_dir, f"{bn}.crt"), "wb") as f: f.write(serialize_cert(tmp_cert)) with open(os.path.join(dest_dir, f"{bn}.key"), "wb") as f: @@ -206,17 +202,17 @@ def build(self, project: Project, ctx: ProvisionContext): overseer = project.get_overseer() if overseer: - self._build_write_cert_pair(overseer, _CERT_BASE_NAME_OVERSEER, ctx) + self._build_write_cert_pair(overseer, CertFileBasename.OVERSEER, ctx) server = project.get_server() if server: - self._build_write_cert_pair(server, _CERT_BASE_NAME_SERVER, ctx) + self._build_write_cert_pair(server, CertFileBasename.SERVER, ctx) for client in project.get_clients(): - self._build_write_cert_pair(client, _CERT_BASE_NAME_CLIENT, ctx) + self._build_write_cert_pair(client, CertFileBasename.CLIENT, ctx) for admin in project.get_admins(): - self._build_write_cert_pair(admin, _CERT_BASE_NAME_CLIENT, ctx) + self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx) def get_pri_key_cert(self, participant: Participant): pri_key, pub_key = self._generate_keys()