From 947f8e080e3abd816e44e2919a6e045ec87ae1dd Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Thu, 5 Dec 2024 11:43:41 -0500 Subject: [PATCH] reorg file structure --- nvflare/lighter/ctx.py | 110 +++++ nvflare/lighter/entity.py | 261 ++++++++++++ nvflare/lighter/impl/cert.py | 6 +- nvflare/lighter/impl/docker.py | 9 +- nvflare/lighter/impl/he.py | 5 +- nvflare/lighter/impl/helm_chart.py | 12 +- nvflare/lighter/impl/static_file.py | 75 ++-- nvflare/lighter/impl/workspace.py | 7 +- nvflare/lighter/provision.py | 3 +- nvflare/lighter/provisioner.py | 93 +++++ nvflare/lighter/spec.py | 420 +------------------- nvflare/tool/poc/poc_commands.py | 2 +- tests/unit_test/lighter/participant_test.py | 2 +- tests/unit_test/lighter/project_test.py | 2 +- 14 files changed, 530 insertions(+), 477 deletions(-) create mode 100644 nvflare/lighter/ctx.py create mode 100644 nvflare/lighter/entity.py create mode 100644 nvflare/lighter/provisioner.py diff --git a/nvflare/lighter/ctx.py b/nvflare/lighter/ctx.py new file mode 100644 index 0000000000..f55cdfd303 --- /dev/null +++ b/nvflare/lighter/ctx.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import yaml + +from nvflare.lighter import utils + +from .constants import CtxKey, PropKey, ProvisionMode +from .entity import Entity, Project + + +class ProvisionContext(dict): + def __init__(self, workspace_root_dir: str, project: Project): + super().__init__() + self[CtxKey.WORKSPACE] = workspace_root_dir + + wip_dir = os.path.join(workspace_root_dir, "wip") + state_dir = os.path.join(workspace_root_dir, "state") + resources_dir = os.path.join(workspace_root_dir, "resources") + self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir}) + dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir] + utils.make_dirs(dirs) + + # set commonly used data into ctx + self[CtxKey.PROJECT] = project + + server = project.get_server() + admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003) + self[CtxKey.ADMIN_PORT] = admin_port + fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002) + self[CtxKey.FED_LEARN_PORT] = fed_learn_port + self[CtxKey.SERVER_NAME] = server.name + + def get_project(self): + return self.get(CtxKey.PROJECT) + + def set_template(self, template: dict): + self[CtxKey.TEMPLATE] = template + + def get_template(self): + return self.get(CtxKey.TEMPLATE) + + def get_template_section(self, section_key: str): + template = self.get_template() + if not template: + raise RuntimeError("template is not available") + + section = template.get(section_key) + if not section: + raise RuntimeError(f"missing section {section} in template") + + return section + + def set_provision_mode(self, mode: str): + valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL] + if mode not in valid_modes: + raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}") + self[CtxKey.PROVISION_MODE] = mode + + def get_provision_mode(self): + return self.get(CtxKey.PROVISION_MODE) + + def get_wip_dir(self): + return self.get(CtxKey.WIP) + + def get_ws_dir(self, entity: Entity): + return os.path.join(self.get_wip_dir(), entity.name) + + def get_kit_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "startup") + + def get_transfer_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "transfer") + + def get_local_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "local") + + def get_state_dir(self): + return self.get(CtxKey.STATE) + + def get_resources_dir(self): + return self.get(CtxKey.RESOURCES) + + def get_workspace(self): + return self.get(CtxKey.WORKSPACE) + + def yaml_load_template_section(self, section_key: str): + return yaml.safe_load(self.get_template_section(section_key)) + + def json_load_template_section(self, section_key: str): + return json.loads(self.get_template_section(section_key)) + + def build_from_template(self, dest_dir: str, temp_section: str, file_name, replacement=None, mode="t", exe=False): + section = self.get_template_section(temp_section) + if replacement: + section = utils.sh_replace(section, replacement) + utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe) diff --git a/nvflare/lighter/entity.py b/nvflare/lighter/entity.py new file mode 100644 index 0000000000..6b40a60c56 --- /dev/null +++ b/nvflare/lighter/entity.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.utils.format_check import name_check + +from .constants import AdminRole, ParticipantType, PropKey + + +def _check_host_name(scope: str, prop_key: str, value): + err, reason = name_check(value, "host_name") + if err: + raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}") + + +def _check_host_names(scope: str, prop_key: str, value): + if isinstance(value, str): + _check_host_name(scope, prop_key, value) + elif isinstance(value, list): + for v in value: + _check_host_name(scope, prop_key, v) + + +def _check_admin_role(scope: str, prop_key: str, value): + valid_roles = [AdminRole.PROJECT_ADMIN, AdminRole.ORG_ADMIN, AdminRole.LEAD, AdminRole.MEMBER] + if value not in valid_roles: + raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be one of {valid_roles}") + + +# validator functions for common properties +# Validator function must follow this signature: +# func(scope: str, prop_key: str, value) +_PROP_VALIDATORS = { + PropKey.HOST_NAMES: _check_host_names, + PropKey.CONNECT_TO: _check_host_name, + PropKey.LISTENING_HOST: _check_host_name, + PropKey.DEFAULT_HOST: _check_host_name, + PropKey.ROLE: _check_admin_role, +} + + +class Entity: + def __init__(self, scope: str, name: str, props: dict, parent=None): + if not props: + props = {} + + for k, v in props.items(): + validator = _PROP_VALIDATORS.get(k) + if validator is not None: + validator(scope, k, v) + self.name = name + self.props = props + self.parent = parent + + def get_prop(self, key: str, default=None): + return self.props.get(key, default) + + def get_prop_fb(self, key: str, fb_key=None, default=None): + """Get property value with fallback. + If I have the property, then return it. + If not, I return the fallback property of my parent. If I don't have parent, return default. + + Args: + key: key of the property + fb_key: key of the fallback property. + default: value to return if no one has the property + + Returns: property value + + """ + value = self.get_prop(key) + if value: + return value + elif not self.parent: + return default + else: + # get the value from the parent + if not fb_key: + fb_key = key + return self.parent.get_prop(fb_key, default) + + +class Participant(Entity): + def __init__(self, type: str, name: str, org: str, props: dict = None, project: Entity = None): + """Class to represent a participant. + + Each participant communicates to other participant. Therefore, each participant has its + own name, type, organization it belongs to, rules and other information. + + Args: + type (str): server, client, admin or other string that builders can handle + name (str): system-wide unique name + org (str): system-wide unique organization + props (dict): properties + project: the project that the participant belongs to + + Raises: + ValueError: if name or org is not compliant with characters or format specification. + """ + Entity.__init__(self, f"{type}::{name}", name, props, parent=project) + + err, reason = name_check(name, type) + if err: + raise ValueError(reason) + + err, reason = name_check(org, "org") + if err: + raise ValueError(reason) + + self.type = type + self.org = org + self.subject = name + + def get_default_host(self) -> str: + """Get the default host name for accessing this participant (server). + If the "default_host" attribute is explicitly specified, then it's the default host. + If the "default_host" attribute is not explicitly specified, then use the "name" attribute. + + Returns: a host name + + """ + h = self.get_prop(PropKey.DEFAULT_HOST) + if h: + return h + else: + return self.name + + +class Project(Entity): + def __init__( + self, + name: str, + description: str, + participants=None, + props: dict = None, + serialized_root_cert=None, + serialized_root_private_key=None, + ): + """A container class to hold information about this FL project. + + This class only holds information. It does not drive the workflow. + + Args: + name (str): the project name + description (str): brief description on this name + participants: if provided, list of participants of the project + props: properties of the project + serialized_root_cert: if provided, the root cert to be used for the project + serialized_root_private_key: if provided, the root private key for signing certs of sites and admins + + Raises: + ValueError: when participant criteria is violated + """ + Entity.__init__(self, "project", name, props) + + if serialized_root_cert: + if not serialized_root_private_key: + raise ValueError("missing serialized_root_private_key while serialized_root_cert is provided") + + self.description = description + self.serialized_root_cert = serialized_root_cert + self.serialized_root_private_key = serialized_root_private_key + self.server = None + self.overseer = None + self.clients = [] + self.admins = [] + self.all_names = {} + + if participants: + if not isinstance(participants, list): + raise ValueError(f"participants must be a list of Participant but got {type(participants)}") + + for p in participants: + if not isinstance(p, Participant): + raise ValueError(f"bad item in participants: must be Participant but got {type(p)}") + + if p.type == ParticipantType.SERVER: + self.set_server(p.name, p.org, p.props) + elif p.type == ParticipantType.ADMIN: + self.add_admin(p.name, p.org, p.props) + elif p.type == ParticipantType.CLIENT: + self.add_client(p.name, p.org, p.props) + elif p.type == ParticipantType.OVERSEER: + self.set_overseer(p.name, p.org, p.props) + else: + raise ValueError(f"invalid value for ParticipantType: {p.type}") + + def _check_unique_name(self, name: str): + if name in self.all_names: + raise ValueError(f"the project {self.name} already has a participant with the name '{name}'") + + def set_server(self, name: str, org: str, props: dict): + if self.server: + raise ValueError(f"project {self.name} already has a server defined") + self._check_unique_name(name) + self.server = Participant(ParticipantType.SERVER, name, org, props, self) + self.all_names[name] = True + + def get_server(self): + """Get the server definition. Only one server is supported! + + Returns: server participant + + """ + return self.server + + def set_overseer(self, name: str, org: str, props: dict): + if self.overseer: + raise ValueError(f"project {self.name} already has an overseer defined") + self._check_unique_name(name) + self.overseer = Participant(ParticipantType.OVERSEER, name, org, props, self) + self.all_names[name] = True + + def get_overseer(self): + """Get the overseer definition. Only one overseer is supported! + + Returns: overseer participant + + """ + return self.overseer + + def add_client(self, name: str, org: str, props: dict): + self._check_unique_name(name) + self.clients.append(Participant(ParticipantType.CLIENT, name, org, props, self)) + self.all_names[name] = True + + def get_clients(self): + return self.clients + + def add_admin(self, name: str, org: str, props: dict): + self._check_unique_name(name) + admin = Participant(ParticipantType.ADMIN, name, org, props, self) + role = admin.get_prop(PropKey.ROLE) + if not role: + raise ValueError(f"missing role in admin '{name}'") + self.admins.append(admin) + self.all_names[name] = True + + def get_admins(self): + return self.admins + + def get_all_participants(self): + result = [] + if self.server: + result.append(self.server) + + if self.overseer: + result.append(self.overseer) + + result.extend(self.clients) + result.extend(self.admins) + return result diff --git a/nvflare/lighter/impl/cert.py b/nvflare/lighter/impl/cert.py index cc05f79ca1..18ddf5d32e 100644 --- a/nvflare/lighter/impl/cert.py +++ b/nvflare/lighter/impl/cert.py @@ -22,8 +22,10 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from nvflare.lighter.constants import CtxKey -from nvflare.lighter.spec import Builder, Participant, ParticipantType, Project, PropKey, ProvisionContext +from nvflare.lighter.constants import CtxKey, ParticipantType, PropKey +from nvflare.lighter.ctx import ProvisionContext +from nvflare.lighter.entity import Participant, Project +from nvflare.lighter.spec import Builder from nvflare.lighter.utils import serialize_cert, serialize_pri_key _CERT_BASE_NAME_CLIENT = "client" diff --git a/nvflare/lighter/impl/docker.py b/nvflare/lighter/impl/docker.py index 6a373c3740..8773b64dc7 100644 --- a/nvflare/lighter/impl/docker.py +++ b/nvflare/lighter/impl/docker.py @@ -28,8 +28,9 @@ def __init__(self, base_image="python:3.8", requirements_file="requirements.txt" self.base_image = base_image self.requirements_file = requirements_file self.services = {} + self.compose_file_path = None - def _build_overseer(self, overseer, ctx): + def _build_overseer(self, overseer): protocol = overseer.props.get("protocol", "http") default_port = "443" if protocol == "https" else "80" port = overseer.props.get("port", default_port) @@ -56,7 +57,7 @@ def _build_server(self, server, ctx: ProvisionContext): info_dict["container_name"] = server.name self.services[server.name] = info_dict - def _build_client(self, client, ctx): + def _build_client(self, client): info_dict = copy.deepcopy(self.services["__flclient__"]) info_dict["volumes"] = [f"./{client.name}:" + "${WORKSPACE}"] info_dict["build"] = "nvflare_compose" @@ -76,13 +77,13 @@ def build(self, project: Project, ctx: ProvisionContext): self.compose_file_path = os.path.join(ctx.get_wip_dir(), ProvFileName.COMPOSE_YAML) overseer = project.get_overseer() if overseer: - self._build_overseer(overseer, ctx) + self._build_overseer(overseer) server = project.get_server() if server: self._build_server(server, ctx) for client in project.get_clients(): - self._build_client(client, ctx) + self._build_client(client) self.services.pop("__overseer__", None) self.services.pop("__flserver__", None) diff --git a/nvflare/lighter/impl/he.py b/nvflare/lighter/impl/he.py index 573058ec59..65051b5dda 100644 --- a/nvflare/lighter/impl/he.py +++ b/nvflare/lighter/impl/he.py @@ -24,7 +24,7 @@ class HEBuilder(Builder): def __init__( self, poly_modulus_degree=8192, - coeff_mod_bit_sizes=[60, 40, 40], + coeff_mod_bit_sizes=None, scale_bits=40, scheme="CKKS", ): @@ -39,6 +39,9 @@ def __init__( scale_bits: defaults to 40. scheme: defaults to "CKKS". """ + if not coeff_mod_bit_sizes: + coeff_mod_bit_sizes = [60, 40, 40] + self._context = None self.scheme_type_mapping = { "CKKS": ts.SCHEME_TYPE.CKKS, diff --git a/nvflare/lighter/impl/helm_chart.py b/nvflare/lighter/impl/helm_chart.py index 3c5eddeb07..570e1113eb 100644 --- a/nvflare/lighter/impl/helm_chart.py +++ b/nvflare/lighter/impl/helm_chart.py @@ -17,7 +17,8 @@ import yaml from nvflare.lighter.constants import CtxKey, PropKey, ProvFileName, TemplateSectionKey -from nvflare.lighter.spec import Builder, Participant, Project, ProvisionContext +from nvflare.lighter.entity import Participant +from nvflare.lighter.spec import Builder, Project, ProvisionContext class HelmChartBuilder(Builder): @@ -25,12 +26,17 @@ def __init__(self, docker_image): """Build Helm Chart.""" self.docker_image = docker_image self.helm_chart_directory = None + self.service_overseer = None + self.service_server = None + self.deployment_server = None + self.deployment_overseer = None + self.helm_chart_templates_directory = None def initialize(self, project: Project, ctx: ProvisionContext): self.helm_chart_directory = os.path.join(ctx.get_wip_dir(), ProvFileName.HELM_CHART_DIR) os.mkdir(self.helm_chart_directory) - def _build_overseer(self, overseer: Participant, ctx): + def _build_overseer(self, overseer: Participant): protocol = overseer.get_prop(PropKey.PROTOCOL, "http") default_port = "443" if protocol == "https" else "80" port = overseer.get_prop(PropKey.PORT, default_port) @@ -110,7 +116,7 @@ def build(self, project: Project, ctx: ProvisionContext): os.mkdir(self.helm_chart_templates_directory) overseer = project.get_overseer() if overseer: - self._build_overseer(overseer, ctx) + self._build_overseer(overseer) server = project.get_server() if server: diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index 09f7a1102b..742f2eaac2 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -20,7 +20,8 @@ from nvflare.lighter import utils from nvflare.lighter.constants import CtxKey, OverseerRole, PropKey, ProvFileName, ProvisionMode, TemplateSectionKey -from nvflare.lighter.spec import Builder, Participant, Project, ProvisionContext +from nvflare.lighter.entity import Participant +from nvflare.lighter.spec import Builder, Project, ProvisionContext class StaticFileBuilder(Builder): @@ -84,12 +85,11 @@ def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): ) if self.docker_image: - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.DOCKER_SERVER_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True + ctx.build_from_template( + dest_dir, TemplateSectionKey.DOCKER_SERVER_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True ) - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.GUNICORN_CONF_PY, ProvFileName.GUNICORN_CONF_PY, @@ -97,7 +97,7 @@ def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): exe=False, ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.START_OVERSEER_SH, ProvFileName.START_SH, exe=True) + ctx.build_from_template(dest_dir, TemplateSectionKey.START_OVERSEER_SH, ProvFileName.START_SH, exe=True) if port: ctx[PropKey.OVERSEER_END_POINT] = f"{protocol}://{overseer.name}:{port}{api_root}" @@ -131,8 +131,7 @@ def _build_server(self, server: Participant, ctx: ProvisionContext): } if self.docker_image: - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.DOCKER_SERVER_SH, ProvFileName.DOCKER_SH, @@ -140,10 +139,9 @@ def _build_server(self, server: Participant, ctx: ProvisionContext): exe=True, ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.START_SERVER_SH, ProvFileName.START_SH, exe=True) + ctx.build_from_template(dest_dir, TemplateSectionKey.START_SERVER_SH, ProvFileName.START_SH, exe=True) - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.SUB_START_SH, ProvFileName.SUB_START_SH, @@ -151,30 +149,28 @@ def _build_server(self, server: Participant, ctx: ProvisionContext): exe=True, ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) + ctx.build_from_template(dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) # local folder creation dest_dir = ctx.get_local_dir(server) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT, exe=False - ) + ctx.build_from_template(dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT, exe=False) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.LOCAL_SERVER_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT, exe=False + ctx.build_from_template( + dest_dir, TemplateSectionKey.LOCAL_SERVER_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT, exe=False ) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.SAMPLE_PRIVACY, ProvFileName.PRIVACY_JSON_SAMPLE, exe=False + ctx.build_from_template( + dest_dir, TemplateSectionKey.SAMPLE_PRIVACY, ProvFileName.PRIVACY_JSON_SAMPLE, exe=False ) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT, exe=False + ctx.build_from_template( + dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT, exe=False ) # workspace folder file dest_dir = ctx.get_ws_dir(server) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.SERVER_README, ProvFileName.README_TXT, exe=False) + ctx.build_from_template(dest_dir, TemplateSectionKey.SERVER_README, ProvFileName.README_TXT, exe=False) def _build_client(self, client, ctx): project = ctx.get_project() @@ -200,8 +196,7 @@ def _build_client(self, client, ctx): utils.write(os.path.join(dest_dir, ProvFileName.FED_CLIENT_JSON), json.dumps(config, indent=2), "t") if self.docker_image: - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.DOCKER_CLIENT_SH, ProvFileName.DOCKER_SH, @@ -209,37 +204,34 @@ def _build_client(self, client, ctx): exe=True, ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.START_CLIENT_SH, ProvFileName.START_SH, exe=True) + ctx.build_from_template(dest_dir, TemplateSectionKey.START_CLIENT_SH, ProvFileName.START_SH, exe=True) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.SUB_START_SH, ProvFileName.SUB_START_SH, replacement_dict, exe=True + ctx.build_from_template( + dest_dir, TemplateSectionKey.SUB_START_SH, ProvFileName.SUB_START_SH, replacement_dict, exe=True ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) + ctx.build_from_template(dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) # local folder creation dest_dir = ctx.get_local_dir(client) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT) + ctx.build_from_template(dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.LOCAL_CLIENT_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT + ctx.build_from_template( + dest_dir, TemplateSectionKey.LOCAL_CLIENT_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT ) - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.SAMPLE_PRIVACY, ProvFileName.PRIVACY_JSON_SAMPLE, ) - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT - ) + ctx.build_from_template(dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT) # workspace folder file dest_dir = ctx.get_ws_dir(client) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.CLIENT_README, ProvFileName.README_TXT) + ctx.build_from_template(dest_dir, TemplateSectionKey.CLIENT_README, ProvFileName.README_TXT) @staticmethod def _check_host_name(host_name: str, server: Participant) -> str: @@ -329,12 +321,11 @@ def _build_admin(self, admin: Participant, ctx: ProvisionContext): utils.write(os.path.join(dest_dir, ProvFileName.FED_ADMIN_JSON), json.dumps(config, indent=2), "t") if self.docker_image: - self.build_from_template( - ctx, dest_dir, TemplateSectionKey.DOCKER_ADMIN_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True + ctx.build_from_template( + dest_dir, TemplateSectionKey.DOCKER_ADMIN_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True ) - self.build_from_template( - ctx, + ctx.build_from_template( dest_dir, TemplateSectionKey.FL_ADMIN_SH, ProvFileName.FL_ADMIN_SH, @@ -342,7 +333,7 @@ def _build_admin(self, admin: Participant, ctx: ProvisionContext): exe=True, ) - self.build_from_template(ctx, dest_dir, TemplateSectionKey.ADMIN_README, ProvFileName.README_TXT) + ctx.build_from_template(dest_dir, TemplateSectionKey.ADMIN_README, ProvFileName.README_TXT) def prepare_admin_config(self, admin, ctx: ProvisionContext): config = ctx.json_load_template_section(TemplateSectionKey.FED_ADMIN) diff --git a/nvflare/lighter/impl/workspace.py b/nvflare/lighter/impl/workspace.py index fcfa43d0bf..811d548716 100644 --- a/nvflare/lighter/impl/workspace.py +++ b/nvflare/lighter/impl/workspace.py @@ -16,7 +16,8 @@ import shutil import nvflare.lighter as prov -from nvflare.lighter.spec import Builder, CtxKey, Project, ProvisionContext +from nvflare.lighter.constants import CtxKey +from nvflare.lighter.spec import Builder, Project, ProvisionContext from nvflare.lighter.utils import load_yaml, make_dirs @@ -68,8 +69,8 @@ def initialize(self, project: Project, ctx: ProvisionContext): workspace_dir = ctx.get_workspace() prod_dirs = [_ for _ in os.listdir(workspace_dir) if _.startswith("prod_")] last = -1 - for dir in prod_dirs: - stage = int(dir.split("_")[-1]) + for d in prod_dirs: + stage = int(d.split("_")[-1]) if stage > last: last = stage ctx[CtxKey.LAST_PROD_STAGE] = last diff --git a/nvflare/lighter/provision.py b/nvflare/lighter/provision.py index 99191bb76a..bb9fb85131 100644 --- a/nvflare/lighter/provision.py +++ b/nvflare/lighter/provision.py @@ -23,7 +23,8 @@ from nvflare.fuel.utils.class_utils import instantiate_class from nvflare.lighter.constants import ParticipantType, PropKey -from nvflare.lighter.spec import Project, Provisioner +from nvflare.lighter.provisioner import Provisioner +from nvflare.lighter.spec import Project from nvflare.lighter.utils import load_yaml adding_client_error_msg = """ diff --git a/nvflare/lighter/provisioner.py b/nvflare/lighter/provisioner.py new file mode 100644 index 0000000000..093cf17af2 --- /dev/null +++ b/nvflare/lighter/provisioner.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import traceback +from typing import List + +from .constants import ProvisionMode, WorkDir +from .ctx import ProvisionContext +from .entity import Project +from .spec import Builder + + +class Provisioner: + def __init__(self, root_dir: str, builders: List[Builder]): + """Workflow class that drive the provision process. + + Provisioner's tasks: + + - Maintain the provision workspace folder structure; + - Invoke Builders to generate the content of each startup kit + + ROOT_WORKSPACE Folder Structure:: + + root_workspace_dir_name: this is the root of the workspace + project_dir_name: the root dir of the project, could be named after the project + resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders + prod: stores the current set of startup kits (production) + participate_dir: stores content files generated by builders + wip: stores the set of startup kits to be created (WIP) + participate_dir: stores content files generated by builders + state: stores the persistent state of the Builders + + Args: + root_dir (str): the directory path to hold all generated or intermediate folders + builders (List[Builder]): all builders that will be called to build the content + """ + self.root_dir = root_dir + self.builders = builders + self.template = {} + + def add_template(self, template: dict): + if not isinstance(template, dict): + raise ValueError(f"template must be a dict but got {type(template)}") + self.template.update(template) + + def provision(self, project: Project, mode=None): + server = project.get_server() + if not server: + raise RuntimeError("missing server from the project") + + workspace_root_dir = os.path.join(self.root_dir, project.name) + ctx = ProvisionContext(workspace_root_dir, project) + if self.template: + ctx.set_template(self.template) + + if not mode: + mode = ProvisionMode.NORMAL + ctx.set_provision_mode(mode) + + try: + for b in self.builders: + b.initialize(project, ctx) + + # call builders! + for b in self.builders: + b.build(project, ctx) + + for b in self.builders[::-1]: + b.finalize(project, ctx) + + except Exception: + prod_dir = ctx.get(WorkDir.CURRENT_PROD_DIR) + if prod_dir: + shutil.rmtree(prod_dir) + print("Exception raised during provision. Incomplete prod_n folder removed.") + traceback.print_exc() + finally: + wip_dir = ctx.get(WorkDir.WIP) + if wip_dir: + shutil.rmtree(wip_dir) + return ctx diff --git a/nvflare/lighter/spec.py b/nvflare/lighter/spec.py index 33f085657d..aaad6a76a2 100644 --- a/nvflare/lighter/spec.py +++ b/nvflare/lighter/spec.py @@ -11,346 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import os -import shutil -import traceback from abc import ABC -from typing import List -import yaml - -from nvflare.apis.utils.format_check import name_check -from nvflare.lighter import utils - -from .constants import AdminRole, CtxKey, ParticipantType, PropKey, ProvisionMode, WorkDir - - -def _check_host_name(scope: str, prop_key: str, value): - err, reason = name_check(value, "host_name") - if err: - raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}") - - -def _check_host_names(scope: str, prop_key: str, value): - if isinstance(value, str): - _check_host_name(scope, prop_key, value) - elif isinstance(value, list): - for v in value: - _check_host_name(scope, prop_key, v) - - -def _check_admin_role(scope: str, prop_key: str, value): - valid_roles = [AdminRole.PROJECT_ADMIN, AdminRole.ORG_ADMIN, AdminRole.LEAD, AdminRole.MEMBER] - if value not in valid_roles: - raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be one of {valid_roles}") - - -# validator functions for common properties -# Validator function must follow this signature: -# func(scope: str, prop_key: str, value) -_PROP_VALIDATORS = { - PropKey.HOST_NAMES: _check_host_names, - PropKey.CONNECT_TO: _check_host_name, - PropKey.LISTENING_HOST: _check_host_name, - PropKey.DEFAULT_HOST: _check_host_name, - PropKey.ROLE: _check_admin_role, -} - - -class Entity: - def __init__(self, scope: str, name: str, props: dict, parent=None): - if not props: - props = {} - - for k, v in props.items(): - validator = _PROP_VALIDATORS.get(k) - if validator is not None: - validator(scope, k, v) - self.name = name - self.props = props - self.parent = parent - - def get_prop(self, key: str, default=None): - return self.props.get(key, default) - - def get_prop_fb(self, key: str, fb_key=None, default=None): - """Get property value with fallback. - If I have the property, then return it. - If not, I return the fallback property of my parent. If I don't have parent, return default. - - Args: - key: key of the property - fb_key: key of the fallback property. - default: value to return if no one has the property - - Returns: property value - - """ - value = self.get_prop(key) - if value: - return value - elif not self.parent: - return default - else: - # get the value from the parent - if not fb_key: - fb_key = key - return self.parent.get_prop(fb_key, default) - - -class Participant(Entity): - def __init__(self, type: str, name: str, org: str, props: dict = None, project: Entity = None): - """Class to represent a participant. - - Each participant communicates to other participant. Therefore, each participant has its - own name, type, organization it belongs to, rules and other information. - - Args: - type (str): server, client, admin or other string that builders can handle - name (str): system-wide unique name - org (str): system-wide unique organization - props (dict): properties - project: the project that the participant belongs to - - Raises: - ValueError: if name or org is not compliant with characters or format specification. - """ - Entity.__init__(self, f"{type}::{name}", name, props, parent=project) - - err, reason = name_check(name, type) - if err: - raise ValueError(reason) - - err, reason = name_check(org, "org") - if err: - raise ValueError(reason) - - self.type = type - self.org = org - self.subject = name - - def get_default_host(self) -> str: - """Get the default host name for accessing this participant (server). - If the "default_host" attribute is explicitly specified, then it's the default host. - If the "default_host" attribute is not explicitly specified, then use the "name" attribute. - - Returns: a host name - - """ - h = self.get_prop(PropKey.DEFAULT_HOST) - if h: - return h - else: - return self.name - - -class Project(Entity): - def __init__( - self, - name: str, - description: str, - participants=None, - props: dict = None, - serialized_root_cert=None, - serialized_root_private_key=None, - ): - """A container class to hold information about this FL project. - - This class only holds information. It does not drive the workflow. - - Args: - name (str): the project name - description (str): brief description on this name - participants: if provided, list of participants of the project - props: properties of the project - serialized_root_cert: if provided, the root cert to be used for the project - serialized_root_private_key: if provided, the root private key for signing certs of sites and admins - - Raises: - ValueError: when participant criteria is violated - """ - Entity.__init__(self, "project", name, props) - - if serialized_root_cert: - if not serialized_root_private_key: - raise ValueError("missing serialized_root_private_key while serialized_root_cert is provided") - - self.description = description - self.serialized_root_cert = serialized_root_cert - self.serialized_root_private_key = serialized_root_private_key - self.server = None - self.overseer = None - self.clients = [] - self.admins = [] - self.all_names = {} - - if participants: - if not isinstance(participants, list): - raise ValueError(f"participants must be a list of Participant but got {type(participants)}") - - for p in participants: - if not isinstance(p, Participant): - raise ValueError(f"bad item in participants: must be Participant but got {type(p)}") - - if p.type == ParticipantType.SERVER: - self.set_server(p.name, p.org, p.props) - elif p.type == ParticipantType.ADMIN: - self.add_admin(p.name, p.org, p.props) - elif p.type == ParticipantType.CLIENT: - self.add_client(p.name, p.org, p.props) - elif p.type == ParticipantType.OVERSEER: - self.set_overseer(p.name, p.org, p.props) - else: - raise ValueError(f"invalid value for ParticipantType: {p.type}") - - def _check_unique_name(self, name: str): - if name in self.all_names: - raise ValueError(f"the project {self.name} already has a participant with the name '{name}'") - - def set_server(self, name: str, org: str, props: dict): - if self.server: - raise ValueError(f"project {self.name} already has a server defined") - self._check_unique_name(name) - self.server = Participant(ParticipantType.SERVER, name, org, props, self) - self.all_names[name] = True - - def get_server(self): - """Get the server definition. Only one server is supported! - - Returns: server participant - - """ - return self.server - - def set_overseer(self, name: str, org: str, props: dict): - if self.overseer: - raise ValueError(f"project {self.name} already has an overseer defined") - self._check_unique_name(name) - self.overseer = Participant(ParticipantType.OVERSEER, name, org, props, self) - self.all_names[name] = True - - def get_overseer(self): - """Get the overseer definition. Only one overseer is supported! - - Returns: overseer participant - - """ - return self.overseer - - def add_client(self, name: str, org: str, props: dict): - self._check_unique_name(name) - self.clients.append(Participant(ParticipantType.CLIENT, name, org, props, self)) - self.all_names[name] = True - - def get_clients(self): - return self.clients - - def add_admin(self, name: str, org: str, props: dict): - self._check_unique_name(name) - admin = Participant(ParticipantType.ADMIN, name, org, props, self) - role = admin.get_prop(PropKey.ROLE) - if not role: - raise ValueError(f"missing role in admin '{name}'") - self.admins.append(admin) - self.all_names[name] = True - - def get_admins(self): - return self.admins - - def get_all_participants(self): - result = [] - if self.server: - result.append(self.server) - - if self.overseer: - result.append(self.overseer) - - result.extend(self.clients) - result.extend(self.admins) - return result - - -class ProvisionContext(dict): - def __init__(self, workspace_root_dir: str, project: Project): - super().__init__() - self[CtxKey.WORKSPACE] = workspace_root_dir - - wip_dir = os.path.join(workspace_root_dir, "wip") - state_dir = os.path.join(workspace_root_dir, "state") - resources_dir = os.path.join(workspace_root_dir, "resources") - self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir}) - dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir] - utils.make_dirs(dirs) - - # set commonly used data into ctx - self[CtxKey.PROJECT] = project - - server = project.get_server() - admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003) - self[CtxKey.ADMIN_PORT] = admin_port - fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002) - self[CtxKey.FED_LEARN_PORT] = fed_learn_port - self[CtxKey.SERVER_NAME] = server.name - - def get_project(self): - return self.get(CtxKey.PROJECT) - - def set_template(self, template: dict): - self[CtxKey.TEMPLATE] = template - - def get_template(self): - return self.get(CtxKey.TEMPLATE) - - def get_template_section(self, section_key: str): - template = self.get_template() - if not template: - raise RuntimeError("template is not available") - - section = template.get(section_key) - if not section: - raise RuntimeError(f"missing section {section} in template") - - return section - - def set_provision_mode(self, mode: str): - valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL] - if mode not in valid_modes: - raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}") - self[CtxKey.PROVISION_MODE] = mode - - def get_provision_mode(self): - return self.get(CtxKey.PROVISION_MODE) - - def get_wip_dir(self): - return self.get(CtxKey.WIP) - - def get_ws_dir(self, entity: Entity): - return os.path.join(self.get_wip_dir(), entity.name) - - def get_kit_dir(self, entity: Entity): - return os.path.join(self.get_ws_dir(entity), "startup") - - def get_transfer_dir(self, entity: Entity): - return os.path.join(self.get_ws_dir(entity), "transfer") - - def get_local_dir(self, entity: Entity): - return os.path.join(self.get_ws_dir(entity), "local") - - def get_state_dir(self): - return self.get(CtxKey.STATE) - - def get_resources_dir(self): - return self.get(CtxKey.RESOURCES) - - def get_workspace(self): - return self.get(CtxKey.WORKSPACE) - - def yaml_load_template_section(self, section_key: str): - return yaml.safe_load(self.get_template_section(section_key)) - - def json_load_template_section(self, section_key: str): - return json.loads(self.get_template_section(section_key)) +from .ctx import ProvisionContext +from .entity import Project class Builder(ABC): @@ -362,83 +26,3 @@ def build(self, project: Project, ctx: ProvisionContext): def finalize(self, project: Project, ctx: ProvisionContext): pass - - @staticmethod - def build_from_template( - ctx: ProvisionContext, dest_dir: str, temp_section: str, file_name, replacement=None, mode="t", exe=False - ): - section = ctx.get_template_section(temp_section) - if replacement: - section = utils.sh_replace(section, replacement) - utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe) - - -class Provisioner(object): - def __init__(self, root_dir: str, builders: List[Builder]): - """Workflow class that drive the provision process. - - Provisioner's tasks: - - - Maintain the provision workspace folder structure; - - Invoke Builders to generate the content of each startup kit - - ROOT_WORKSPACE Folder Structure:: - - root_workspace_dir_name: this is the root of the workspace - project_dir_name: the root dir of the project, could be named after the project - resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders - prod: stores the current set of startup kits (production) - participate_dir: stores content files generated by builders - wip: stores the set of startup kits to be created (WIP) - participate_dir: stores content files generated by builders - state: stores the persistent state of the Builders - - Args: - root_dir (str): the directory path to hold all generated or intermediate folders - builders (List[Builder]): all builders that will be called to build the content - """ - self.root_dir = root_dir - self.builders = builders - self.template = {} - - def add_template(self, template: dict): - if not isinstance(template, dict): - raise ValueError(f"template must be a dict but got {type(template)}") - self.template.update(template) - - def provision(self, project: Project, mode=None): - server = project.get_server() - if not server: - raise RuntimeError("missing server from the project") - - workspace_root_dir = os.path.join(self.root_dir, project.name) - ctx = ProvisionContext(workspace_root_dir, project) - if self.template: - ctx.set_template(self.template) - - if not mode: - mode = ProvisionMode.NORMAL - ctx.set_provision_mode(mode) - - try: - for b in self.builders: - b.initialize(project, ctx) - - # call builders! - for b in self.builders: - b.build(project, ctx) - - for b in self.builders[::-1]: - b.finalize(project, ctx) - - except Exception: - prod_dir = ctx.get(WorkDir.CURRENT_PROD_DIR) - if prod_dir: - shutil.rmtree(prod_dir) - print("Exception raised during provision. Incomplete prod_n folder removed.") - traceback.print_exc() - finally: - wip_dir = ctx.get(WorkDir.WIP) - if wip_dir: - shutil.rmtree(wip_dir) - return ctx diff --git a/nvflare/tool/poc/poc_commands.py b/nvflare/tool/poc/poc_commands.py index 89e1650380..5b88cbd927 100644 --- a/nvflare/tool/poc/poc_commands.py +++ b/nvflare/tool/poc/poc_commands.py @@ -32,7 +32,7 @@ from nvflare.fuel.utils.gpu_utils import get_host_gpu_ids from nvflare.lighter.constants import ProvisionMode from nvflare.lighter.provision import gen_default_project_config, prepare_project -from nvflare.lighter.spec import Provisioner +from nvflare.lighter.provisioner import Provisioner from nvflare.lighter.utils import ( load_yaml, update_project_server_name_config, diff --git a/tests/unit_test/lighter/participant_test.py b/tests/unit_test/lighter/participant_test.py index 92d8132f46..cb0b3d8339 100644 --- a/tests/unit_test/lighter/participant_test.py +++ b/tests/unit_test/lighter/participant_test.py @@ -14,7 +14,7 @@ import pytest -from nvflare.lighter.spec import Participant +from nvflare.lighter.entity import Participant class TestParticipant: diff --git a/tests/unit_test/lighter/project_test.py b/tests/unit_test/lighter/project_test.py index 4f810c1b60..f462c954ec 100644 --- a/tests/unit_test/lighter/project_test.py +++ b/tests/unit_test/lighter/project_test.py @@ -14,7 +14,7 @@ import pytest -from nvflare.lighter.spec import Participant, Project +from nvflare.lighter.entity import Participant, Project def create_participants(type, number, org, name, props=None):