diff --git a/nvflare/apis/utils/format_check.py b/nvflare/apis/utils/format_check.py index 49760c958e..7e333db1d1 100644 --- a/nvflare/apis/utils/format_check.py +++ b/nvflare/apis/utils/format_check.py @@ -35,7 +35,9 @@ def name_check(name: str, entity_type: str): if re.match(regex_pattern, name): return False, "name={} passed on regex_pattern={} check".format(name, regex_pattern) else: - return True, "name={} is ill-formatted based on regex_pattern={}".format(name, regex_pattern) + return True, "name={} is ill-formatted for entity_type={} based on regex_pattern={}".format( + name, entity_type, regex_pattern + ) def validate_class_methods_args(cls): diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py new file mode 100644 index 0000000000..97b18d0595 --- /dev/null +++ b/nvflare/lighter/constants.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class WorkDir: + WORKSPACE = "workspace" + WIP = "wip_dir" + STATE = "state_dir" + RESOURCES = "resources_dir" + CURRENT_PROD_DIR = "current_prod_dir" + + +class ParticipantType: + SERVER = "server" + CLIENT = "client" + ADMIN = "admin" + OVERSEER = "overseer" + + +class PropKey: + API_VERSION = "api_version" + NAME = "name" + DESCRIPTION = "description" + ROLE = "role" + HOST_NAMES = "host_names" + CONNECT_TO = "connect_to" + LISTENING_HOST = "listening_host" + DEFAULT_HOST = "default_host" + PROTOCOL = "protocol" + API_ROOT = "api_root" + PORT = "port" + OVERSEER_END_POINT = "overseer_end_point" + ADMIN_PORT = "admin_port" + FED_LEARN_PORT = "fed_learn_port" + + +class CtxKey(WorkDir, PropKey): + PROJECT = "__project__" + TEMPLATE = "__template__" + PROVISION_MODE = "__provision_model__" + LAST_PROD_STAGE = "last_prod_stage" + TEMPLATE_FILES = "template_files" + SERVER_NAME = "server_name" + ROOT_CERT = "root_cert" + ROOT_PRI_KEY = "root_pri_key" + + +class ProvisionMode: + POC = "poc" + NORMAL = "normal" + + +class AdminRole: + PROJECT_ADMIN = "project_admin" + ORG_ADMIN = "org_admin" + LEAD = "lead" + MEMBER = "member" + + +class OverseerRole: + SERVER = "server" + CLIENT = "client" + ADMIN = "admin" + + +class TemplateSectionKey: + START_SERVER_SH = "start_svr_sh" + START_CLIENT_SH = "start_cln_sh" + DOCKER_SERVER_SH = "docker_svr_sh" + DOCKER_CLIENT_SH = "docker_cln_sh" + DOCKER_ADMIN_SH = "docker_adm_sh" + GUNICORN_CONF_PY = "gunicorn_conf_py" + START_OVERSEER_SH = "start_ovsr_sh" + FED_SERVER = "fed_server" + FED_CLIENT = "fed_client" + SUB_START_SH = "sub_start_sh" + STOP_FL_SH = "stop_fl_sh" + LOG_CONFIG = "log_config" + LOCAL_SERVER_RESOURCES = "local_server_resources" + LOCAL_CLIENT_RESOURCES = "local_client_resources" + SAMPLE_PRIVACY = "sample_privacy" + DEFAULT_AUTHZ = "default_authz" + SERVER_README = "readme_fs" + CLIENT_README = "readme_fc" + ADMIN_README = "readme_am" + FL_ADMIN_SH = "fl_admin_sh" + FED_ADMIN = "fed_admin" + COMPOSE_YAML = "compose_yaml" + DOCKERFILE = "dockerfile" + HELM_CHART_CHART = "helm_chart_chart" + HELM_CHART_VALUES = "helm_chart_values" + HELM_CHART_SERVICE_OVERSEER = "helm_chart_service_overseer" + HELM_CHART_SERVICE_SERVER = "helm_chart_service_server" + HELM_CHART_DEPLOYMENT_OVERSEER = "helm_chart_deployment_overseer" + HELM_CHART_DEPLOYMENT_SERVER = "helm_chart_deployment_server" + + +class ProvFileName: + START_SH = "start.sh" + SUB_START_SH = "sub_start.sh" + PRIVILEGE_YML = "privilege.yml" + DOCKER_SH = "docker.sh" + GUNICORN_CONF_PY = "gunicorn.conf.py" + FED_SERVER_JSON = "fed_server.json" + FED_CLIENT_JSON = "fed_client.json" + STOP_FL_SH = "stop_fl.sh" + LOG_CONFIG_DEFAULT = "log.config.default" + RESOURCES_JSON_DEFAULT = "resources.json.default" + PRIVACY_JSON_SAMPLE = "privacy.json.sample" + AUTHORIZATION_JSON_DEFAULT = "authorization.json.default" + README_TXT = "readme.txt" + FED_ADMIN_JSON = "fed_admin.json" + FL_ADMIN_SH = "fl_admin.sh" + SIGNATURE_JSON = "signature.json" + COMPOSE_YAML = "compose.yaml" + ENV = ".env" + COMPOSE_BUILD_DIR = "nvflare_compose" + DOCKERFILE = "Dockerfile" + REQUIREMENTS_TXT = "requirements.txt" + SERVER_CONTEXT_TENSEAL = "server_context.tenseal" + CLIENT_CONTEXT_TENSEAL = "client_context.tenseal" + HELM_CHART_DIR = "nvflare_hc" + DEPLOYMENT_OVERSEER_YAML = "deployment_overseer.yaml" + SERVICE_OVERSEER_YAML = "service_overseer.yaml" + CHART_YAML = "Chart.yaml" + VALUES_YAML = "values.yaml" + HELM_CHART_TEMPLATES_DIR = "templates" + + +class CertFileBasename: + CLIENT = "client" + SERVER = "server" + OVERSEER = "overseer" diff --git a/nvflare/lighter/ctx.py b/nvflare/lighter/ctx.py new file mode 100644 index 0000000000..f55cdfd303 --- /dev/null +++ b/nvflare/lighter/ctx.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import yaml + +from nvflare.lighter import utils + +from .constants import CtxKey, PropKey, ProvisionMode +from .entity import Entity, Project + + +class ProvisionContext(dict): + def __init__(self, workspace_root_dir: str, project: Project): + super().__init__() + self[CtxKey.WORKSPACE] = workspace_root_dir + + wip_dir = os.path.join(workspace_root_dir, "wip") + state_dir = os.path.join(workspace_root_dir, "state") + resources_dir = os.path.join(workspace_root_dir, "resources") + self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir}) + dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir] + utils.make_dirs(dirs) + + # set commonly used data into ctx + self[CtxKey.PROJECT] = project + + server = project.get_server() + admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003) + self[CtxKey.ADMIN_PORT] = admin_port + fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002) + self[CtxKey.FED_LEARN_PORT] = fed_learn_port + self[CtxKey.SERVER_NAME] = server.name + + def get_project(self): + return self.get(CtxKey.PROJECT) + + def set_template(self, template: dict): + self[CtxKey.TEMPLATE] = template + + def get_template(self): + return self.get(CtxKey.TEMPLATE) + + def get_template_section(self, section_key: str): + template = self.get_template() + if not template: + raise RuntimeError("template is not available") + + section = template.get(section_key) + if not section: + raise RuntimeError(f"missing section {section} in template") + + return section + + def set_provision_mode(self, mode: str): + valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL] + if mode not in valid_modes: + raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}") + self[CtxKey.PROVISION_MODE] = mode + + def get_provision_mode(self): + return self.get(CtxKey.PROVISION_MODE) + + def get_wip_dir(self): + return self.get(CtxKey.WIP) + + def get_ws_dir(self, entity: Entity): + return os.path.join(self.get_wip_dir(), entity.name) + + def get_kit_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "startup") + + def get_transfer_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "transfer") + + def get_local_dir(self, entity: Entity): + return os.path.join(self.get_ws_dir(entity), "local") + + def get_state_dir(self): + return self.get(CtxKey.STATE) + + def get_resources_dir(self): + return self.get(CtxKey.RESOURCES) + + def get_workspace(self): + return self.get(CtxKey.WORKSPACE) + + def yaml_load_template_section(self, section_key: str): + return yaml.safe_load(self.get_template_section(section_key)) + + def json_load_template_section(self, section_key: str): + return json.loads(self.get_template_section(section_key)) + + def build_from_template(self, dest_dir: str, temp_section: str, file_name, replacement=None, mode="t", exe=False): + section = self.get_template_section(temp_section) + if replacement: + section = utils.sh_replace(section, replacement) + utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe) diff --git a/nvflare/lighter/dummy_project.yml b/nvflare/lighter/dummy_project.yml index b3d9d454ba..69ca96fdcc 100644 --- a/nvflare/lighter/dummy_project.yml +++ b/nvflare/lighter/dummy_project.yml @@ -35,7 +35,6 @@ builders: - master_template.yml - aws_template.yml - azure_template.yml - - path: nvflare.lighter.impl.template.TemplateBuilder - path: nvflare.lighter.impl.static_file.StaticFileBuilder args: # config_folder can be set to inform NVIDIA FLARE where to get configuration diff --git a/nvflare/lighter/entity.py b/nvflare/lighter/entity.py new file mode 100644 index 0000000000..6b40a60c56 --- /dev/null +++ b/nvflare/lighter/entity.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.utils.format_check import name_check + +from .constants import AdminRole, ParticipantType, PropKey + + +def _check_host_name(scope: str, prop_key: str, value): + err, reason = name_check(value, "host_name") + if err: + raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}") + + +def _check_host_names(scope: str, prop_key: str, value): + if isinstance(value, str): + _check_host_name(scope, prop_key, value) + elif isinstance(value, list): + for v in value: + _check_host_name(scope, prop_key, v) + + +def _check_admin_role(scope: str, prop_key: str, value): + valid_roles = [AdminRole.PROJECT_ADMIN, AdminRole.ORG_ADMIN, AdminRole.LEAD, AdminRole.MEMBER] + if value not in valid_roles: + raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be one of {valid_roles}") + + +# validator functions for common properties +# Validator function must follow this signature: +# func(scope: str, prop_key: str, value) +_PROP_VALIDATORS = { + PropKey.HOST_NAMES: _check_host_names, + PropKey.CONNECT_TO: _check_host_name, + PropKey.LISTENING_HOST: _check_host_name, + PropKey.DEFAULT_HOST: _check_host_name, + PropKey.ROLE: _check_admin_role, +} + + +class Entity: + def __init__(self, scope: str, name: str, props: dict, parent=None): + if not props: + props = {} + + for k, v in props.items(): + validator = _PROP_VALIDATORS.get(k) + if validator is not None: + validator(scope, k, v) + self.name = name + self.props = props + self.parent = parent + + def get_prop(self, key: str, default=None): + return self.props.get(key, default) + + def get_prop_fb(self, key: str, fb_key=None, default=None): + """Get property value with fallback. + If I have the property, then return it. + If not, I return the fallback property of my parent. If I don't have parent, return default. + + Args: + key: key of the property + fb_key: key of the fallback property. + default: value to return if no one has the property + + Returns: property value + + """ + value = self.get_prop(key) + if value: + return value + elif not self.parent: + return default + else: + # get the value from the parent + if not fb_key: + fb_key = key + return self.parent.get_prop(fb_key, default) + + +class Participant(Entity): + def __init__(self, type: str, name: str, org: str, props: dict = None, project: Entity = None): + """Class to represent a participant. + + Each participant communicates to other participant. Therefore, each participant has its + own name, type, organization it belongs to, rules and other information. + + Args: + type (str): server, client, admin or other string that builders can handle + name (str): system-wide unique name + org (str): system-wide unique organization + props (dict): properties + project: the project that the participant belongs to + + Raises: + ValueError: if name or org is not compliant with characters or format specification. + """ + Entity.__init__(self, f"{type}::{name}", name, props, parent=project) + + err, reason = name_check(name, type) + if err: + raise ValueError(reason) + + err, reason = name_check(org, "org") + if err: + raise ValueError(reason) + + self.type = type + self.org = org + self.subject = name + + def get_default_host(self) -> str: + """Get the default host name for accessing this participant (server). + If the "default_host" attribute is explicitly specified, then it's the default host. + If the "default_host" attribute is not explicitly specified, then use the "name" attribute. + + Returns: a host name + + """ + h = self.get_prop(PropKey.DEFAULT_HOST) + if h: + return h + else: + return self.name + + +class Project(Entity): + def __init__( + self, + name: str, + description: str, + participants=None, + props: dict = None, + serialized_root_cert=None, + serialized_root_private_key=None, + ): + """A container class to hold information about this FL project. + + This class only holds information. It does not drive the workflow. + + Args: + name (str): the project name + description (str): brief description on this name + participants: if provided, list of participants of the project + props: properties of the project + serialized_root_cert: if provided, the root cert to be used for the project + serialized_root_private_key: if provided, the root private key for signing certs of sites and admins + + Raises: + ValueError: when participant criteria is violated + """ + Entity.__init__(self, "project", name, props) + + if serialized_root_cert: + if not serialized_root_private_key: + raise ValueError("missing serialized_root_private_key while serialized_root_cert is provided") + + self.description = description + self.serialized_root_cert = serialized_root_cert + self.serialized_root_private_key = serialized_root_private_key + self.server = None + self.overseer = None + self.clients = [] + self.admins = [] + self.all_names = {} + + if participants: + if not isinstance(participants, list): + raise ValueError(f"participants must be a list of Participant but got {type(participants)}") + + for p in participants: + if not isinstance(p, Participant): + raise ValueError(f"bad item in participants: must be Participant but got {type(p)}") + + if p.type == ParticipantType.SERVER: + self.set_server(p.name, p.org, p.props) + elif p.type == ParticipantType.ADMIN: + self.add_admin(p.name, p.org, p.props) + elif p.type == ParticipantType.CLIENT: + self.add_client(p.name, p.org, p.props) + elif p.type == ParticipantType.OVERSEER: + self.set_overseer(p.name, p.org, p.props) + else: + raise ValueError(f"invalid value for ParticipantType: {p.type}") + + def _check_unique_name(self, name: str): + if name in self.all_names: + raise ValueError(f"the project {self.name} already has a participant with the name '{name}'") + + def set_server(self, name: str, org: str, props: dict): + if self.server: + raise ValueError(f"project {self.name} already has a server defined") + self._check_unique_name(name) + self.server = Participant(ParticipantType.SERVER, name, org, props, self) + self.all_names[name] = True + + def get_server(self): + """Get the server definition. Only one server is supported! + + Returns: server participant + + """ + return self.server + + def set_overseer(self, name: str, org: str, props: dict): + if self.overseer: + raise ValueError(f"project {self.name} already has an overseer defined") + self._check_unique_name(name) + self.overseer = Participant(ParticipantType.OVERSEER, name, org, props, self) + self.all_names[name] = True + + def get_overseer(self): + """Get the overseer definition. Only one overseer is supported! + + Returns: overseer participant + + """ + return self.overseer + + def add_client(self, name: str, org: str, props: dict): + self._check_unique_name(name) + self.clients.append(Participant(ParticipantType.CLIENT, name, org, props, self)) + self.all_names[name] = True + + def get_clients(self): + return self.clients + + def add_admin(self, name: str, org: str, props: dict): + self._check_unique_name(name) + admin = Participant(ParticipantType.ADMIN, name, org, props, self) + role = admin.get_prop(PropKey.ROLE) + if not role: + raise ValueError(f"missing role in admin '{name}'") + self.admins.append(admin) + self.all_names[name] = True + + def get_admins(self): + return self.admins + + def get_all_participants(self): + result = [] + if self.server: + result.append(self.server) + + if self.overseer: + result.append(self.overseer) + + result.extend(self.clients) + result.extend(self.admins) + return result diff --git a/nvflare/lighter/impl/cert.py b/nvflare/lighter/impl/cert.py index fba9dd5297..95b363876a 100644 --- a/nvflare/lighter/impl/cert.py +++ b/nvflare/lighter/impl/cert.py @@ -22,10 +22,78 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from nvflare.lighter.spec import Builder, Participant +from nvflare.lighter.constants import CertFileBasename, CtxKey, ParticipantType, PropKey +from nvflare.lighter.ctx import ProvisionContext +from nvflare.lighter.entity import Participant, Project +from nvflare.lighter.spec import Builder from nvflare.lighter.utils import serialize_cert, serialize_pri_key +class _CertState: + + CERT_STATE_FILE = "cert.json" + + PROP_ROOT_CERT = CtxKey.ROOT_CERT + PROP_ROOT_PRI_KEY = CtxKey.ROOT_PRI_KEY + PROP_CERT = "cert" + PROP_PRI_KEY = "pri_key" + + def __init__(self, state_dir: str): + self.is_available = False + self.state_dir = state_dir + self.content = {} + cert_file = os.path.join(state_dir, self.CERT_STATE_FILE) + if os.path.exists(cert_file): + self.is_available = True + with open(cert_file, "rt") as f: + self.content.update(json.load(f)) + + def get_root_cert(self): + return self.content.get(self.PROP_ROOT_CERT) + + def set_root_cert(self, cert): + self.content[self.PROP_ROOT_CERT] = cert + + def get_root_pri_key(self): + return self.content.get(self.PROP_ROOT_PRI_KEY) + + def set_root_pri_key(self, key): + self.content[self.PROP_ROOT_PRI_KEY] = key + + def has_subject(self, subject: str): + return subject in self.content + + def _add_subject_prop(self, subject: str, key: str, value): + subject_data = self.content.get(subject) + if not subject_data: + subject_data = {} + self.content[subject] = subject_data + subject_data[key] = value + + def _get_subject_prop(self, subject: str, key: str): + subject_data = self.content.get(subject) + if not subject_data: + return None + return subject_data.get(key) + + def add_subject_cert(self, subject: str, cert): + self._add_subject_prop(subject, self.PROP_CERT, cert) + + def get_subject_cert(self, subject: str): + return self._get_subject_prop(subject, self.PROP_CERT) + + def add_subject_pri_key(self, subject: str, pri_key): + self._add_subject_prop(subject, self.PROP_PRI_KEY, pri_key) + + def get_subject_pri_key(self, subject: str): + return self._get_subject_prop(subject, self.PROP_PRI_KEY) + + def persist(self): + cert_file = os.path.join(self.state_dir, self.CERT_STATE_FILE) + with open(cert_file, "wt") as f: + json.dump(self.content, f) + + class CertBuilder(Builder): def __init__(self): """Build certificate chain for every participant. @@ -35,47 +103,62 @@ def __init__(self): information about previously generated certs, it loads them back and reuses them. """ self.root_cert = None - self.persistent_state = dict() + self.persistent_state = None + self.serialized_cert = None + self.pri_key = None + self.pub_key = None + self.subject = None + self.issuer = None - def initialize(self, ctx): - state_dir = self.get_state_dir(ctx) - cert_file = os.path.join(state_dir, "cert.json") - if os.path.exists(cert_file): - self.persistent_state = json.load(open(cert_file, "rt")) - self.serialized_cert = self.persistent_state["root_cert"].encode("ascii") + def initialize(self, project: Project, ctx: ProvisionContext): + state_dir = ctx.get_state_dir() + self.persistent_state = _CertState(state_dir) + state = self.persistent_state + + if state.is_available: + state_root_cert = state.get_root_cert() + self.serialized_cert = state_root_cert.encode("ascii") self.root_cert = x509.load_pem_x509_certificate(self.serialized_cert, default_backend()) + + state_pri_key = state.get_root_pri_key() self.pri_key = serialization.load_pem_private_key( - self.persistent_state["root_pri_key"].encode("ascii"), password=None, backend=default_backend() + state_pri_key.encode("ascii"), password=None, backend=default_backend() ) + self.pub_key = self.pri_key.public_key() self.subject = self.root_cert.subject self.issuer = self.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value def _build_root(self, subject, subject_org): - if not self.persistent_state: + assert isinstance(self.persistent_state, _CertState) + if not self.persistent_state.is_available: pri_key, pub_key = self._generate_keys() self.issuer = subject self.root_cert = self._generate_cert(subject, subject_org, self.issuer, pri_key, pub_key, ca=True) self.pri_key = pri_key self.pub_key = pub_key self.serialized_cert = serialize_cert(self.root_cert) - self.persistent_state["root_cert"] = self.serialized_cert.decode("ascii") - self.persistent_state["root_pri_key"] = serialize_pri_key(self.pri_key).decode("ascii") - - def _build_write_cert_pair(self, participant, base_name, ctx): - subject = self.get_subject(participant) - if self.persistent_state and subject in self.persistent_state: - cert = x509.load_pem_x509_certificate( - self.persistent_state[subject]["cert"].encode("ascii"), default_backend() - ) + + self.persistent_state.set_root_cert(self.serialized_cert.decode("ascii")) + self.persistent_state.set_root_pri_key(serialize_pri_key(self.pri_key).decode("ascii")) + + def _build_write_cert_pair(self, participant: Participant, base_name, ctx: ProvisionContext): + assert isinstance(self.persistent_state, _CertState) + subject = participant.subject + if self.persistent_state.has_subject(subject): + subject_cert = self.persistent_state.get_subject_cert(subject) + cert = x509.load_pem_x509_certificate(subject_cert.encode("ascii"), default_backend()) + + subject_pri_key = self.persistent_state.get_subject_pri_key(subject) pri_key = serialization.load_pem_private_key( - self.persistent_state[subject]["pri_key"].encode("ascii"), password=None, backend=default_backend() + subject_pri_key.encode("ascii"), password=None, backend=default_backend() ) - if participant.type == "admin": + + if participant.type == ParticipantType.ADMIN: cn_list = cert.subject.get_attributes_for_oid(NameOID.UNSTRUCTURED_NAME) for cn in cn_list: role = cn.value - new_role = participant.props.get("role") + new_role = participant.get_prop(PropKey.ROLE) if role != new_role: err_msg = ( f"{participant.name}'s previous role is {role} but is now {new_role}.\n" @@ -84,58 +167,63 @@ def _build_write_cert_pair(self, participant, base_name, ctx): raise RuntimeError(err_msg) else: pri_key, cert = self.get_pri_key_cert(participant) - self.persistent_state[subject] = dict( - cert=serialize_cert(cert).decode("ascii"), pri_key=serialize_pri_key(pri_key).decode("ascii") - ) - dest_dir = self.get_kit_dir(participant, ctx) + self.persistent_state.add_subject_cert(subject, serialize_cert(cert).decode("ascii")) + self.persistent_state.add_subject_pri_key(subject, serialize_pri_key(pri_key).decode("ascii")) + + dest_dir = ctx.get_kit_dir(participant) with open(os.path.join(dest_dir, f"{base_name}.crt"), "wb") as f: f.write(serialize_cert(cert)) with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f: f.write(serialize_pri_key(pri_key)) - if base_name == "client" and (listening_host := participant.get_listening_host()): + + if base_name == CertFileBasename.CLIENT and (listening_host := participant.get_prop(PropKey.LISTENING_HOST)): + project = ctx.get_project() tmp_participant = Participant( - type="server", + type=ParticipantType.SERVER, name=participant.name, org=participant.org, - default_host=listening_host, + project=project, + props={PropKey.DEFAULT_HOST: listening_host}, ) tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant) - with open(os.path.join(dest_dir, "server.crt"), "wb") as f: + bn = CertFileBasename.SERVER + with open(os.path.join(dest_dir, f"{bn}.crt"), "wb") as f: f.write(serialize_cert(tmp_cert)) - with open(os.path.join(dest_dir, "server.key"), "wb") as f: + with open(os.path.join(dest_dir, f"{bn}.key"), "wb") as f: f.write(serialize_pri_key(tmp_pri_key)) with open(os.path.join(dest_dir, "rootCA.pem"), "wb") as f: f.write(self.serialized_cert) - def build(self, project, ctx): + def build(self, project: Project, ctx: ProvisionContext): self._build_root(project.name, subject_org=None) - ctx["root_cert"] = self.root_cert - ctx["root_pri_key"] = self.pri_key - overseer = project.get_participants_by_type("overseer") + ctx[CtxKey.ROOT_CERT] = self.root_cert + ctx[CtxKey.ROOT_PRI_KEY] = self.pri_key + + overseer = project.get_overseer() if overseer: - self._build_write_cert_pair(overseer, "overseer", ctx) + self._build_write_cert_pair(overseer, CertFileBasename.OVERSEER, ctx) - servers = project.get_participants_by_type("server", first_only=False) - for server in servers: - self._build_write_cert_pair(server, "server", ctx) + server = project.get_server() + if server: + self._build_write_cert_pair(server, CertFileBasename.SERVER, ctx) - for client in project.get_participants_by_type("client", first_only=False): - self._build_write_cert_pair(client, "client", ctx) + for client in project.get_clients(): + self._build_write_cert_pair(client, CertFileBasename.CLIENT, ctx) - for admin in project.get_participants_by_type("admin", first_only=False): - self._build_write_cert_pair(admin, "client", ctx) + for admin in project.get_admins(): + self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx) - def get_pri_key_cert(self, participant): + def get_pri_key_cert(self, participant: Participant): pri_key, pub_key = self._generate_keys() - subject = self.get_subject(participant) + subject = participant.subject subject_org = participant.org - if participant.type == "admin": - role = participant.get_prop("role") + if participant.type == ParticipantType.ADMIN: + role = participant.get_prop(PropKey.ROLE) else: role = None - server = participant if participant.type == "server" else None + server = participant if participant.type == ParticipantType.SERVER else None cert = self._generate_cert( subject, subject_org, @@ -147,9 +235,6 @@ def get_pri_key_cert(self, participant): ) return pri_key, cert - def get_subject(self, participant): - return participant.subject - def _generate_keys(self): pri_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) pub_key = pri_key.public_key() @@ -201,7 +286,7 @@ def _generate_cert( # This is to generate a server cert. # Use SubjectAlternativeName for all host names default_host = server.get_default_host() - host_names = server.get_host_names() + host_names = server.get_prop(PropKey.HOST_NAMES) sans = [x509.DNSName(default_host)] if host_names: for h in host_names: @@ -220,7 +305,6 @@ def _x509_name(self, cn_name, org_name=None, role=None): name.append(x509.NameAttribute(NameOID.UNSTRUCTURED_NAME, role)) return x509.Name(name) - def finalize(self, ctx): - state_dir = self.get_state_dir(ctx) - cert_file = os.path.join(state_dir, "cert.json") - json.dump(self.persistent_state, open(cert_file, "wt")) + def finalize(self, project: Project, ctx: ProvisionContext): + assert isinstance(self.persistent_state, _CertState) + self.persistent_state.persist() diff --git a/nvflare/lighter/impl/docker.py b/nvflare/lighter/impl/docker.py index 3cbbf7829f..8773b64dc7 100644 --- a/nvflare/lighter/impl/docker.py +++ b/nvflare/lighter/impl/docker.py @@ -18,7 +18,8 @@ import yaml -from nvflare.lighter.spec import Builder +from nvflare.lighter.constants import CtxKey, ProvFileName, TemplateSectionKey +from nvflare.lighter.spec import Builder, Project, ProvisionContext class DockerBuilder(Builder): @@ -26,8 +27,10 @@ def __init__(self, base_image="python:3.8", requirements_file="requirements.txt" """Build docker compose file.""" self.base_image = base_image self.requirements_file = requirements_file + self.services = {} + self.compose_file_path = None - def _build_overseer(self, overseer, ctx): + def _build_overseer(self, overseer): protocol = overseer.props.get("protocol", "http") default_port = "443" if protocol == "https" else "80" port = overseer.props.get("port", default_port) @@ -38,9 +41,9 @@ def _build_overseer(self, overseer, ctx): info_dict["container_name"] = overseer.name self.services[overseer.name] = info_dict - def _build_server(self, server, ctx): - fed_learn_port = server.props.get("fed_learn_port", 8002) - admin_port = server.props.get("admin_port", 8003) + def _build_server(self, server, ctx: ProvisionContext): + fed_learn_port = ctx.get(CtxKey.FED_LEARN_PORT) + admin_port = ctx.get(CtxKey.ADMIN_PORT) info_dict = copy.deepcopy(self.services["__flserver__"]) info_dict["volumes"][0] = f"./{server.name}:" + "${WORKSPACE}" @@ -54,7 +57,7 @@ def _build_server(self, server, ctx): info_dict["container_name"] = server.name self.services[server.name] = info_dict - def _build_client(self, client, ctx): + def _build_client(self, client): info_dict = copy.deepcopy(self.services["__flclient__"]) info_dict["volumes"] = [f"./{client.name}:" + "${WORKSPACE}"] info_dict["build"] = "nvflare_compose" @@ -68,37 +71,38 @@ def _build_client(self, client, ctx): info_dict["container_name"] = client.name self.services[client.name] = info_dict - def build(self, project, ctx): - self.template = ctx.get("template") - self.compose = yaml.safe_load(self.template.get("compose_yaml")) - self.services = self.compose.get("services") - self.compose_file_path = os.path.join(self.get_wip_dir(ctx), "compose.yaml") - overseer = project.get_participants_by_type("overseer") + def build(self, project: Project, ctx: ProvisionContext): + compose = ctx.yaml_load_template_section(TemplateSectionKey.COMPOSE_YAML) + self.services = compose.get("services") + self.compose_file_path = os.path.join(ctx.get_wip_dir(), ProvFileName.COMPOSE_YAML) + overseer = project.get_overseer() if overseer: - self._build_overseer(overseer, ctx) - servers = project.get_participants_by_type("server", first_only=False) - for server in servers: + self._build_overseer(overseer) + server = project.get_server() + if server: self._build_server(server, ctx) - for client in project.get_participants_by_type("client", first_only=False): - self._build_client(client, ctx) + + for client in project.get_clients(): + self._build_client(client) + self.services.pop("__overseer__", None) self.services.pop("__flserver__", None) self.services.pop("__flclient__", None) - self.compose["services"] = self.services + compose["services"] = self.services with open(self.compose_file_path, "wt") as f: - yaml.dump(self.compose, f) - env_file_path = os.path.join(self.get_wip_dir(ctx), ".env") + yaml.dump(compose, f) + env_file_path = os.path.join(ctx.get_wip_dir(), ProvFileName.ENV) with open(env_file_path, "wt") as f: f.write("WORKSPACE=/workspace\n") f.write("PYTHON_EXECUTABLE=/usr/local/bin/python3\n") f.write("IMAGE_NAME=nvflare-service\n") - compose_build_dir = os.path.join(self.get_wip_dir(ctx), "nvflare_compose") + compose_build_dir = os.path.join(ctx.get_wip_dir(), ProvFileName.COMPOSE_BUILD_DIR) os.mkdir(compose_build_dir) - with open(os.path.join(compose_build_dir, "Dockerfile"), "wt") as f: + with open(os.path.join(compose_build_dir, ProvFileName.DOCKERFILE), "wt") as f: f.write(f"FROM {self.base_image}\n") - f.write(self.template.get("dockerfile")) + f.write(ctx.get_template_section(TemplateSectionKey.DOCKERFILE)) try: - shutil.copyfile(self.requirements_file, os.path.join(compose_build_dir, "requirements.txt")) + shutil.copyfile(self.requirements_file, os.path.join(compose_build_dir, ProvFileName.REQUIREMENTS_TXT)) except Exception: - f = open(os.path.join(compose_build_dir, "requirements.txt"), "wt") + f = open(os.path.join(compose_build_dir, ProvFileName.REQUIREMENTS_TXT), "wt") f.close() diff --git a/nvflare/lighter/impl/he.py b/nvflare/lighter/impl/he.py index 61f3760b01..65051b5dda 100644 --- a/nvflare/lighter/impl/he.py +++ b/nvflare/lighter/impl/he.py @@ -16,14 +16,15 @@ import tenseal as ts -from nvflare.lighter.spec import Builder +from nvflare.lighter.constants import ProvFileName +from nvflare.lighter.spec import Builder, Project, ProvisionContext class HEBuilder(Builder): def __init__( self, poly_modulus_degree=8192, - coeff_mod_bit_sizes=[60, 40, 40], + coeff_mod_bit_sizes=None, scale_bits=40, scheme="CKKS", ): @@ -38,6 +39,9 @@ def __init__( scale_bits: defaults to 40. scheme: defaults to "CKKS". """ + if not coeff_mod_bit_sizes: + coeff_mod_bit_sizes = [60, 40, 40] + self._context = None self.scheme_type_mapping = { "CKKS": ts.SCHEME_TYPE.CKKS, @@ -51,7 +55,7 @@ def __init__( self.scheme_type = self.scheme_type_mapping[_scheme] self.serialized = None - def initialize(self, ctx): + def initialize(self, project: Project, ctx: ProvisionContext): self._context = ts.context( self.scheme_type, poly_modulus_degree=self.poly_modulus_degree, @@ -63,15 +67,15 @@ def initialize(self, ctx): self._context.generate_relin_keys() self._context.global_scale = 2**self.scale_bits - def build(self, project, ctx): - servers = project.get_participants_by_type("server", first_only=False) - for server in servers: - dest_dir = self.get_kit_dir(server, ctx) - with open(os.path.join(dest_dir, "server_context.tenseal"), "wb") as f: + def build(self, project: Project, ctx: ProvisionContext): + server = project.get_server() + if server: + dest_dir = ctx.get_kit_dir(server) + with open(os.path.join(dest_dir, ProvFileName.SERVER_CONTEXT_TENSEAL), "wb") as f: f.write(self.get_serialized_context()) - for client in project.get_participants_by_type("client", first_only=False): - dest_dir = self.get_kit_dir(client, ctx) - with open(os.path.join(dest_dir, "client_context.tenseal"), "wb") as f: + for client in project.get_clients(): + dest_dir = ctx.get_kit_dir(client) + with open(os.path.join(dest_dir, ProvFileName.CLIENT_CONTEXT_TENSEAL), "wb") as f: f.write(self.get_serialized_context(is_client=True)) def get_serialized_context(self, is_client=False): diff --git a/nvflare/lighter/impl/helm_chart.py b/nvflare/lighter/impl/helm_chart.py index 6d2438521f..570e1113eb 100644 --- a/nvflare/lighter/impl/helm_chart.py +++ b/nvflare/lighter/impl/helm_chart.py @@ -16,22 +16,30 @@ import yaml -from nvflare.lighter.spec import Builder +from nvflare.lighter.constants import CtxKey, PropKey, ProvFileName, TemplateSectionKey +from nvflare.lighter.entity import Participant +from nvflare.lighter.spec import Builder, Project, ProvisionContext class HelmChartBuilder(Builder): def __init__(self, docker_image): """Build Helm Chart.""" self.docker_image = docker_image - - def initialize(self, ctx): - self.helm_chart_directory = os.path.join(self.get_wip_dir(ctx), "nvflare_hc") + self.helm_chart_directory = None + self.service_overseer = None + self.service_server = None + self.deployment_server = None + self.deployment_overseer = None + self.helm_chart_templates_directory = None + + def initialize(self, project: Project, ctx: ProvisionContext): + self.helm_chart_directory = os.path.join(ctx.get_wip_dir(), ProvFileName.HELM_CHART_DIR) os.mkdir(self.helm_chart_directory) - def _build_overseer(self, overseer, ctx): - protocol = overseer.props.get("protocol", "http") + def _build_overseer(self, overseer: Participant): + protocol = overseer.get_prop(PropKey.PROTOCOL, "http") default_port = "443" if protocol == "https" else "80" - port = overseer.props.get("port", default_port) + port = overseer.get_prop(PropKey.PORT, default_port) self.deployment_overseer["spec"]["template"]["spec"]["volumes"][0]["hostPath"][ "path" ] = "{{ .Values.workspace }}" @@ -40,18 +48,17 @@ def _build_overseer(self, overseer, ctx): self.deployment_overseer["spec"]["template"]["spec"]["containers"][0]["command"][ 0 ] = f"/workspace/{overseer.name}/startup/start.sh" - with open(os.path.join(self.helm_chart_templates_directory, "deployment_overseer.yaml"), "wt") as f: + with open(os.path.join(self.helm_chart_templates_directory, ProvFileName.DEPLOYMENT_OVERSEER_YAML), "wt") as f: yaml.dump(self.deployment_overseer, f) self.service_overseer["spec"]["ports"][0]["port"] = port self.service_overseer["spec"]["ports"][0]["targetPort"] = port - with open(os.path.join(self.helm_chart_templates_directory, "service_overseer.yaml"), "wt") as f: + with open(os.path.join(self.helm_chart_templates_directory, ProvFileName.SERVICE_OVERSEER_YAML), "wt") as f: yaml.dump(self.service_overseer, f) - def _build_server(self, server, ctx): - fed_learn_port = server.props.get("fed_learn_port", 30002) - admin_port = server.props.get("admin_port", 30003) - idx = ctx["index"] + def _build_server(self, server: Participant, ctx: ProvisionContext, idx: int): + fed_learn_port = ctx.get(CtxKey.FED_LEARN_PORT, 30002) + admin_port = ctx.get(CtxKey.ADMIN_PORT, 30003) self.deployment_server["metadata"]["name"] = f"{server.name}" self.deployment_server["metadata"]["labels"]["system"] = f"{server.name}" @@ -91,25 +98,26 @@ def _build_server(self, server, ctx): with open(os.path.join(self.helm_chart_templates_directory, f"service_server{idx}.yaml"), "wt") as f: yaml.dump(self.service_server, f) - def build(self, project, ctx): - self.template = ctx.get("template") - with open(os.path.join(self.helm_chart_directory, "Chart.yaml"), "wt") as f: - yaml.dump(yaml.safe_load(self.template.get("helm_chart_chart")), f) - - with open(os.path.join(self.helm_chart_directory, "values.yaml"), "wt") as f: - yaml.dump(yaml.safe_load(self.template.get("helm_chart_values")), f) + def build(self, project: Project, ctx: ProvisionContext): + with open(os.path.join(self.helm_chart_directory, ProvFileName.CHART_YAML), "wt") as f: + yaml.dump(ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_CHART), f) - self.service_overseer = yaml.safe_load(self.template.get("helm_chart_service_overseer")) - self.service_server = yaml.safe_load(self.template.get("helm_chart_service_server")) + with open(os.path.join(self.helm_chart_directory, ProvFileName.VALUES_YAML), "wt") as f: + yaml.dump(ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_VALUES), f) - self.deployment_overseer = yaml.safe_load(self.template.get("helm_chart_deployment_overseer")) - self.deployment_server = yaml.safe_load(self.template.get("helm_chart_deployment_server")) + self.service_overseer = ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_SERVICE_OVERSEER) + self.service_server = ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_SERVICE_SERVER) - self.helm_chart_templates_directory = os.path.join(self.helm_chart_directory, "templates") + self.deployment_overseer = ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_DEPLOYMENT_OVERSEER) + self.deployment_server = ctx.yaml_load_template_section(TemplateSectionKey.HELM_CHART_DEPLOYMENT_SERVER) + self.helm_chart_templates_directory = os.path.join( + self.helm_chart_directory, ProvFileName.HELM_CHART_TEMPLATES_DIR + ) os.mkdir(self.helm_chart_templates_directory) - overseer = project.get_participants_by_type("overseer") - self._build_overseer(overseer, ctx) - servers = project.get_participants_by_type("server", first_only=False) - for index, server in enumerate(servers): - ctx["index"] = index - self._build_server(server, ctx) + overseer = project.get_overseer() + if overseer: + self._build_overseer(overseer) + + server = project.get_server() + if server: + self._build_server(server, ctx, 0) diff --git a/nvflare/lighter/impl/signature.py b/nvflare/lighter/impl/signature.py index f45e915dc4..48943c902c 100644 --- a/nvflare/lighter/impl/signature.py +++ b/nvflare/lighter/impl/signature.py @@ -15,7 +15,8 @@ import json import os -from nvflare.lighter.spec import Builder, Project +from nvflare.lighter.constants import CtxKey, ProvFileName +from nvflare.lighter.spec import Builder, Project, ProvisionContext from nvflare.lighter.utils import sign_all @@ -26,23 +27,31 @@ class SignatureBuilder(Builder): can be cryptographically verified to ensure any tampering is detected. This builder writes the signature.json file. """ - def _do_sign(self, root_pri_key, dest_dir): + @staticmethod + def _do_sign(root_pri_key, dest_dir): signatures = sign_all(dest_dir, root_pri_key) - json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) + with open(os.path.join(dest_dir, ProvFileName.SIGNATURE_JSON), "wt") as f: + json.dump(signatures, f) - def build(self, project: Project, ctx: dict): - root_pri_key = ctx.get("root_pri_key") + def build(self, project: Project, ctx: ProvisionContext): + root_pri_key = ctx.get(CtxKey.ROOT_PRI_KEY) + if not root_pri_key: + raise RuntimeError(f"missing {CtxKey.ROOT_PRI_KEY} in ProvisionContext") - overseer = project.get_participants_by_type("overseer") + overseer = project.get_overseer() if overseer: - dest_dir = self.get_kit_dir(overseer, ctx) + dest_dir = ctx.get_kit_dir(overseer) self._do_sign(root_pri_key, dest_dir) - servers = project.get_participants_by_type("server", first_only=False) - for server in servers: - dest_dir = self.get_kit_dir(server, ctx) + server = project.get_server() + if server: + dest_dir = ctx.get_kit_dir(server) self._do_sign(root_pri_key, dest_dir) - for p in project.get_participants_by_type("client", first_only=False): - dest_dir = self.get_kit_dir(p, ctx) + for p in project.get_clients(): + dest_dir = ctx.get_kit_dir(p) + self._do_sign(root_pri_key, dest_dir) + + for admin in project.get_admins(): + dest_dir = ctx.get_kit_dir(admin) self._do_sign(root_pri_key, dest_dir) diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index e6feb7ee77..742f2eaac2 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -19,20 +19,20 @@ import yaml from nvflare.lighter import utils -from nvflare.lighter.spec import Builder, Participant +from nvflare.lighter.constants import CtxKey, OverseerRole, PropKey, ProvFileName, ProvisionMode, TemplateSectionKey +from nvflare.lighter.entity import Participant +from nvflare.lighter.spec import Builder, Project, ProvisionContext class StaticFileBuilder(Builder): def __init__( self, - enable_byoc=False, config_folder="", scheme="grpc", app_validator="", download_job_url="", docker_image="", - snapshot_persistor="", - overseer_agent="", + overseer_agent: dict = None, components="", ): """Build all static files from template. @@ -46,97 +46,80 @@ def __init__( file and string replacement to generate those static files for each participant. Args: - enable_byoc: for each participant, true to enable loading of code in the custom folder of applications config_folder: usually "config" app_validator: optional path to an app validator to verify that uploaded app has the expected structure - docker_image: when docker_image is set to a docker image name, docker.sh will be generated on server/client/admin + docker_image: when docker_image is set to a docker image name, docker.sh will be generated on + server/client/admin """ - self.enable_byoc = enable_byoc self.config_folder = config_folder self.scheme = scheme self.docker_image = docker_image self.download_job_url = download_job_url self.app_validator = app_validator self.overseer_agent = overseer_agent - self.snapshot_persistor = snapshot_persistor self.components = components - def get_server_name(self, server): - return server.name - - def get_overseer_name(self, overseer): - return overseer.name - - def _build_overseer(self, overseer, ctx): - dest_dir = self.get_kit_dir(overseer, ctx) - utils._write( - os.path.join(dest_dir, "start.sh"), - self.template["start_svr_sh"], - "t", - exe=True, - ) - protocol = overseer.props.get("protocol", "http") - api_root = overseer.props.get("api_root", "/api/v1/") + def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): + dest_dir = ctx.get_kit_dir(overseer) + protocol = overseer.get_prop(PropKey.PROTOCOL, "http") + api_root = overseer.get_prop(PropKey.API_ROOT, "/api/v1/") default_port = "443" if protocol == "https" else "80" - port = overseer.props.get("port", default_port) - replacement_dict = {"port": port, "hostname": self.get_overseer_name(overseer)} - admins = self.project.get_participants_by_type("admin", first_only=False) + port = overseer.get_prop(PropKey.PORT, default_port) + replacement_dict = {"port": port, "hostname": overseer.name} + + project = ctx.get_project() + admins = project.get_admins() privilege_dict = dict() for admin in admins: - role = admin.props.get("role") + role = admin.get_prop(PropKey.ROLE) if role in privilege_dict: privilege_dict[role].append(admin.subject) else: privilege_dict[role] = [admin.subject] - utils._write( - os.path.join(dest_dir, "privilege.yml"), + + utils.write( + os.path.join(dest_dir, ProvFileName.PRIVILEGE_YML), yaml.dump(privilege_dict, Dumper=yaml.Dumper), "t", exe=False, ) if self.docker_image: - utils._write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(self.template["docker_svr_sh"], replacement_dict), - "t", - exe=True, + ctx.build_from_template( + dest_dir, TemplateSectionKey.DOCKER_SERVER_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True ) - utils._write( - os.path.join(dest_dir, "gunicorn.conf.py"), - utils.sh_replace(self.template["gunicorn_conf_py"], replacement_dict), - "t", + + ctx.build_from_template( + dest_dir, + TemplateSectionKey.GUNICORN_CONF_PY, + ProvFileName.GUNICORN_CONF_PY, + replacement_dict, exe=False, ) - utils._write( - os.path.join(dest_dir, "start.sh"), - self.template["start_ovsr_sh"], - "t", - exe=True, - ) + + ctx.build_from_template(dest_dir, TemplateSectionKey.START_OVERSEER_SH, ProvFileName.START_SH, exe=True) + if port: - ctx["overseer_end_point"] = f"{protocol}://{self.get_overseer_name(overseer)}:{port}{api_root}" + ctx[PropKey.OVERSEER_END_POINT] = f"{protocol}://{overseer.name}:{port}{api_root}" else: - ctx["overseer_end_point"] = f"{protocol}://{self.get_overseer_name(overseer)}{api_root}" + ctx[PropKey.OVERSEER_END_POINT] = f"{protocol}://{overseer.name}{api_root}" - def _build_server(self, server, ctx): - config = json.loads(self.template["fed_server"]) - dest_dir = self.get_kit_dir(server, ctx) + def _build_server(self, server: Participant, ctx: ProvisionContext): + project = ctx.get_project() + config = ctx.json_load_template_section(TemplateSectionKey.FED_SERVER) + dest_dir = ctx.get_kit_dir(server) server_0 = config["servers"][0] - server_0["name"] = self.project_name - admin_port = server.get_prop("admin_port", 8003) - ctx["admin_port"] = admin_port - fed_learn_port = server.get_prop("fed_learn_port", 8002) - ctx["fed_learn_port"] = fed_learn_port - ctx["server_name"] = self.get_server_name(server) - server_0["service"]["target"] = f"{self.get_server_name(server)}:{fed_learn_port}" + server_0["name"] = project.name + admin_port = ctx.get(CtxKey.ADMIN_PORT) + fed_learn_port = ctx.get(CtxKey.FED_LEARN_PORT) + server_0["service"]["target"] = f"{server.name}:{fed_learn_port}" server_0["service"]["scheme"] = self.scheme - server_0["admin_host"] = self.get_server_name(server) + server_0["admin_host"] = server.name server_0["admin_port"] = admin_port - self._prepare_overseer_agent(server, config, "server", ctx) + self._prepare_overseer_agent(server, config, OverseerRole.SERVER, ctx) + utils.write(os.path.join(dest_dir, ProvFileName.FED_SERVER_JSON), json.dumps(config, indent=2), "t") - utils._write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t") replacement_dict = { "admin_port": admin_port, "fed_learn_port": fed_learn_port, @@ -146,70 +129,58 @@ def _build_server(self, server, ctx): "type": "server", "cln_uid": "", } + if self.docker_image: - utils._write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(self.template["docker_svr_sh"], replacement_dict), - "t", + ctx.build_from_template( + dest_dir, + TemplateSectionKey.DOCKER_SERVER_SH, + ProvFileName.DOCKER_SH, + replacement=replacement_dict, exe=True, ) - utils._write( - os.path.join(dest_dir, "start.sh"), - self.template["start_svr_sh"], - "t", - exe=True, - ) - utils._write( - os.path.join(dest_dir, "sub_start.sh"), - utils.sh_replace(self.template["sub_start_sh"], replacement_dict), - "t", - exe=True, - ) - utils._write( - os.path.join(dest_dir, "stop_fl.sh"), - self.template["stop_fl_sh"], - "t", + + ctx.build_from_template(dest_dir, TemplateSectionKey.START_SERVER_SH, ProvFileName.START_SH, exe=True) + + ctx.build_from_template( + dest_dir, + TemplateSectionKey.SUB_START_SH, + ProvFileName.SUB_START_SH, + replacement=replacement_dict, exe=True, ) + + ctx.build_from_template(dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) + # local folder creation - dest_dir = self.get_local_dir(server, ctx) - utils._write( - os.path.join(dest_dir, "log.config.default"), - self.template["log_config"], - "t", - ) - utils._write( - os.path.join(dest_dir, "resources.json.default"), - self.template["local_server_resources"], - "t", + dest_dir = ctx.get_local_dir(server) + + ctx.build_from_template(dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT, exe=False) + + ctx.build_from_template( + dest_dir, TemplateSectionKey.LOCAL_SERVER_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT, exe=False ) - utils._write( - os.path.join(dest_dir, "privacy.json.sample"), - self.template["sample_privacy"], - "t", + + ctx.build_from_template( + dest_dir, TemplateSectionKey.SAMPLE_PRIVACY, ProvFileName.PRIVACY_JSON_SAMPLE, exe=False ) - utils._write( - os.path.join(dest_dir, "authorization.json.default"), - self.template["default_authz"], - "t", + + ctx.build_from_template( + dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT, exe=False ) # workspace folder file - utils._write( - os.path.join(self.get_ws_dir(server, ctx), "readme.txt"), - self.template["readme_fs"], - "t", - ) + dest_dir = ctx.get_ws_dir(server) + ctx.build_from_template(dest_dir, TemplateSectionKey.SERVER_README, ProvFileName.README_TXT, exe=False) def _build_client(self, client, ctx): - project = ctx["project"] + project = ctx.get_project() server = project.get_server() if not server: raise ValueError("missing server definition in project") - config = json.loads(self.template["fed_client"]) - dest_dir = self.get_kit_dir(client, ctx) + config = ctx.json_load_template_section(TemplateSectionKey.FED_CLIENT) + dest_dir = ctx.get_kit_dir(client) config["servers"][0]["service"]["scheme"] = self.scheme - config["servers"][0]["name"] = self.project_name + config["servers"][0]["name"] = project.name config["servers"][0]["identity"] = server.name # the official identity of the server replacement_dict = { "client_name": f"{client.subject}", @@ -220,93 +191,79 @@ def _build_client(self, client, ctx): "cln_uid": f"uid={client.subject}", } - self._prepare_overseer_agent(client, config, "client", ctx) + self._prepare_overseer_agent(client, config, OverseerRole.CLIENT, ctx) + + utils.write(os.path.join(dest_dir, ProvFileName.FED_CLIENT_JSON), json.dumps(config, indent=2), "t") - utils._write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t") if self.docker_image: - utils._write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(self.template["docker_cln_sh"], replacement_dict), - "t", + ctx.build_from_template( + dest_dir, + TemplateSectionKey.DOCKER_CLIENT_SH, + ProvFileName.DOCKER_SH, + replacement_dict, exe=True, ) - utils._write( - os.path.join(dest_dir, "start.sh"), - self.template["start_cln_sh"], - "t", - exe=True, - ) - utils._write( - os.path.join(dest_dir, "sub_start.sh"), - utils.sh_replace(self.template["sub_start_sh"], replacement_dict), - "t", - exe=True, - ) - utils._write( - os.path.join(dest_dir, "stop_fl.sh"), - self.template["stop_fl_sh"], - "t", - exe=True, + + ctx.build_from_template(dest_dir, TemplateSectionKey.START_CLIENT_SH, ProvFileName.START_SH, exe=True) + + ctx.build_from_template( + dest_dir, TemplateSectionKey.SUB_START_SH, ProvFileName.SUB_START_SH, replacement_dict, exe=True ) + + ctx.build_from_template(dest_dir, TemplateSectionKey.STOP_FL_SH, ProvFileName.STOP_FL_SH, exe=True) + # local folder creation - dest_dir = self.get_local_dir(client, ctx) - utils._write( - os.path.join(dest_dir, "log.config.default"), - self.template["log_config"], - "t", - ) - utils._write( - os.path.join(dest_dir, "resources.json.default"), - self.template["local_client_resources"], - "t", - ) - utils._write( - os.path.join(dest_dir, "privacy.json.sample"), - self.template["sample_privacy"], - "t", + dest_dir = ctx.get_local_dir(client) + + ctx.build_from_template(dest_dir, TemplateSectionKey.LOG_CONFIG, ProvFileName.LOG_CONFIG_DEFAULT) + + ctx.build_from_template( + dest_dir, TemplateSectionKey.LOCAL_CLIENT_RESOURCES, ProvFileName.RESOURCES_JSON_DEFAULT ) - utils._write( - os.path.join(dest_dir, "authorization.json.default"), - self.template["default_authz"], - "t", + + ctx.build_from_template( + dest_dir, + TemplateSectionKey.SAMPLE_PRIVACY, + ProvFileName.PRIVACY_JSON_SAMPLE, ) + ctx.build_from_template(dest_dir, TemplateSectionKey.DEFAULT_AUTHZ, ProvFileName.AUTHORIZATION_JSON_DEFAULT) + # workspace folder file - utils._write( - os.path.join(self.get_ws_dir(client, ctx), "readme.txt"), - self.template["readme_fc"], - "t", - ) + dest_dir = ctx.get_ws_dir(client) + ctx.build_from_template(dest_dir, TemplateSectionKey.CLIENT_README, ProvFileName.README_TXT) - def _check_host_name(self, host_name: str, server: Participant) -> str: + @staticmethod + def _check_host_name(host_name: str, server: Participant) -> str: if host_name == server.get_default_host(): # Use the default host - OK return "" - available_host_names = server.get_host_names() + available_host_names = server.get_prop(PropKey.HOST_NAMES) if available_host_names and host_name in available_host_names: # use alternative host name - OK return "" return f"unknown host name '{host_name}'" - def _prepare_overseer_agent(self, participant, config, role, ctx): - project = ctx["project"] + def _prepare_overseer_agent(self, participant, config, role, ctx: ProvisionContext): + project = ctx.get_project() server = project.get_server() if not server: raise ValueError(f"Missing server definition in project {project.name}") - fl_port = server.get_prop("fed_learn_port", 8002) - admin_port = server.get_prop("admin_port", 8003) + # The properties CtxKey.FED_LEARN_PORT and CtxKey.ADMIN_PORT are guaranteed to exist + fl_port = ctx.get(CtxKey.FED_LEARN_PORT) + admin_port = ctx.get(CtxKey.ADMIN_PORT) if self.overseer_agent: overseer_agent = copy.deepcopy(self.overseer_agent) if overseer_agent.get("overseer_exists", True): - if role == "server": + if role == OverseerRole.SERVER: overseer_agent["args"] = { "role": role, "overseer_end_point": ctx.get("overseer_end_point", ""), - "project": self.project_name, + "project": project.name, "name": server.name, "fl_port": str(fl_port), "admin_port": str(admin_port), @@ -315,18 +272,18 @@ def _prepare_overseer_agent(self, participant, config, role, ctx): overseer_agent["args"] = { "role": role, "overseer_end_point": ctx.get("overseer_end_point", ""), - "project": self.project_name, + "project": project.name, "name": participant.subject, } else: # do not use overseer system # Dummy overseer agent is used here - if role == "server": + if role == OverseerRole.SERVER: # the server expects the "connect_to" to be the same as its name # otherwise the host name generated by the dummy agent won't be accepted! connect_to = server.name else: - connect_to = participant.get_connect_to() + connect_to = participant.get_prop(PropKey.CONNECT_TO) if connect_to: err = self._check_host_name(connect_to, server) if err: @@ -348,10 +305,10 @@ def _prepare_overseer_agent(self, participant, config, role, ctx): overseer_agent.pop("overseer_exists", None) config["overseer_agent"] = overseer_agent - def _build_admin(self, admin, ctx): - dest_dir = self.get_kit_dir(admin, ctx) - admin_port = ctx.get("admin_port") - server_name = ctx.get("server_name") + def _build_admin(self, admin: Participant, ctx: ProvisionContext): + dest_dir = ctx.get_kit_dir(admin) + admin_port = ctx.get(CtxKey.ADMIN_PORT) + server_name = ctx.get(CtxKey.SERVER_NAME) replacement_dict = { "cn": f"{server_name}", @@ -361,53 +318,48 @@ def _build_admin(self, admin, ctx): config = self.prepare_admin_config(admin, ctx) - utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") + utils.write(os.path.join(dest_dir, ProvFileName.FED_ADMIN_JSON), json.dumps(config, indent=2), "t") + if self.docker_image: - utils._write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(self.template["docker_adm_sh"], replacement_dict), - "t", - exe=True, + ctx.build_from_template( + dest_dir, TemplateSectionKey.DOCKER_ADMIN_SH, ProvFileName.DOCKER_SH, replacement_dict, exe=True ) - utils._write( - os.path.join(dest_dir, "fl_admin.sh"), - utils.sh_replace(self.template["fl_admin_sh"], replacement_dict), - "t", + + ctx.build_from_template( + dest_dir, + TemplateSectionKey.FL_ADMIN_SH, + ProvFileName.FL_ADMIN_SH, + replacement=replacement_dict, exe=True, ) - utils._write( - os.path.join(dest_dir, "readme.txt"), - self.template["readme_am"], - "t", - ) - def prepare_admin_config(self, admin, ctx): - config = json.loads(self.template["fed_admin"]) + ctx.build_from_template(dest_dir, TemplateSectionKey.ADMIN_README, ProvFileName.README_TXT) + + def prepare_admin_config(self, admin, ctx: ProvisionContext): + config = ctx.json_load_template_section(TemplateSectionKey.FED_ADMIN) agent_config = dict() - self._prepare_overseer_agent(admin, agent_config, "admin", ctx) + self._prepare_overseer_agent(admin, agent_config, OverseerRole.ADMIN, ctx) config["admin"].update(agent_config) - provision_mode = ctx.get("provision_mode") - if provision_mode == "poc": + provision_mode = ctx.get_provision_mode() + if provision_mode == ProvisionMode.POC: # in poc mode, we change to use "local_cert" as the cred_type so that the user won't be # prompted for username when starting the admin console config["admin"]["username"] = admin.name config["admin"]["cred_type"] = "local_cert" return config - def build(self, project, ctx): - self.template = ctx.get("template") - self.project_name = project.name - self.project = project - overseer = project.get_participants_by_type("overseer") + def build(self, project: Project, ctx: ProvisionContext): + overseer = project.get_overseer() if overseer: self._build_overseer(overseer, ctx) - servers = project.get_participants_by_type("server", first_only=False) - for server in servers: + + server = project.get_server() + if server: self._build_server(server, ctx) - for client in project.get_participants_by_type("client", first_only=False): + for client in project.get_clients(): self._build_client(client, ctx) - for admin in project.get_participants_by_type("admin", first_only=False): + for admin in project.get_admins(): self._build_admin(admin, ctx) diff --git a/nvflare/lighter/impl/template.py b/nvflare/lighter/impl/template.py index e3a19e8261..a7a52e0223 100644 --- a/nvflare/lighter/impl/template.py +++ b/nvflare/lighter/impl/template.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -from nvflare.lighter.spec import Builder -from nvflare.lighter.utils import load_yaml +from nvflare.lighter.spec import Builder, Project, ProvisionContext class TemplateBuilder(Builder): @@ -24,10 +21,5 @@ class TemplateBuilder(Builder): Loads the content of the template_file into the key-value pair (template) in the build context. """ - def initialize(self, ctx): - resource_dir = self.get_resources_dir(ctx) - template_files = ctx.get("template_files") - template = dict() - for tplt_file in template_files: - template.update(load_yaml(os.path.join(resource_dir, tplt_file))) - ctx["template"] = template + def initialize(self, project: Project, ctx: ProvisionContext): + print("TemplateBuilder is obsolete!") diff --git a/nvflare/lighter/impl/workspace.py b/nvflare/lighter/impl/workspace.py index 6b203227df..811d548716 100644 --- a/nvflare/lighter/impl/workspace.py +++ b/nvflare/lighter/impl/workspace.py @@ -13,14 +13,16 @@ # limitations under the License. import os -import pathlib import shutil -from nvflare.lighter.spec import Builder, Project +import nvflare.lighter as prov +from nvflare.lighter.constants import CtxKey +from nvflare.lighter.spec import Builder, Project, ProvisionContext +from nvflare.lighter.utils import load_yaml, make_dirs class WorkspaceBuilder(Builder): - def __init__(self, template_file): + def __init__(self, template_file=None): """Manages the folder structure for provisioned projects. Sets the template_file containing scripts and configs to put into startup folders, creates directories for the @@ -43,50 +45,56 @@ def __init__(self, template_file): wip/ <--- this is only used during runtime, and will be removed when the provision command exits Args: - template_file: name(s) of template file(s) containing scripts and configs to put into startup folders + template_file: one or more template file names """ self.template_files = template_file - def _make_dir(self, dirs): - for dir in dirs: - if not os.path.exists(dir): - os.makedirs(dir) + def _build_template(self, ctx: ProvisionContext): + prov_folder = os.path.dirname(prov.__file__) + temp_folder = os.path.join(prov_folder, "templates") - def initialize(self, ctx): - workspace_dir = ctx["workspace"] + temp_files_to_load = self.template_files + if not temp_files_to_load: + # load everything + temp_files_to_load = [f for f in os.listdir(temp_folder) if os.path.isfile(f)] + elif isinstance(temp_files_to_load, str): + temp_files_to_load = [temp_files_to_load] + + template = dict() + for f in temp_files_to_load: + template.update(load_yaml(os.path.join(temp_folder, f))) + ctx.set_template(template) + + def initialize(self, project: Project, ctx: ProvisionContext): + workspace_dir = ctx.get_workspace() prod_dirs = [_ for _ in os.listdir(workspace_dir) if _.startswith("prod_")] last = -1 - for dir in prod_dirs: - stage = int(dir.split("_")[-1]) + for d in prod_dirs: + stage = int(d.split("_")[-1]) if stage > last: last = stage - ctx["last_prod_stage"] = last - if not isinstance(self.template_files, list): - self.template_files = [self.template_files] - tplt_file_list = [] - for tplt_file in self.template_files: - tplt_file_full_path = os.path.join(self.get_resources_dir(ctx), tplt_file) - file_path = pathlib.Path(__file__).parent.absolute() - shutil.copyfile(os.path.join(file_path, tplt_file), tplt_file_full_path) - tplt_file_list.append(tplt_file) - ctx["template_files"] = tplt_file_list - - def build(self, project: Project, ctx: dict): - dirs = [self.get_kit_dir(p, ctx) for p in project.participants] - self._make_dir(dirs) - dirs = [self.get_transfer_dir(p, ctx) for p in project.participants] - self._make_dir(dirs) - dirs = [self.get_local_dir(p, ctx) for p in project.participants] - self._make_dir(dirs) - - def finalize(self, ctx: dict): - if ctx["last_prod_stage"] >= 99: + ctx[CtxKey.LAST_PROD_STAGE] = last + self._build_template(ctx) + + def build(self, project: Project, ctx: ProvisionContext): + participants = project.get_all_participants() + dirs = [ctx.get_kit_dir(p) for p in participants] + make_dirs(dirs) + + dirs = [ctx.get_transfer_dir(p) for p in participants] + make_dirs(dirs) + + dirs = [ctx.get_local_dir(p) for p in participants] + make_dirs(dirs) + + def finalize(self, project: Project, ctx: ProvisionContext): + if ctx[CtxKey.LAST_PROD_STAGE] >= 99: print(f"Please clean up {ctx['workspace']} by removing prod_N folders") print("After clean-up, rerun the provision command.") else: - current_prod_stage = str(ctx["last_prod_stage"] + 1).zfill(2) - current_prod_dir = os.path.join(ctx["workspace"], f"prod_{current_prod_stage}") - shutil.move(self.get_wip_dir(ctx), current_prod_dir) - ctx.pop("wip_dir", None) + current_prod_stage = str(ctx[CtxKey.LAST_PROD_STAGE] + 1).zfill(2) + current_prod_dir = os.path.join(ctx.get_workspace(), f"prod_{current_prod_stage}") + shutil.move(ctx.get_wip_dir(), current_prod_dir) + ctx.pop(CtxKey.WIP, None) print(f"Generated results can be found under {current_prod_dir}. ") - ctx["current_prod_dir"] = current_prod_dir + ctx[CtxKey.CURRENT_PROD_DIR] = current_prod_dir diff --git a/nvflare/lighter/provision.py b/nvflare/lighter/provision.py index 8f6c3a69eb..bb9fb85131 100644 --- a/nvflare/lighter/provision.py +++ b/nvflare/lighter/provision.py @@ -22,7 +22,9 @@ from typing import Optional from nvflare.fuel.utils.class_utils import instantiate_class -from nvflare.lighter.spec import Participant, Project, Provisioner +from nvflare.lighter.constants import ParticipantType, PropKey +from nvflare.lighter.provisioner import Provisioner +from nvflare.lighter.spec import Project from nvflare.lighter.utils import load_yaml adding_client_error_msg = """ @@ -128,33 +130,50 @@ def prepare_builders(project_dict): return builders +def _must_get(participant_def: dict, key: str): + v = participant_def.get(key) + if not v: + raise ValueError(f"missing property '{key}' from participant definition") + return v + + def prepare_project(project_dict, add_user_file_path=None, add_client_file_path=None): - api_version = project_dict.get("api_version") + api_version = project_dict.get(PropKey.API_VERSION) if api_version not in [3]: raise ValueError(f"API version expected 3 but found {api_version}") - project_name = project_dict.get("name") - project_description = project_dict.get("description", "") - participants = list() - for p in project_dict.get("participants"): - participants.append(Participant(**p)) + project_name = project_dict.get(PropKey.NAME) + project_description = project_dict.get(PropKey.DESCRIPTION, "") + project = Project(name=project_name, description=project_description, props=project_dict) + participant_defs = project_dict.get("participants") + if add_user_file_path: - add_extra_users(add_user_file_path, participants) + add_extra_users(add_user_file_path, participant_defs) + if add_client_file_path: - add_extra_clients(add_client_file_path, participants) - project = Project(name=project_name, description=project_description, participants=participants) - n_servers = len(project.get_participants_by_type("server", first_only=False)) - if n_servers > 2: - raise ValueError( - f"Configuration error: Expect 2 or 1 server to be provisioned. project contains {n_servers} servers." - ) + add_extra_clients(add_client_file_path, participant_defs) + + for p in participant_defs: + participant_type = _must_get(p, "type") + name = _must_get(p, "name") + org = _must_get(p, "org") + if participant_type == ParticipantType.SERVER: + project.set_server(name, org, props=p) + elif participant_type == ParticipantType.CLIENT: + project.add_client(name, org, p) + elif participant_type == ParticipantType.ADMIN: + project.add_admin(name, org, p) + elif participant_type == ParticipantType.OVERSEER: + project.set_overseer(name, org, p) + else: + raise ValueError(f"invalid participant_type '{participant_type}'") return project -def add_extra_clients(add_client_file_path, participants): +def add_extra_clients(add_client_file_path, participant_defs): try: extra = load_yaml(add_client_file_path) extra.update({"type": "client"}) - participants.append(Participant(**extra)) + participant_defs.append(extra) except Exception as e: print("** Error during adding client **") print("The yaml file format is") @@ -162,11 +181,11 @@ def add_extra_clients(add_client_file_path, participants): exit(0) -def add_extra_users(add_user_file_path, participants): +def add_extra_users(add_user_file_path, participant_defs): try: extra = load_yaml(add_user_file_path) extra.update({"type": "admin"}) - participants.append(Participant(**extra)) + participant_defs.append(extra) except Exception: print("** Error during adding user **") print("The yaml file format is") diff --git a/nvflare/lighter/provisioner.py b/nvflare/lighter/provisioner.py new file mode 100644 index 0000000000..093cf17af2 --- /dev/null +++ b/nvflare/lighter/provisioner.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import traceback +from typing import List + +from .constants import ProvisionMode, WorkDir +from .ctx import ProvisionContext +from .entity import Project +from .spec import Builder + + +class Provisioner: + def __init__(self, root_dir: str, builders: List[Builder]): + """Workflow class that drive the provision process. + + Provisioner's tasks: + + - Maintain the provision workspace folder structure; + - Invoke Builders to generate the content of each startup kit + + ROOT_WORKSPACE Folder Structure:: + + root_workspace_dir_name: this is the root of the workspace + project_dir_name: the root dir of the project, could be named after the project + resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders + prod: stores the current set of startup kits (production) + participate_dir: stores content files generated by builders + wip: stores the set of startup kits to be created (WIP) + participate_dir: stores content files generated by builders + state: stores the persistent state of the Builders + + Args: + root_dir (str): the directory path to hold all generated or intermediate folders + builders (List[Builder]): all builders that will be called to build the content + """ + self.root_dir = root_dir + self.builders = builders + self.template = {} + + def add_template(self, template: dict): + if not isinstance(template, dict): + raise ValueError(f"template must be a dict but got {type(template)}") + self.template.update(template) + + def provision(self, project: Project, mode=None): + server = project.get_server() + if not server: + raise RuntimeError("missing server from the project") + + workspace_root_dir = os.path.join(self.root_dir, project.name) + ctx = ProvisionContext(workspace_root_dir, project) + if self.template: + ctx.set_template(self.template) + + if not mode: + mode = ProvisionMode.NORMAL + ctx.set_provision_mode(mode) + + try: + for b in self.builders: + b.initialize(project, ctx) + + # call builders! + for b in self.builders: + b.build(project, ctx) + + for b in self.builders[::-1]: + b.finalize(project, ctx) + + except Exception: + prod_dir = ctx.get(WorkDir.CURRENT_PROD_DIR) + if prod_dir: + shutil.rmtree(prod_dir) + print("Exception raised during provision. Incomplete prod_n folder removed.") + traceback.print_exc() + finally: + wip_dir = ctx.get(WorkDir.WIP) + if wip_dir: + shutil.rmtree(wip_dir) + return ctx diff --git a/nvflare/lighter/spec.py b/nvflare/lighter/spec.py index f2e5369866..aaad6a76a2 100644 --- a/nvflare/lighter/spec.py +++ b/nvflare/lighter/spec.py @@ -11,270 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import os -import shutil -import traceback from abc import ABC -from typing import List - -from nvflare.apis.utils.format_check import name_check - - -class Participant(object): - def __init__(self, type: str, name: str, org: str, enable_byoc: bool = False, *args, **kwargs): - """Class to represent a participant. - - Each participant communicates to other participant. Therefore, each participant has its - own name, type, organization it belongs to, rules and other information. - - Args: - type (str): server, client, admin or other string that builders can handle - name (str): system-wide unique name - org (str): system-wide unique organization - enable_byoc (bool, optional): whether this participant allows byoc codes to be loaded. Defaults to False. - - Raises: - ValueError: if name or org is not compliant with characters or format specification. - """ - err, reason = name_check(name, type) - if err: - raise ValueError(reason) - err, reason = name_check(org, "org") - if err: - raise ValueError(reason) - self.type = type - self.name = name - self.org = org - self.subject = name - self.enable_byoc = enable_byoc - self.props = kwargs - - # check validity of properties - host_names = self.get_host_names() - if host_names: - for n in host_names: - err, reason = name_check(n, "host_name") - if err: - raise ValueError(f"bad host name '{n}' in {self.name}: {reason}") - - self._check_host_name("connect_to") - self._check_host_name("listening_host") - self._check_host_name("default_host") - - def _check_host_name(self, prop_name: str): - host_name = self.get_prop(prop_name) - if host_name: - err, reason = name_check(host_name, "host_name") - if err: - raise ValueError(f"bad {prop_name} '{host_name}' in {self.name}: {reason}") - - def get_host_names(self): - """Get the "host_names" attribute of this participant (server). - This attribute specifies additional host names for clients to access the FL Server. - Each name could be a domain name or IP address. - - Returns: a list of host names or None if not specified. - - """ - host_names = self.get_prop("host_names") - if not host_names: - return None - - if isinstance(host_names, str): - return [host_names] - - if not isinstance(host_names, list): - raise ValueError( - f"bad host_names in {self.subject}: must be a str or list of str, but got {type(host_names)}" - ) - - return host_names - - def get_connect_to(self): - """Get the "connect_to" attribute of this participant (client). - This value is for the client to connect to the FL server. - If not specified, then the client will connect to the FL server via its default host. - - Returns: the value of "connect_to" attribute or None if not specified. - - """ - return self.get_prop("connect_to") - - def get_listening_host(self): - """Get the "listening_host" attribute of this participant (client). - When specified, the client will be listening and other parties will use the specified value to connect to - this client. This client will receive a "server" cert in its startup kit. - - Returns: the value of "listening_host" attribute or None if not specified. - - """ - return self.get_prop("listening_host") - - def get_default_host(self) -> str: - """Get the default host name for accessing this participant (server). - If the "default_host" attribute is explicitly specified, then it's the default host. - If the "default_host" attribute is not explicitly specified, then use the "name" attribute. - - Returns: a host name - - """ - h = self.get_prop("default_host") - if h: - return h - else: - return self.name - - def get_prop(self, key: str, default=None): - return self.props.get(key, default) - - -class Project(object): - def __init__(self, name: str, description: str, participants: List[Participant]): - """A container class to hold information about this FL project. - This class only holds information. It does not drive the workflow. - - Args: - name (str): the project name - description (str): brief description on this name - participants (List[Participant]): All the participants that will join this project - - Raises: - ValueError: when duplicate name found in participants list - """ - self.name = name - all_names = list() - for p in participants: - if p.name in all_names: - raise ValueError(f"Unable to add a duplicate name {p.name} into this project.") - else: - all_names.append(p.name) - self.description = description - self.participants = participants - - def get_participants_by_type(self, type, first_only=True): - found = list() - for p in self.participants: - if p.type == type: - if first_only: - return p - else: - found.append(p) - return found - - def get_server(self): - """Get the server definition. Only one server is supported! - - Returns: server participant - - """ - return self.get_participants_by_type("server", first_only=True) +from .ctx import ProvisionContext +from .entity import Project class Builder(ABC): - def initialize(self, ctx: dict): + def initialize(self, project: Project, ctx: ProvisionContext): pass - def build(self, project: Project, ctx: dict): + def build(self, project: Project, ctx: ProvisionContext): pass - def finalize(self, ctx: dict): + def finalize(self, project: Project, ctx: ProvisionContext): pass - - def get_wip_dir(self, ctx: dict): - return ctx.get("wip_dir") - - def get_ws_dir(self, participate: Participant, ctx: dict): - return os.path.join(self.get_wip_dir(ctx), participate.name) - - def get_kit_dir(self, participant: Participant, ctx: dict): - return os.path.join(self.get_ws_dir(participant, ctx), "startup") - - def get_transfer_dir(self, participant: Participant, ctx: dict): - return os.path.join(self.get_ws_dir(participant, ctx), "transfer") - - def get_local_dir(self, participant: Participant, ctx: dict): - return os.path.join(self.get_ws_dir(participant, ctx), "local") - - def get_state_dir(self, ctx: dict): - return ctx.get("state_dir") - - def get_resources_dir(self, ctx: dict): - return ctx.get("resources_dir") - - -class Provisioner(object): - def __init__(self, root_dir: str, builders: List[Builder]): - """Workflow class that drive the provision process. - - Provisioner's tasks: - - - Maintain the provision workspace folder structure; - - Invoke Builders to generate the content of each startup kit - - ROOT_WORKSPACE Folder Structure:: - - root_workspace_dir_name: this is the root of the workspace - project_dir_name: the root dir of the project, could be named after the project - resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders - prod: stores the current set of startup kits (production) - participate_dir: stores content files generated by builders - wip: stores the set of startup kits to be created (WIP) - participate_dir: stores content files generated by builders - state: stores the persistent state of the Builders - - Args: - root_dir (str): the directory path to hold all generated or intermediate folders - builders (List[Builder]): all builders that will be called to build the content - """ - self.root_dir = root_dir - self.builders = builders - self.ctx = None - - def _make_dir(self, dirs): - for dir in dirs: - if not os.path.exists(dir): - os.makedirs(dir) - - def _prepare_workspace(self, ctx): - workspace = ctx.get("workspace") - wip_dir = os.path.join(workspace, "wip") - state_dir = os.path.join(workspace, "state") - resources_dir = os.path.join(workspace, "resources") - ctx.update(dict(wip_dir=wip_dir, state_dir=state_dir, resources_dir=resources_dir)) - dirs = [workspace, resources_dir, wip_dir, state_dir] - self._make_dir(dirs) - - def provision(self, project: Project, mode=None): - # ctx = {"workspace": os.path.join(self.root_dir, project.name), "project": project} - workspace = os.path.join(self.root_dir, project.name) - ctx = {"workspace": workspace} # project is more static information while ctx is dynamic - self._prepare_workspace(ctx) - ctx["project"] = project - - if mode: - ctx["provision_mode"] = mode - - try: - for b in self.builders: - b.initialize(ctx) - - # call builders! - for b in self.builders: - b.build(project, ctx) - - for b in self.builders[::-1]: - b.finalize(ctx) - - except Exception as ex: - prod_dir = ctx.get("current_prod_dir") - if prod_dir: - shutil.rmtree(prod_dir) - print("Exception raised during provision. Incomplete prod_n folder removed.") - traceback.print_exc() - finally: - wip_dir = ctx.get("wip_dir") - if wip_dir: - shutil.rmtree(wip_dir) - return ctx diff --git a/nvflare/lighter/templates/aws_template.yml b/nvflare/lighter/templates/aws_template.yml new file mode 100644 index 0000000000..bd6a918906 --- /dev/null +++ b/nvflare/lighter/templates/aws_template.yml @@ -0,0 +1,437 @@ +aws_start_sh: | + + function find_ec2_gpu_instance_type() { + local gpucnt=0 + local gpumem=0 + if rfile=$(get_resources_file) + then + # Parse the number of GPUs and memory per GPU from the resource_manager component in local/resources.json + gpucnt=$(jq -r '.components[] | select(.id == "resource_manager") | .args.num_of_gpus' "${rfile}") + if [ ${gpucnt} -gt 0 ] + then + gpumem=$(jq -r '.components[] | select(.id == "resource_manager") | .args.mem_per_gpu_in_GiB' "${rfile}") + if [ ${gpumem} -gt 0 ] + then + gpumem=$(( ${gpumem}*1024 )) + printf " finding smallest instance type with ${gpucnt} GPUs and ${gpumem} MiB VRAM ... " + gpu_types=$(aws ec2 describe-instance-types --region ${REGION} --query 'InstanceTypes[?GpuInfo.Gpus[?Manufacturer==`NVIDIA`]].{InstanceType: InstanceType, GPU: GpuInfo.Gpus[*].{Name: Name, GpuMemoryMiB: MemoryInfo.SizeInMiB, GpuCount: Count}, Architecture: ProcessorInfo.SupportedArchitectures, VCpuCount: VCpuInfo.DefaultVCpus, MemoryMiB: MemoryInfo.SizeInMiB}' --output json) + filtered_gpu_types=$(echo ${gpu_types} | jq "[.[] | select(.GPU | any(.GpuCount == ${gpucnt} and .GpuMemoryMiB >= ${gpumem})) | select(.Architecture | index(\"${ARCH}\"))]") + smallest_gpu_type=$(echo ${filtered_gpu_types} | jq -r 'min_by(.VCpuCount).InstanceType') + if [ ${smallest_gpu_type} = null ] + then + echo "failed finding a GPU instance, EC2_TYPE unchanged." + else + echo "${smallest_gpu_type} found" + EC2_TYPE=${smallest_gpu_type} + fi + fi + fi + fi + } + + VM_NAME=nvflare_{~~type~~} + SECURITY_GROUP=nvflare_{~~type~~}_sg_$RANDOM + DEST_FOLDER=/var/tmp/cloud + KEY_PAIR=NVFlare{~~type~~}KeyPair + KEY_FILE=${KEY_PAIR}.pem + AMI_IMAGE_OWNER="099720109477" # Owner account id=Amazon + AMI_NAME="ubuntu-*-22.04-amd64-pro-server" + ARCH=x86_64 + AMI_IMAGE=ami-03c983f9003cb9cd1 # 22.04 20.04:ami-04bad3c587fe60d89 24.04:ami-0406d1fdd021121cd + EC2_TYPE=t2.small + EC2_TYPE_ARM=t4g.small + TMPDIR="${TMPDIR:-/tmp}" + LOGFILE=$(mktemp "${TMPDIR}/nvflare-aws-XXX") + + echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." + + check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary dig "Please install it first." + check_binary jq "Please install it first." + + REGION=$(aws configure get region 2>/dev/null) + : "${REGION:=us-west-2}" + : "${AWS_DEFAULT_REGION:=$REGION}" + : "${AWS_REGION:=$AWS_DEFAULT_REGION}" + REGION=${AWS_REGION} + + echo "Note: run this command first for a different AWS profile:" + echo " export AWS_PROFILE=your-profile-name." + + echo -e "\nChecking AWS identity ... \n" + aws_identity=$(aws sts get-caller-identity) + if [[ $? -ne 0 ]]; then + echo "" + exit 1 + fi + + if [ -z ${vpc_id+x} ] + then + using_default_vpc=true + else + using_default_vpc=false + fi + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container == "true" ] + then + AMI_IMAGE=ami-06b8d5099f3a8d79d + EC2_TYPE=t2.xlarge + fi + + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + fi + + if [ $useDefault == true ] + then + while true + do + prompt REGION "* Cloud EC2 region, press ENTER to accept default" "${REGION}" + if [ ${container} = false ] + then + prompt AMI_NAME "* Cloud AMI image name (use amd64 or arm64), press ENTER to accept default" "${AMI_NAME}" + printf " retrieving AMI ID for ${AMI_NAME} ... " + IMAGES=$(aws ec2 describe-images --region ${REGION} --owners ${AMI_IMAGE_OWNER} --filters "Name=name,Values=*${AMI_NAME}*" --output json) + if [ "${#IMAGES}" -lt 30 ] + then + echo -e "\nNo images found, starting over\n" + continue + fi + AMI_IMAGE=$(echo $IMAGES | jq -r '.Images | sort_by(.CreationDate) | last(.[]).ImageId') + echo "${AMI_IMAGE} found" + if [[ "$AMI_NAME" == *"arm64"* ]] + then + ARCH="arm64" + EC2_TYPE=${EC2_TYPE_ARM} + fi + find_ec2_gpu_instance_type + fi + prompt AMI_IMAGE "* Cloud AMI image, press ENTER to accept default" + prompt EC2_TYPE "* Cloud EC2 type, press ENTER to accept default" "${EC2_TYPE}" + prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n)" + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]] + then + break + fi + done + fi + + if [ $container == false ] + then + echo "If the {~~type~~} requires additional Python packages, please add them to: " + echo " ${DIR}/requirements.txt" + prompt ans "Press ENTER when it's done or no additional dependencies. " + fi + + # Check if default VPC exists + if [ $using_default_vpc == true ] + then + echo "Checking if default VPC exists" + found_default_vpc=$(aws ec2 describe-vpcs --region ${REGION} | jq '.Vpcs[] | select(.IsDefault == true)') + if [ -z "${found_default_vpc}" ] + then + echo "No default VPC found. Please create one before running this script with the following command." + echo "aws ec2 create-default-vpc --region ${REGION}" + echo "or specify your own vpc and subnet with --vpc-id and --subnet-id" + exit + else + echo "Default VPC found" + fi + else + echo "Please check the vpc-id $vpc_id and subnet-id $subnet_id are correct and they support EC2 with public IP and internet gateway with proper routing." + echo "This script will use the above info to create EC2 instance." + fi + + cd $DIR/.. + # Generate key pair + + echo "Generating key pair for VM" + + aws ec2 delete-key-pair --region ${REGION} --key-name $KEY_PAIR > /dev/null 2>&1 + rm -rf $KEY_FILE + aws ec2 create-key-pair --region ${REGION} --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE + report_status "$?" "creating key pair" + chmod 400 $KEY_FILE + + # Generate Security Group + # Try not reusing existing security group because we have to modify it for our own need. + if [ $using_default_vpc == true ] + then + sg_id=$(aws ec2 create-security-group --region ${REGION} --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) + else + sg_id=$(aws ec2 create-security-group --region ${REGION} --group-name $SECURITY_GROUP --description "NVFlare security group" --vpc-id $vpc_id | jq -r .GroupId) + fi + report_status "$?" "creating security group" + my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) + if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] + then + aws ec2 authorize-security-group-ingress --region ${REGION} --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > ${LOGFILE}.sec_grp.log + else + echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" + aws ec2 authorize-security-group-ingress --region ${REGION} --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > ${LOGFILE}.sec_grp.log + fi + {~~inbound_rule~~} + report_status "$?" "creating security group rules" + + # Start provisioning + + echo "Creating VM at region ${REGION}, may take a few minutes." + + ami_info=$(aws ec2 describe-images --region ${REGION} --image-ids $AMI_IMAGE --output json) + amidevice=$(echo $ami_info | jq -r '.Images[0].BlockDeviceMappings[0].DeviceName') + block_device_mappings=$(echo $ami_info | jq -r '.Images[0].BlockDeviceMappings') + original_size=$(echo $block_device_mappings | jq -r '.[0].Ebs.VolumeSize') + original_volume_type=$(echo $block_device_mappings | jq -r '.[0].Ebs.VolumeType') + new_size=$((original_size + 8)) # increase disk size by 8GB for nvflare, torch, etc + bdmap='[{"DeviceName":"'${amidevice}'","Ebs":{"VolumeSize":'${new_size}',"VolumeType":"'${original_volume_type}'","DeleteOnTermination":true}}]' + + if [ $using_default_vpc == true ] + then + aws ec2 run-instances --region ${REGION} --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --block-device-mappings $bdmap --security-group-ids $sg_id > vm_create.json + else + aws ec2 run-instances --region ${REGION} --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --block-device-mappings $bdmap --security-group-ids $sg_id --subnet-id $subnet_id > vm_create.json + fi + report_status "$?" "creating VM" + instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) + + longkeyfile="$(pwd)/${KEY_PAIR}_${instance_id}.pem" + cp -f ${KEY_FILE} "${longkeyfile}" + chmod 400 "${longkeyfile}" + KEY_FILE="${longkeyfile}" + + aws ec2 wait instance-status-ok --region ${REGION} --instance-ids $instance_id + aws ec2 describe-instances --region ${REGION} --instance-ids $instance_id > vm_result.json + + IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) + + echo "VM created with IP address: ${IP_ADDRESS}" + + echo "Copying files to $VM_NAME" + DEST_SITE=ubuntu@${IP_ADDRESS} + DEST=${DEST_SITE}:${DEST_FOLDER} + echo "Destination folder is ${DEST}" + scp -q -i "${KEY_FILE}" -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST + report_status "$?" "copying startup kits to VM" + + rm -f ${LOGFILE}.log + if [ $container == true ] + then + echo "Launching container with docker option ${DOCKER_OPTION}." + ssh -f -i "${KEY_FILE}" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} --network host ${DOCKER_OPTION} ${image_name} \ + /bin/bash -c \"python -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m ${DEST_FOLDER} \ + -s fed_{~~type~~}.json --set {~~cln_uid~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/nvflare.log 2>&1 + report_status "$?" "launching container" + else + echo "Installing os packages as root in $VM_NAME, may take a few minutes ... " + ssh -f -i "${KEY_FILE}" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + ' NVIDIA_OS_PKG="nvidia-driver-550-server" && sudo apt update && \ + sudo DEBIAN_FRONTEND=noninteractive apt install -y python3-dev gcc && \ + . /etc/os-release && if [ "${VERSION_ID}" \< "22.04" ]; then NVIDIA_OS_PKG="nvidia-driver-535-server"; fi && \ + if lspci | grep -i nvidia; then sudo DEBIAN_FRONTEND=noninteractive apt install -y ${NVIDIA_OS_PKG}; fi && \ + if lspci | grep -i nvidia; then sudo modprobe nvidia; fi && sleep 10 && \ + exit' >> ${LOGFILE}.log 2>&1 + report_status "$?" "installing os packages" + sleep 10 + echo "Installing user space packages in $VM_NAME, may take a few minutes ... " + ssh -f -i "${KEY_FILE}" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + ' echo "export PATH=~/.local/bin:$PATH" >> ~/.bashrc && \ + export PATH=/home/ubuntu/.local/bin:$PATH && \ + pwd && wget -q https://bootstrap.pypa.io/get-pip.py && \ + timeout 300 sh -c """until [ -f /usr/bin/gcc ]; do sleep 3; done""" && \ + python3 get-pip.py --break-system-packages && python3 -m pip install --break-system-packages nvflare && \ + touch /var/tmp/cloud/startup/requirements.txt && \ + printf "installing from requirements.txt: " && \ + cat /var/tmp/cloud/startup/requirements.txt | tr "\n" " " && \ + python3 -m pip install --break-system-packages --no-cache-dir -r /var/tmp/cloud/startup/requirements.txt && \ + (crontab -l 2>/dev/null; echo "@reboot /var/tmp/cloud/startup/start.sh >> /var/tmp/nvflare-start.log 2>&1") | crontab && \ + NVIDIAMOD="nvidia.ko.zst" && . /etc/os-release && if [ "${VERSION_ID}" \< "24.04" -a "${VERSION_ID}" \> "16.04" ]; then NVIDIAMOD="nvidia.ko"; fi && \ + if lspci | grep -i nvidia; then timeout 900 sh -c """until [ -f /lib/modules/$(uname -r)/updates/dkms/${NVIDIAMOD} ]; do sleep 3; done"""; fi && \ + sleep 60 && nohup /var/tmp/cloud/startup/start.sh && sleep 20 && \ + exit' >> ${LOGFILE}.log 2>&1 + report_status "$?" "installing user space packages" + sleep 10 + fi + + echo "System was provisioned, packages may continue to install in the background." + echo "To terminate the EC2 instance, run the following command:" + echo " aws ec2 terminate-instances --region ${REGION} --instance-ids ${instance_id}" + echo "Other resources provisioned" + echo "security group: ${SECURITY_GROUP}" + echo "key pair: ${KEY_PAIR}" + echo "review install progress:" + echo " tail -f ${LOGFILE}.log" + echo "login to instance:" + echo " ssh -i \"${KEY_FILE}\" ubuntu@${IP_ADDRESS}" + +aws_start_dsb_sh: | + VM_NAME=nvflare_dashboard + AMI_IMAGE=ami-04c7330a29e61bbca # 22.04 from https://cloud-images.ubuntu.com/locator/ec2/ + EC2_TYPE=t2.small + SECURITY_GROUP=nvflare_dashboard_sg_$RANDOM + REGION=us-west-2 + ADMIN_USERNAME=ubuntu + DEST_FOLDER=/home/${ADMIN_USERNAME} + KEY_PAIR=NVFlareDashboardKeyPair + KEY_FILE=${KEY_PAIR}.pem + + echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." + + check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary dig "Please install it first." + check_binary jq "Please install it first." + + if [ -z ${vpc_id+x} ] + then + using_default_vpc=true + else + using_default_vpc=false + fi + + echo "One initial user will be created when starting dashboard." + echo "Please enter the email address for this user." + read email + credential="${email}:$RANDOM" + + # Generate key pair + + echo "Generating key pair for VM" + + aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 + rm -rf $KEY_FILE + aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE + report_status "$?" "creating key pair" + chmod 400 $KEY_FILE + + # Check if default VPC exists + if [ $using_default_vpc == true ] + then + echo "Checking if default VPC exists" + found_default_vpc=$(aws ec2 describe-vpcs | jq '.Vpcs[] | select(.IsDefault == true)') + if [ -z "${found_default_vpc}" ] + then + echo "No default VPC found. Please create one before running this script with the following command." + echo "aws ec2 create-default-vpc" + echo "or specify your own vpc and subnet with --vpc-id and --subnet-id" + exit + else + echo "Default VPC found" + fi + else + echo "Please check the vpc-id $vpc_id and subnet-id $subnet_id are correct and they support EC2 with public IP and internet gateway with proper routing." + echo "This script will use the above info to create EC2 instance." + fi + + # Generate Security Group + # Try not reusing existing security group because we have to modify it for our own need. + if [ $using_default_vpc == true ] + then + sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) + else + sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" --vpc-id $vpc_id | jq -r .GroupId) + fi + report_status "$?" "creating security group" + echo "Security group id: ${sg_id}" + my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) + if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] + then + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log + else + echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log + fi + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 443 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log + report_status "$?" "creating security group rules" + + # Start provisioning + + echo "Creating VM at region ${REGION}, may take a few minutes." + if [ $using_default_vpc == true ] + then + aws ec2 run-instances --region ${REGION} --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json + else + aws ec2 run-instances --region ${REGION} --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id --subnet-id $subnet_id > vm_create.json + fi + report_status "$?" "creating VM" + instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) + + aws ec2 wait instance-status-ok --instance-ids $instance_id + aws ec2 describe-instances --instance-ids $instance_id > vm_result.json + + IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) + + echo "VM created with IP address: ${IP_ADDRESS}" + + echo "Installing docker engine in $VM_NAME, may take a few minutes." + DEST_SITE=${ADMIN_USERNAME}@${IP_ADDRESS} + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "$scripts" > /tmp/docker_engine.log + report_status "$?" "installing docker engine" + ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "sudo usermod -aG docker $ADMIN_USERNAME && exit" >> /tmp/docker_engine.log + report_status "$?" "installing docker engine" + + echo "Installing nvflare in $VM_NAME, may take a few minutes." + ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ + wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \ + python3 -m pip install {~~NVFLARE~~} && \ + mkdir -p ./cert && \ + exit" > /tmp/nvflare.json + report_status "$?" "installing nvflare" + + echo "Checking if certificate (web.crt) and private key (web.key) are available" + if [[ -f "web.crt" && -f "web.key" ]]; then + CERT_FOLDER=${DEST_SITE}:${DEST_FOLDER}/cert + echo "Cert folder is ${CERT_FOLDER}" + scp -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $CERT_FOLDER + report_status "$?" "copying cert/key to VM ${CERT_FOLDER} folder" + secure=true + else + echo "No web.crt and web.key found" + secure=false + fi + + echo "Starting dashboard" + ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ + python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json + + echo "Dashboard url is running at IP address ${IP_ADDRESS}, listening to port 443." + if [ "$secure" == true ] + then + echo "URL is https://${IP_ADDRESS}" + else + echo "URL is http://${IP_ADDRESS}:443" + fi + echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." + echo "Project admin credential (username:password) is ${credential} ." + echo "To terminate the EC2 instance, run the following command." + echo "aws ec2 terminate-instances --instance-ids ${instance_id}" + echo "Other resources provisioned" + echo "security group: ${SECURITY_GROUP}" + echo "key pair: ${KEY_PAIR}" diff --git a/nvflare/lighter/templates/azure_template.yml b/nvflare/lighter/templates/azure_template.yml new file mode 100644 index 0000000000..8a5c100121 --- /dev/null +++ b/nvflare/lighter/templates/azure_template.yml @@ -0,0 +1,517 @@ +azure_start_svr_header_sh: | + RESOURCE_GROUP=nvflare_rg + VM_NAME=nvflare_server + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgs + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + NIC_NAME=${VM_NAME}VMNic + SERVER_NAME={~~server_name~~} + FL_PORT=8002 + ADMIN_PORT=8003 + + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + self_dns=true + if [[ "$SERVER_NAME" = *".cloudapp.azure.com"* ]] + then + DNS_TAG=$(echo $SERVER_NAME | cut -d "." -f 1) + DERIVED_LOCATION=$(echo $SERVER_NAME | cut -d "." -f 2) + LOCATION=$DERIVED_LOCATION + self_dns=false + else + echo "Warning: ${SERVER_NAME} does not end with .cloudapp.azure.com." + echo "The cloud launch process will not create the domain name for you." + echo "Please use your own DNS to set the information." + LOCATION=westus2 + fi + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container == true ] + then + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_D8s_v3 + else + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + fi + + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + if [ $self_dns == false ] && [ $DERIVED_LOCATION != $LOCATION ] + then + echo "Server name implies LOCATION=${DERIVED_LOCATION} but the config file specifies LOCATION=${LOCATION}. Unable to continue provisioning." + exit 1 + fi + fi + + if [ $useDefault == true ] + then + while true + do + prompt VM_IMAGE "Cloud VM image, press ENTER to accept default" "${VM_IMAGE}" + prompt VM_SIZE "Cloud VM size, press ENTER to accept default" "${VM_SIZE}" + if [ $self_dns == true ] + then + prompt LOCATION "Cloud location, press ENTER to accept default" "${LOCATION}" + prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, location = ${LOCATION}, OK? (Y/n)" + else + prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n)" + fi + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi + done + fi + + if [ $container == false ] + then + echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." + prompt ans "Press ENTER when it's done or no additional dependencies." + fi + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + elif [ $useDefault == true ] + then + report_status "1" "Only one NVFL server VM and its resource group is allowed. $RESOURCE_GROUP exists and thus creating duplicate resource group" + else + echo "Users require to reuse Resource Group $RESOURCE_GROUP. This script will modify the group and may not work always." + fi + + echo "Creating Virtual Machine, will take a few minutes" + if [ $self_dns == true ] + then + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-address nvflare_server_ip \ + --public-ip-address-allocation static \ + --public-ip-sku Standard > /tmp/vm.json + else + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-address nvflare_server_ip \ + --public-ip-address-allocation static \ + --public-ip-sku Standard \ + --public-ip-address-dns-name $DNS_TAG > /tmp/vm.json + fi + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + echo "Setting up network related configuration" + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name FL_PORT \ + --nsg-name $NSG_NAME \ + --priority 1001 \ + --protocol Tcp \ + --destination-port-ranges $FL_PORT + report_status "$?" "creating network security group rule for FL port" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name ADMIN_PORT \ + --nsg-name $NSG_NAME \ + --priority 1002 \ + --protocol Tcp \ + --destination-port-ranges $ADMIN_PORT + report_status "$?" "creating network security group rule for Admin port" + +azure_start_cln_header_sh: | + RESOURCE_GROUP=nvflare_client_rg_${RANDOM}_${RANDOM} + VM_NAME=nvflare_client + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgc + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + LOCATION=westus2 + NIC_NAME=${VM_NAME}VMNic + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container == true ] + then + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_D8s_v3 + else + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + fi + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + fi + + if [ $useDefault == true ] + then + while true + do + prompt LOCATION "Cloud location, press ENTER to accept default" "${LOCATION}" + prompt VM_IMAGE "Cloud VM image, press ENTER to accept default" "${VM_IMAGE}" + prompt VM_SIZE "Cloud VM size, press ENTER to accept default" "${VM_SIZE}" + prompt ans "location = ${LOCATION}, VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) " + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi + done + fi + + if [ $container == false ] + then + echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." + prompt ans "Press ENTER when it's done or no additional dependencies." + fi + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + else + echo "Resource Group $RESOURCE_GROUP exists, will reuse it." + fi + + echo "Creating Virtual Machine, will take a few minutes" + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-sku Standard > /tmp/vm.json + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + + echo "Setting up network related configuration" + + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + +azure_start_common_sh: | + az network nic update \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name $NIC_NAME \ + --network-security-group $NSG_NAME + report_status "$?" "updating network interface card" + + echo "Copying files to $VM_NAME" + DEST=$ADMIN_USERNAME@${IP_ADDRESS}:$DEST_FOLDER + echo "Destination folder is ${DEST}" + cd $DIR/.. && sshpass -p $PASSWORD scp -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST + report_status "$?" "copying startup kits to VM" + + if [ $container == true ] + then + echo "Installing and lauching container in $VM_NAME, may take a few minutes." + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "$scripts" > /tmp/docker_engine.json + report_status "$?" "installing docker engine" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json + report_status "$?" "Setting user group" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} {~~docker_network~~} ${image_name} /bin/bash -c \"python -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m ${DEST_FOLDER} -s fed_{~~type~~}.json --set {~~cln_uid~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/vm_create.json 2>&1 + report_status "$?" "launching container" + else + echo "Installing packages in $VM_NAME, may take a few minutes." + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed nvflare && touch ${DEST_FOLDER}/startup/requirements.txt && python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && ${DEST_FOLDER}/startup/start.sh && sleep 20 && cat ${DEST_FOLDER}/log.txt" > /tmp/vm_create.json + report_status "$?" "installing packages" + fi + echo "System was provisioned" + echo "To delete the resource group (also delete the VM), run the following command" + echo "az group delete -n ${RESOURCE_GROUP}" + echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt + +azure_start_dsb_sh: | + RESOURCE_GROUP=nvflare_dashboard_rg_${RANDOM}_${RANDOM} + VM_NAME=nvflare_dashboard + VM_IMAGE=Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgc + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + LOCATION=westus2 + NIC_NAME=${VM_NAME}VMNic + + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + echo "One initial user will be created when starting dashboard." + echo "Please enter the email address for this user." + read email + credential="${email}:$RANDOM" + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + else + echo "Resource Group $RESOURCE_GROUP exists, will reuse it." + fi + + echo "Creating Virtual Machine, will take a few minutes" + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-sku Standard > /tmp/vm.json + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + report_status "$?" "extracting ip address" + + echo "Setting up network related configuration" + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name HTTPS \ + --nsg-name $NSG_NAME \ + --priority 1001 \ + --protocol Tcp \ + --destination-port-ranges 443 + report_status "$?" "creating network security group rule for HTTPS" + + az network nic update \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name $NIC_NAME \ + --network-security-group $NSG_NAME + report_status "$?" "updating network interface card" + + echo "Installing docker engine in $VM_NAME, may take a few minutes." + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "$scripts" > /tmp/docker_engine.json + report_status "$?" "installing docker engine" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json + report_status "$?" "installing docker engine" + + DEST_FOLDER=/home/${ADMIN_USERNAME} + echo "Installing nvflare in $VM_NAME, may take a few minutes." + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed {~~NVFLARE~~} && mkdir -p ${DEST_FOLDER}/cert && chown -R ${ADMIN_USERNAME} ${DEST_FOLDER}" > /tmp/nvflare.json + report_status "$?" "installing nvflare" + + echo "Checking if certificate (web.crt) and private key (web.key) are available" + if [[ -f "web.crt" && -f "web.key" ]]; then + DEST=$ADMIN_USERNAME@$IP_ADDRESS:${DEST_FOLDER}/cert + echo "Destination folder is ${DEST}" + sshpass -p $PASSWORD scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $DEST + report_status "$?" "copying cert/key to VM ${DEST} folder" + secure=true + else + echo "No web.crt and web.key found" + secure=false + fi + + echo "Starting dashboard" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "cd ${DEST_FOLDER} && python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json + + # credential=$(jq -r .value[0].message /tmp/dashboard.json | grep "Project admin") + # echo "The VM was created with user: ${ADMIN_USERNAME} and password: ${PASSWORD}" + if [ "$secure" == true ] + then + echo "URL is https://${IP_ADDRESS}" + else + echo "URL is http://${IP_ADDRESS}:443" + fi + echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." + echo "Project admin credential (username:password) is ${credential} ." + echo "To stop the dashboard, run az group delete -n ${RESOURCE_GROUP}" + echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt diff --git a/nvflare/lighter/templates/master_template.yml b/nvflare/lighter/templates/master_template.yml new file mode 100644 index 0000000000..7b1e51af88 --- /dev/null +++ b/nvflare/lighter/templates/master_template.yml @@ -0,0 +1,998 @@ +readme_am: | + ********************************* + Admin Client package + ********************************* + The package includes at least the following files: + readme.txt + rootCA.pem + client.crt + client.key + fl_admin.sh + + Please install the nvflare package by 'python3 -m pip nvflare.' This will install a set of Python codes + in your environment. After installation, you can run the fl_admin.sh file to start communicating to the admin server. + + The rootCA.pem file is pointed by "ca_cert" in fl_admin.sh. If you plan to move/copy it to a different place, + you will need to modify fl_admin.sh. The same applies to the other two files, client.crt and client.key. + + The email in your submission to participate this Federated Learning project is embedded in the CN field of client + certificate, which uniquely identifies the participant. As such, please safeguard its private key, client.key. + +readme_fc: | + ********************************* + Federated Learning Client package + ********************************* + The package includes at least the following files: + readme.txt + rootCA.pem + client.crt + client.key + fed_client.json + start.sh + sub_start.sh + stop_fl.sh + + Run start.sh to start the client. + + The rootCA.pem file is pointed by "ssl_root_cert" in fed_client.json. If you plan to move/copy it to a different place, + you will need to modify fed_client.json. The same applies to the other two files, client.crt and client.key. + + The client name in your submission to participate this Federated Learning project is embedded in the CN field of client + certificate, which uniquely identifies the participant. As such, please safeguard its private key, client.key. + +readme_fs: | + ********************************* + Federated Learning Server package + ********************************* + The package includes at least the following files: + readme.txt + rootCA.pem + server.crt + server.key + authorization.json + fed_server.json + start.sh + sub_start.sh + stop_fl.sh + signature.json + + Run start.sh to start the server. + + The rootCA.pem file is pointed by "ssl_root_cert" in fed_server.json. If you plan to move/copy it to a different place, + you will need to modify fed_server.json. The same applies to the other two files, server.crt and server.key. + + Please always safeguard the server.key. + +gunicorn_conf_py: | + bind="0.0.0.0:{~~port~~}" + cert_reqs=2 + do_handshake_on_connect=True + timeout=30 + worker_class="nvflare.ha.overseer.worker.ClientAuthWorker" + workers=1 + wsgi_app="nvflare.ha.overseer.overseer:app" + +local_client_resources: | + { + "format_version": 2, + "client": { + "retry_timeout": 30, + "compression": "Gzip" + }, + "components": [ + { + "id": "resource_manager", + "path": "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager", + "args": { + "num_of_gpus": 0, + "mem_per_gpu_in_GiB": 0 + } + }, + { + "id": "resource_consumer", + "path": "nvflare.app_common.resource_consumers.gpu_resource_consumer.GPUResourceConsumer", + "args": {} + }, + { + "id": "process_launcher", + "path": "nvflare.app_common.job_launcher.client_process_launcher.ClientProcessJobLauncher", + "args": {} + } + ] + } + +fed_client: | + { + "format_version": 2, + "servers": [ + { + "name": "spleen_segmentation", + "service": { + } + } + ], + "client": { + "ssl_private_key": "client.key", + "ssl_cert": "client.crt", + "ssl_root_cert": "rootCA.pem" + } + } + +sample_privacy: | + { + "scopes": [ + { + "name": "public", + "properties": { + "train_dataset": "/data/public/train", + "val_dataset": "/data/public/val" + }, + "task_result_filters": [ + { + "name": "AddNoiseToMinMax", + "args": { + "min_noise_level": 0.2, + "max_noise_level": 0.2 + } + }, + { + "name": "PercentilePrivacy", + "args": { + "percentile": 10, + "gamma": 0.02 + } + } + ], + "task_data_filters": [ + { + "name": "BadModelDetector" + } + ] + }, + { + "name": "private", + "properties": { + "train_dataset": "/data/private/train", + "val_dataset": "/data/private/val" + }, + "task_result_filters": [ + { + "name": "AddNoiseToMinMax", + "args": { + "min_noise_level": 0.1, + "max_noise_level": 0.1 + } + }, + { + "name": "SVTPrivacy", + "args": { + "fraction": 0.1, + "epsilon": 0.2 + } + } + ] + } + ], + "default_scope": "public" + } + +local_server_resources: | + { + "format_version": 2, + "servers": [ + { + "admin_storage": "transfer", + "max_num_clients": 100, + "heart_beat_timeout": 600, + "num_server_workers": 4, + "download_job_url": "http://download.server.com/", + "compression": "Gzip" + } + ], + "snapshot_persistor": { + "path": "nvflare.app_common.state_persistors.storage_state_persistor.StorageStatePersistor", + "args": { + "uri_root": "/", + "storage": { + "path": "nvflare.app_common.storages.filesystem_storage.FilesystemStorage", + "args": { + "root_dir": "/tmp/nvflare/snapshot-storage", + "uri_root": "/" + } + } + } + }, + "components": [ + { + "id": "job_scheduler", + "path": "nvflare.app_common.job_schedulers.job_scheduler.DefaultJobScheduler", + "args": { + "max_jobs": 4 + } + }, + { + "id": "job_manager", + "path": "nvflare.apis.impl.job_def_manager.SimpleJobDefManager", + "args": { + "uri_root": "/tmp/nvflare/jobs-storage", + "job_store_id": "job_store" + } + }, + { + "id": "job_store", + "path": "nvflare.app_common.storages.filesystem_storage.FilesystemStorage" + }, + { + "id": "process_launcher", + "path": "nvflare.app_common.job_launcher.server_process_launcher.ServerProcessJobLauncher", + "args": {} + } + ] + } + +fed_server: | + { + "format_version": 2, + "servers": [ + { + "name": "spleen_segmentation", + "service": { + "target": "localhost:8002" + }, + "admin_host": "localhost", + "admin_port": 5005, + "ssl_private_key": "server.key", + "ssl_cert": "server.crt", + "ssl_root_cert": "rootCA.pem" + } + ] + } + +fed_admin: | + { + "format_version": 1, + "admin": { + "with_file_transfer": true, + "upload_dir": "transfer", + "download_dir": "transfer", + "with_login": true, + "with_ssl": true, + "cred_type": "cert", + "client_key": "client.key", + "client_cert": "client.crt", + "ca_cert": "rootCA.pem", + "prompt": "> " + } + } + +default_authz: | + { + "format_version": "1.0", + "permissions": { + "project_admin": "any", + "org_admin": { + "submit_job": "none", + "clone_job": "none", + "manage_job": "o:submitter", + "download_job": "o:submitter", + "view": "any", + "operate": "o:site", + "shell_commands": "o:site", + "byoc": "none" + }, + "lead": { + "submit_job": "any", + "clone_job": "n:submitter", + "manage_job": "n:submitter", + "download_job": "n:submitter", + "view": "any", + "operate": "o:site", + "shell_commands": "o:site", + "byoc": "any" + }, + "member": { + "view": "any" + } + } + } + +fl_admin_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + mkdir -p $DIR/../transfer + python3 -m nvflare.fuel.hci.tools.admin -m $DIR/.. -s fed_admin.json + +log_config: | + [loggers] + keys=root + + [handlers] + keys=consoleHandler,errorFileHandler + + [formatters] + keys=fullFormatter + + [logger_root] + level=INFO + handlers=consoleHandler,errorFileHandler + + [handler_consoleHandler] + class=StreamHandler + level=DEBUG + formatter=fullFormatter + args=(sys.stdout,) + + [handler_errorFileHandler] + class=FileHandler + level=ERROR + formatter=fullFormatter + args=('error.log', 'a') + + [formatter_fullFormatter] + format=%(asctime)s - %(name)s - %(levelname)s - %(message)s + +start_ovsr_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + NVFL_OVERSEER_HEARTBEAT_TIMEOUT=10 AUTHZ_FILE=$DIR/privilege.yml gunicorn -c $DIR/gunicorn.conf.py --keyfile $DIR/overseer.key --certfile $DIR/overseer.crt --ca-certs $DIR/rootCA.pem + +start_cln_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + all_arguments="${@}" + doCloud=false + # parse arguments + while [[ $# -gt 0 ]] + do + key="$1" + case $key in + --cloud) + doCloud=true + csp=$2 + shift + ;; + esac + shift + done + + if [ $doCloud == true ] + then + case $csp in + azure) + $DIR/azure_start.sh ${all_arguments} + ;; + aws) + $DIR/aws_start.sh ${all_arguments} + ;; + *) + echo "Only on-prem or azure or aws is currently supported." + esac + else + $DIR/sub_start.sh & + fi + +start_svr_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + all_arguments="${@}" + doCloud=false + ha_mode={~~ha_mode~~} + # parse arguments + while [[ $# -gt 0 ]] + do + key="$1" + case $key in + --cloud) + if [ $ha_mode == false ] + then + doCloud=true + csp=$2 + shift + else + echo "Cloud launch does not support NVFlare HA mode." + exit 1 + fi + ;; + esac + shift + done + + if [ $doCloud == true ] + then + case $csp in + azure) + $DIR/azure_start.sh ${all_arguments} + ;; + aws) + $DIR/aws_start.sh ${all_arguments} + ;; + *) + echo "Only on-prem or azure or aws is currently supported." + esac + else + $DIR/sub_start.sh & + fi + +stop_fl_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + echo "Please use FL admin console to issue shutdown client command to properly stop this client." + echo "This stop_fl.sh script can only be used as the last resort to stop this client." + echo "It will not properly deregister the client to the server." + echo "The client status on the server after this shell script will be incorrect." + read -n1 -p "Would you like to continue (y/N)? " answer + case $answer in + y|Y) + echo + echo "Shutdown request created. Wait for local FL process to shutdown." + touch $DIR/../shutdown.fl + ;; + n|N|*) + echo + echo "Not continue" + ;; + esac + +sub_start_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + echo "WORKSPACE set to $DIR/.." + mkdir -p $DIR/../transfer + export PYTHONPATH=/local/custom:$PYTHONPATH + echo "PYTHONPATH is $PYTHONPATH" + + SECONDS=0 + lst=-400 + restart_count=0 + start_fl() { + if [[ $(( $SECONDS - $lst )) -lt 300 ]]; then + ((restart_count++)) + else + restart_count=0 + fi + if [[ $(($SECONDS - $lst )) -lt 300 && $restart_count -ge 5 ]]; then + echo "System is in trouble and unable to start the task!!!!!" + rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl $DIR/../daemon_pid.fl + exit + fi + lst=$SECONDS + ((python3 -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m $DIR/.. -s fed_{~~type~~}.json --set secure_train=true {~~cln_uid~~} org={~~org_name~~} config_folder={~~config_folder~~} 2>&1 & echo $! >&3 ) 3>$DIR/../pid.fl ) + pid=`cat $DIR/../pid.fl` + echo "new pid ${pid}" + } + + stop_fl() { + if [[ ! -f "$DIR/../pid.fl" ]]; then + echo "No pid.fl. No need to kill process." + return + fi + pid=`cat $DIR/../pid.fl` + sleep 5 + kill -0 ${pid} 2> /dev/null 1>&2 + if [[ $? -ne 0 ]]; then + echo "Process already terminated" + return + fi + kill -9 $pid + rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl 2> /dev/null 1>&2 + } + + if [[ -f "$DIR/../daemon_pid.fl" ]]; then + dpid=`cat $DIR/../daemon_pid.fl` + kill -0 ${dpid} 2> /dev/null 1>&2 + if [[ $? -eq 0 ]]; then + echo "There seems to be one instance, pid=$dpid, running." + echo "If you are sure it's not the case, please kill process $dpid and then remove daemon_pid.fl in $DIR/.." + exit + fi + rm -f $DIR/../daemon_pid.fl + fi + + echo $BASHPID > $DIR/../daemon_pid.fl + + while true + do + sleep 5 + if [[ ! -f "$DIR/../pid.fl" ]]; then + echo "start fl because of no pid.fl" + start_fl + continue + fi + pid=`cat $DIR/../pid.fl` + kill -0 ${pid} 2> /dev/null 1>&2 + if [[ $? -ne 0 ]]; then + if [[ -f "$DIR/../shutdown.fl" ]]; then + echo "Gracefully shutdown." + break + fi + echo "start fl because process of ${pid} does not exist" + start_fl + continue + fi + if [[ -f "$DIR/../shutdown.fl" ]]; then + echo "About to shutdown." + stop_fl + break + fi + if [[ -f "$DIR/../restart.fl" ]]; then + echo "About to restart." + stop_fl + fi + done + + rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl $DIR/../daemon_pid.fl + +docker_cln_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + # docker run script for FL client + # local data directory + : ${MY_DATA_DIR:="/home/flclient/data"} + # The syntax above is to set MY_DATA_DIR to /home/flcient/data if this + # environment variable is not set previously. + # Therefore, users can set their own MY_DATA_DIR with + # export MY_DATA_DIR=$SOME_DIRECTORY + # before running docker.sh + + # for all gpus use line below + #GPU2USE='--gpus=all' + # for 2 gpus use line below + #GPU2USE='--gpus=2' + # for specific gpus as gpu#0 and gpu#2 use line below + #GPU2USE='--gpus="device=0,2"' + # to use host network, use line below + NETARG="--net=host" + # FL clients do not need to open ports, so the following line is not needed. + #NETARG="-p 443:443 -p 8003:8003" + DOCKER_IMAGE={~~docker_image~~} + echo "Starting docker with $DOCKER_IMAGE" + mode="${1:--r}" + if [ $mode = "-d" ] + then + docker run -d --rm --name={~~client_name~~} $GPU2USE -u $(id -u):$(id -g) \ + -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v $DIR/..:/workspace/ \ + -v $MY_DATA_DIR:/data/:ro -w /workspace/ --ipc=host $NETARG $DOCKER_IMAGE \ + /bin/bash -c "python -u -m nvflare.private.fed.app.client.client_train -m /workspace -s fed_client.json --set uid={~~client_name~~} secure_train=true config_folder=config org={~~org_name~~}" + else + docker run --rm -it --name={~~client_name~~} $GPU2USE -u $(id -u):$(id -g) \ + -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v $DIR/..:/workspace/ \ + -v $MY_DATA_DIR:/data/:ro -w /workspace/ --ipc=host $NETARG $DOCKER_IMAGE /bin/bash + fi + +docker_svr_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + # docker run script for FL server + # to use host network, use line below + NETARG="--net=host" + # or to expose specific ports, use line below + #NETARG="-p {~~admin_port~~}:{~~admin_port~~} -p {~~fed_learn_port~~}:{~~fed_learn_port~~}" + DOCKER_IMAGE={~~docker_image~~} + echo "Starting docker with $DOCKER_IMAGE" + svr_name="${SVR_NAME:-flserver}" + mode="${1:-r}" + if [ $mode = "-d" ] + then + docker run -d --rm --name=$svr_name -v $DIR/..:/workspace/ -w /workspace \ + --ipc=host $NETARG $DOCKER_IMAGE /bin/bash -c \ + "python -u -m nvflare.private.fed.app.server.server_train -m /workspace -s fed_server.json --set secure_train=true config_folder=config org={~~org_name~~}" + else + docker run --rm -it --name=$svr_name -v $DIR/..:/workspace/ -w /workspace/ --ipc=host $NETARG $DOCKER_IMAGE /bin/bash + fi + +docker_adm_sh: | + #!/usr/bin/env bash + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + # docker run script for FL admin + # to use host network, use line below + #NETARG="--net=host" + # Admin clients do not need to open ports, so the following line is not needed. + #NETARG="-p 8003:8003" + DOCKER_IMAGE={~~docker_image~~} + echo "Starting docker with $DOCKER_IMAGE" + docker run --rm -it --name=fladmin -v $DIR/..:/workspace/ -w /workspace/ $DOCKER_IMAGE /bin/bash + +compose_yaml: | + services: + __overseer__: + build: ./nvflare + image: ${IMAGE_NAME} + volumes: + - .:/workspace + command: ["${WORKSPACE}/startup/start.sh"] + ports: + - "8443:8443" + + __flserver__: + image: ${IMAGE_NAME} + ports: + - "8002:8002" + - "8003:8003" + volumes: + - .:/workspace + - nvflare_svc_persist:/tmp/nvflare/ + command: ["${PYTHON_EXECUTABLE}", + "-u", + "-m", + "nvflare.private.fed.app.server.server_train", + "-m", + "${WORKSPACE}", + "-s", + "fed_server.json", + "--set", + "secure_train=true", + "config_folder=config", + "org=__org_name__", + ] + + __flclient__: + image: ${IMAGE_NAME} + volumes: + - .:/workspace + command: ["${PYTHON_EXECUTABLE}", + "-u", + "-m", + "nvflare.private.fed.app.client.client_train", + "-m", + "${WORKSPACE}", + "-s", + "fed_client.json", + "--set", + "secure_train=true", + "uid=__flclient__", + "org=__org_name__", + "config_folder=config", + ] + + volumes: + nvflare_svc_persist: + +dockerfile: | + RUN pip install -U pip + RUN pip install nvflare + COPY requirements.txt requirements.txt + RUN pip install -r requirements.txt + +helm_chart_chart: | + apiVersion: v2 + name: nvflare + description: A Helm chart for NVFlare overseer and servers + type: application + version: 0.1.0 + appVersion: "2.2.0" + +helm_chart_service_overseer: | + apiVersion: v1 + kind: Service + metadata: + name: overseer + spec: + selector: + system: overseer + ports: + - protocol: TCP + port: 8443 + targetPort: overseer-port + +helm_chart_service_server: | + apiVersion: v1 + kind: Service + metadata: + name: server + labels: + system: server + spec: + selector: + system: server + ports: + - name: fl-port + protocol: TCP + port: 8002 + targetPort: fl-port + - name: admin-port + protocol: TCP + port: 8003 + targetPort: admin-port + +helm_chart_deployment_overseer: | + apiVersion: apps/v1 + kind: Deployment + metadata: + name: overseer + labels: + system: overseer + spec: + replicas: 1 + selector: + matchLabels: + system: overseer + template: + metadata: + labels: + system: overseer + spec: + volumes: + - name: workspace + hostPath: + path: + type: Directory + containers: + - name: overseer + image: nvflare-min:2.2.0 + imagePullPolicy: IfNotPresent + volumeMounts: + - name: workspace + mountPath: /workspace + command: ["/workspace/overseer/startup/start.sh"] + ports: + - name: overseer-port + containerPort: 8443 + protocol: TCP +helm_chart_deployment_server: | + apiVersion: apps/v1 + kind: Deployment + metadata: + name: server + labels: + system: server + spec: + replicas: 1 + selector: + matchLabels: + system: server + template: + metadata: + labels: + system: server + spec: + volumes: + - name: workspace + hostPath: + path: + type: Directory + - name: persist + hostPath: + path: /tmp/nvflare + type: Directory + containers: + - name: server1 + image: nvflare-min:2.2.0 + imagePullPolicy: IfNotPresent + volumeMounts: + - name: workspace + mountPath: /workspace + - name: persist + mountPath: /tmp/nvflare + command: ["/usr/local/bin/python3"] + args: + [ + "-u", + "-m", + "nvflare.private.fed.app.server.server_train", + "-m", + "/workspace/server", + "-s", + "fed_server.json", + "--set", + "secure_train=true", + "config_folder=config", + "org=__org_name__", + + ] + ports: + - containerPort: 8002 + protocol: TCP + - containerPort: 8003 + protocol: TCP +helm_chart_values: | + workspace: /home/nvflare + persist: /home/nvflare + + +cloud_script_header: | + #!/usr/bin/env bash + + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + function report_status() { + status="$1" + if [ "${status}" -ne 0 ] + then + echo "$2 failed" + exit "${status}" + fi + } + + function check_binary() { + echo -n "Checking if $1 exists. => " + if ! command -v $1 &> /dev/null + then + echo "not found. $2" + exit 1 + else + echo "found" + fi + } + + function prompt() { + # usage: prompt NEW_VAR "Prompt message" ["${PROMPT_VALUE}"] + local __resultvar=$1 + local __prompt=$2 + local __default=${3:-} + local __result + if [[ ${BASH_VERSINFO[0]} -ge 4 && -n "$__default" ]] + then + read -e -i "$__default" -p "$__prompt: " __result + else + __default=${3:-${!__resultvar:-}} + if [[ -n $__default ]] + then + printf "%s [%s]: " "$__prompt" "$__default" + else + printf "%s: " "$__prompt" + fi + IFS= read -r __result + if [[ -z "$__result" && -n "$__default" ]] + then + __result="$__default" + fi + fi + eval $__resultvar="'$__result'" + } + + function get_resources_file() { + local rfile="${DIR}/../local/resources.json" + if [ -f "${rfile}" ] + then + echo "${rfile}" + elif [ -f "${rfile}.default" ] + then + echo "${rfile}.default" + else + echo "" + exit 1 + fi + } + + # parse arguments + while [[ $# -gt 0 ]] + do + key="$1" + case $key in + --config) + config_file=$2 + shift + ;; + --image) + image_name=$2 + shift + ;; + --vpc-id) + vpc_id=$2 + shift + ;; + --subnet-id) + subnet_id=$2 + shift + ;; + esac + shift + done + +adm_notebook: | + { + "cells": [ + { + "cell_type": "markdown", + "id": "b758695b", + "metadata": {}, + "source": [ + "# System Info" + ] + }, + { + "cell_type": "markdown", + "id": "9f7cd9e6", + "metadata": {}, + "source": [ + "In this notebook, System Info is checked with the FLARE API." + ] + }, + { + "cell_type": "markdown", + "id": "ea50ba28", + "metadata": {}, + "source": [ + "#### 1. Connect to the FL System with the FLARE API\n", + "\n", + "Use `new_secure_session()` to initiate a session connecting to the FL Server with the FLARE API. The necessary arguments are the username of the admin user you are using and the corresponding startup kit location.\n", + "\n", + "In the code example below, we get the `admin_user_dir` by concatenating the workspace root with the default directories that are created if you provision a project with a given project name. You can change the values to what applies to your system if needed.\n", + "\n", + "Note that if debug mode is not enabled, there is no output after initiating a session successfully, so instead we print the output of `get_system_info()`. If you are unable to connect and initiate a session, make sure that your FL Server is running and that the configurations are correct with the right path to the admin startup kit directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0166942d", + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Run this pip install if NVFlare is not installed in your Jupyter Notebook\n", + "\n", + "# !python3 -m pip install -U nvflare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3dbde69", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from nvflare.fuel.flare_api.flare_api import new_secure_session\n", + "\n", + "username = \"{~~admin_name~~}\" # change this to your own username\n", + "\n", + "sess = new_secure_session(\n", + " username=username,\n", + " startup_kit_location=os.getcwd()\n", + ")\n", + "print(sess.get_system_info())" + ] + }, + { + "cell_type": "markdown", + "id": "31ccb6a6", + "metadata": {}, + "source": [ + "### 2. Shutting Down the FL System\n", + "\n", + "As of now, there is no specific FLARE API command for shutting down the FL system, but the FLARE API can use the `do_command()` function of the underlying AdminAPI to submit any commands that the FLARE Console supports including shutdown commands to the clients and server:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0d8aa9c", + "metadata": {}, + "outputs": [], + "source": [ + "print(sess.api.do_command(\"shutdown client\"))\n", + "print(sess.api.do_command(\"shutdown server\"))\n", + "\n", + "sess.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 + } + diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index a45b1cad45..2c26d9f016 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -273,6 +273,12 @@ def update_storage_locations( outfile.write(json_object) +def make_dirs(dirs): + for d in dirs: + if not os.path.exists(d): + os.makedirs(d) + + def _write(file_full_path, content, mode, exe=False): mode = mode + "w" with open(file_full_path, mode) as f: @@ -281,28 +287,32 @@ def _write(file_full_path, content, mode, exe=False): os.chmod(file_full_path, 0o755) +def write(file_full_path, content, mode, exe=False): + _write(file_full_path, content, mode, exe) + + def _write_common(type, dest_dir, template, tplt, replacement_dict, config): mapping = {"server": "svr", "client": "cln"} - _write(os.path.join(dest_dir, f"fed_{type}.json"), json.dumps(config, indent=2), "t") - _write( + write(os.path.join(dest_dir, f"fed_{type}.json"), json.dumps(config, indent=2), "t") + write( os.path.join(dest_dir, "docker.sh"), sh_replace(template[f"docker_{mapping[type]}_sh"], replacement_dict), "t", exe=True, ) - _write( + write( os.path.join(dest_dir, "start.sh"), sh_replace(template[f"start_{mapping[type]}_sh"], replacement_dict), "t", exe=True, ) - _write( + write( os.path.join(dest_dir, "sub_start.sh"), sh_replace(tplt.get_sub_start_sh(), replacement_dict), "t", exe=True, ) - _write( + write( os.path.join(dest_dir, "stop_fl.sh"), template["stop_fl_sh"], "t", @@ -311,17 +321,17 @@ def _write_common(type, dest_dir, template, tplt, replacement_dict, config): def _write_local(type, dest_dir, template, capacity=""): - _write( + write( os.path.join(dest_dir, "log.config.default"), template["log_config"], "t", ) - _write( + write( os.path.join(dest_dir, "privacy.json.sample"), template["sample_privacy"], "t", ) - _write( + write( os.path.join(dest_dir, "authorization.json.default"), template["default_authz"], "t", @@ -334,7 +344,7 @@ def _write_local(type, dest_dir, template, capacity=""): if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]: component["args"] = json.loads(capacity) break - _write( + write( os.path.join(dest_dir, "resources.json.default"), json.dumps(resources, indent=2), "t", @@ -342,6 +352,6 @@ def _write_local(type, dest_dir, template, capacity=""): def _write_pki(type, dest_dir, cert_pair, root_cert): - _write(os.path.join(dest_dir, f"{type}.crt"), cert_pair.ser_cert, "b", exe=False) - _write(os.path.join(dest_dir, f"{type}.key"), cert_pair.ser_pri_key, "b", exe=False) - _write(os.path.join(dest_dir, "rootCA.pem"), root_cert, "b", exe=False) + write(os.path.join(dest_dir, f"{type}.crt"), cert_pair.ser_cert, "b", exe=False) + write(os.path.join(dest_dir, f"{type}.key"), cert_pair.ser_pri_key, "b", exe=False) + write(os.path.join(dest_dir, "rootCA.pem"), root_cert, "b", exe=False) diff --git a/nvflare/tool/poc/poc_commands.py b/nvflare/tool/poc/poc_commands.py index 0acc690a5c..5b88cbd927 100644 --- a/nvflare/tool/poc/poc_commands.py +++ b/nvflare/tool/poc/poc_commands.py @@ -30,8 +30,9 @@ from nvflare.fuel.utils.class_utils import instantiate_class from nvflare.fuel.utils.config import ConfigFormat from nvflare.fuel.utils.gpu_utils import get_host_gpu_ids +from nvflare.lighter.constants import ProvisionMode from nvflare.lighter.provision import gen_default_project_config, prepare_project -from nvflare.lighter.spec import Provisioner +from nvflare.lighter.provisioner import Provisioner from nvflare.lighter.utils import ( load_yaml, update_project_server_name_config, @@ -308,7 +309,7 @@ def local_provision( builders = prepare_builders(project_config) provisioner = Provisioner(workspace, builders) - provisioner.provision(project, mode="poc") + provisioner.provision(project, mode=ProvisionMode.POC) return project_config, service_config diff --git a/tests/unit_test/lighter/participant_test.py b/tests/unit_test/lighter/participant_test.py index 92d8132f46..cb0b3d8339 100644 --- a/tests/unit_test/lighter/participant_test.py +++ b/tests/unit_test/lighter/participant_test.py @@ -14,7 +14,7 @@ import pytest -from nvflare.lighter.spec import Participant +from nvflare.lighter.entity import Participant class TestParticipant: diff --git a/tests/unit_test/lighter/project_test.py b/tests/unit_test/lighter/project_test.py index ebbe7e775d..f462c954ec 100644 --- a/tests/unit_test/lighter/project_test.py +++ b/tests/unit_test/lighter/project_test.py @@ -14,31 +14,74 @@ import pytest -from nvflare.lighter.spec import Participant, Project +from nvflare.lighter.entity import Participant, Project -def create_participants(type, number, org, name): +def create_participants(type, number, org, name, props=None): p_list = list() for i in range(number): name = f"{name[:2]}{i}{name[2:]}" - p_list.append(Participant(name=name, org=org, type=type)) + p_list.append(Participant(name=name, org=org, type=type, props=props)) return p_list class TestProject: - def test_invalid_project(self): - p1 = create_participants("server", 3, "org", "server") - p2 = create_participants("server", 3, "org", "server") - p = p1 + p2 - with pytest.raises(ValueError, match=r".* se0rver .*"): + def test_single_server(self): + p1 = Participant(name="server1", org="org", type="server") + p2 = Participant(name="server2", org="org", type="server") + with pytest.raises(ValueError, match=r".* already has a server defined"): + _ = Project("name", "description", [p1, p2]) + + def test_single_overseer(self): + p1 = Participant(name="name1", org="org", type="overseer") + p2 = Participant(name="name2", org="org", type="overseer") + with pytest.raises(ValueError, match=r".* already has an overseer defined"): + _ = Project("name", "description", [p1, p2]) + + def test_get_clients(self): + p = create_participants(type="client", number=3, org="org", name="name") + prj = Project("name", "description", p) + c = prj.get_clients() + assert len(c) == len(p) + assert all(c[i].name == p[i].name and c[i].org == p[i].org for i in range(len(p))) + + def test_get_admins(self): + p = create_participants( + type="admin", number=3, org="org", name="admin@nvidia.com", props={"role": "project_admin"} + ) + prj = Project("name", "description", p) + c = prj.get_admins() + assert len(c) == len(p) + assert all(c[i].name == p[i].name and c[i].org == p[i].org for i in range(len(p))) + + def test_admin_role_required(self): + p = create_participants(type="admin", number=3, org="org", name="admin@nvidia.com") + with pytest.raises(ValueError, match=r"missing role *."): _ = Project("name", "description", p) + def test_bad_admin_role(self): + with pytest.raises(ValueError, match=r"bad value for role *."): + _ = create_participants( + type="admin", number=3, org="org", name="admin@nvidia.com", props={"role": "invalid"} + ) + @pytest.mark.parametrize( - "p_type,name", - [("server", "server"), ("client", "client"), ("admin", "admin@abc.com"), ("overseer", "overseer")], + "type1,type2", + [ + ("client", "client"), + ("server", "client"), + ("admin", "admin"), + ], ) - def test_get_participants_by_type(self, p_type, name): - p = create_participants(type=p_type, number=3, org="org", name=name) - prj = Project("name", "description", p) - assert prj.get_participants_by_type(p_type) == p[0] - assert prj.get_participants_by_type(p_type, first_only=False) == p + def test_dup_names(self, type1, type2): + if type1 == "admin": + name = "name@xyz.com" + props = {"role": "project_admin"} + else: + name = "name" + props = None + + p1 = Participant(name=name, org="org", type=type1, props=props) + p2 = Participant(name=name, org="org", type=type2, props=props) + with pytest.raises(ValueError, match=r".* already has a participant with the name *."): + _ = Project("name", "description", [p1, p2]) diff --git a/tests/unit_test/lighter/provision_test.py b/tests/unit_test/lighter/provision_test.py index 5d781e7250..82a4951de1 100644 --- a/tests/unit_test/lighter/provision_test.py +++ b/tests/unit_test/lighter/provision_test.py @@ -34,7 +34,5 @@ def test_prepare_project(self): ], } - with pytest.raises( - ValueError, match="Configuration error: Expect 2 or 1 server to be provisioned. project contains 3 servers." - ): + with pytest.raises(ValueError, match=".* already has a server defined"): prepare_project(project_dict=project_config)