Skip to content

Commit

Permalink
reorg file structure
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Dec 5, 2024
1 parent 0534ee1 commit 947f8e0
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 477 deletions.
110 changes: 110 additions & 0 deletions nvflare/lighter/ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

import yaml

from nvflare.lighter import utils

from .constants import CtxKey, PropKey, ProvisionMode
from .entity import Entity, Project


class ProvisionContext(dict):
def __init__(self, workspace_root_dir: str, project: Project):
super().__init__()
self[CtxKey.WORKSPACE] = workspace_root_dir

wip_dir = os.path.join(workspace_root_dir, "wip")
state_dir = os.path.join(workspace_root_dir, "state")
resources_dir = os.path.join(workspace_root_dir, "resources")
self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir})
dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir]
utils.make_dirs(dirs)

# set commonly used data into ctx
self[CtxKey.PROJECT] = project

server = project.get_server()
admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003)
self[CtxKey.ADMIN_PORT] = admin_port
fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002)
self[CtxKey.FED_LEARN_PORT] = fed_learn_port
self[CtxKey.SERVER_NAME] = server.name

def get_project(self):
return self.get(CtxKey.PROJECT)

def set_template(self, template: dict):
self[CtxKey.TEMPLATE] = template

def get_template(self):
return self.get(CtxKey.TEMPLATE)

def get_template_section(self, section_key: str):
template = self.get_template()
if not template:
raise RuntimeError("template is not available")

section = template.get(section_key)
if not section:
raise RuntimeError(f"missing section {section} in template")

return section

def set_provision_mode(self, mode: str):
valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL]
if mode not in valid_modes:
raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}")
self[CtxKey.PROVISION_MODE] = mode

def get_provision_mode(self):
return self.get(CtxKey.PROVISION_MODE)

def get_wip_dir(self):
return self.get(CtxKey.WIP)

def get_ws_dir(self, entity: Entity):
return os.path.join(self.get_wip_dir(), entity.name)

def get_kit_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "startup")

def get_transfer_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "transfer")

def get_local_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "local")

def get_state_dir(self):
return self.get(CtxKey.STATE)

def get_resources_dir(self):
return self.get(CtxKey.RESOURCES)

def get_workspace(self):
return self.get(CtxKey.WORKSPACE)

def yaml_load_template_section(self, section_key: str):
return yaml.safe_load(self.get_template_section(section_key))

def json_load_template_section(self, section_key: str):
return json.loads(self.get_template_section(section_key))

def build_from_template(self, dest_dir: str, temp_section: str, file_name, replacement=None, mode="t", exe=False):
section = self.get_template_section(temp_section)
if replacement:
section = utils.sh_replace(section, replacement)
utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe)
261 changes: 261 additions & 0 deletions nvflare/lighter/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.utils.format_check import name_check

from .constants import AdminRole, ParticipantType, PropKey


def _check_host_name(scope: str, prop_key: str, value):
err, reason = name_check(value, "host_name")
if err:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: {reason}")


def _check_host_names(scope: str, prop_key: str, value):
if isinstance(value, str):
_check_host_name(scope, prop_key, value)
elif isinstance(value, list):
for v in value:
_check_host_name(scope, prop_key, v)


def _check_admin_role(scope: str, prop_key: str, value):
valid_roles = [AdminRole.PROJECT_ADMIN, AdminRole.ORG_ADMIN, AdminRole.LEAD, AdminRole.MEMBER]
if value not in valid_roles:
raise ValueError(f"bad value for {prop_key} '{value}' in {scope}: must be one of {valid_roles}")


# validator functions for common properties
# Validator function must follow this signature:
# func(scope: str, prop_key: str, value)
_PROP_VALIDATORS = {
PropKey.HOST_NAMES: _check_host_names,
PropKey.CONNECT_TO: _check_host_name,
PropKey.LISTENING_HOST: _check_host_name,
PropKey.DEFAULT_HOST: _check_host_name,
PropKey.ROLE: _check_admin_role,
}


class Entity:
def __init__(self, scope: str, name: str, props: dict, parent=None):
if not props:
props = {}

for k, v in props.items():
validator = _PROP_VALIDATORS.get(k)
if validator is not None:
validator(scope, k, v)
self.name = name
self.props = props
self.parent = parent

def get_prop(self, key: str, default=None):
return self.props.get(key, default)

def get_prop_fb(self, key: str, fb_key=None, default=None):
"""Get property value with fallback.
If I have the property, then return it.
If not, I return the fallback property of my parent. If I don't have parent, return default.
Args:
key: key of the property
fb_key: key of the fallback property.
default: value to return if no one has the property
Returns: property value
"""
value = self.get_prop(key)
if value:
return value
elif not self.parent:
return default
else:
# get the value from the parent
if not fb_key:
fb_key = key
return self.parent.get_prop(fb_key, default)


class Participant(Entity):
def __init__(self, type: str, name: str, org: str, props: dict = None, project: Entity = None):
"""Class to represent a participant.
Each participant communicates to other participant. Therefore, each participant has its
own name, type, organization it belongs to, rules and other information.
Args:
type (str): server, client, admin or other string that builders can handle
name (str): system-wide unique name
org (str): system-wide unique organization
props (dict): properties
project: the project that the participant belongs to
Raises:
ValueError: if name or org is not compliant with characters or format specification.
"""
Entity.__init__(self, f"{type}::{name}", name, props, parent=project)

err, reason = name_check(name, type)
if err:
raise ValueError(reason)

err, reason = name_check(org, "org")
if err:
raise ValueError(reason)

self.type = type
self.org = org
self.subject = name

def get_default_host(self) -> str:
"""Get the default host name for accessing this participant (server).
If the "default_host" attribute is explicitly specified, then it's the default host.
If the "default_host" attribute is not explicitly specified, then use the "name" attribute.
Returns: a host name
"""
h = self.get_prop(PropKey.DEFAULT_HOST)
if h:
return h
else:
return self.name


class Project(Entity):
def __init__(
self,
name: str,
description: str,
participants=None,
props: dict = None,
serialized_root_cert=None,
serialized_root_private_key=None,
):
"""A container class to hold information about this FL project.
This class only holds information. It does not drive the workflow.
Args:
name (str): the project name
description (str): brief description on this name
participants: if provided, list of participants of the project
props: properties of the project
serialized_root_cert: if provided, the root cert to be used for the project
serialized_root_private_key: if provided, the root private key for signing certs of sites and admins
Raises:
ValueError: when participant criteria is violated
"""
Entity.__init__(self, "project", name, props)

if serialized_root_cert:
if not serialized_root_private_key:
raise ValueError("missing serialized_root_private_key while serialized_root_cert is provided")

self.description = description
self.serialized_root_cert = serialized_root_cert
self.serialized_root_private_key = serialized_root_private_key
self.server = None
self.overseer = None
self.clients = []
self.admins = []
self.all_names = {}

if participants:
if not isinstance(participants, list):
raise ValueError(f"participants must be a list of Participant but got {type(participants)}")

for p in participants:
if not isinstance(p, Participant):
raise ValueError(f"bad item in participants: must be Participant but got {type(p)}")

if p.type == ParticipantType.SERVER:
self.set_server(p.name, p.org, p.props)
elif p.type == ParticipantType.ADMIN:
self.add_admin(p.name, p.org, p.props)
elif p.type == ParticipantType.CLIENT:
self.add_client(p.name, p.org, p.props)
elif p.type == ParticipantType.OVERSEER:
self.set_overseer(p.name, p.org, p.props)
else:
raise ValueError(f"invalid value for ParticipantType: {p.type}")

def _check_unique_name(self, name: str):
if name in self.all_names:
raise ValueError(f"the project {self.name} already has a participant with the name '{name}'")

def set_server(self, name: str, org: str, props: dict):
if self.server:
raise ValueError(f"project {self.name} already has a server defined")
self._check_unique_name(name)
self.server = Participant(ParticipantType.SERVER, name, org, props, self)
self.all_names[name] = True

def get_server(self):
"""Get the server definition. Only one server is supported!
Returns: server participant
"""
return self.server

def set_overseer(self, name: str, org: str, props: dict):
if self.overseer:
raise ValueError(f"project {self.name} already has an overseer defined")
self._check_unique_name(name)
self.overseer = Participant(ParticipantType.OVERSEER, name, org, props, self)
self.all_names[name] = True

def get_overseer(self):
"""Get the overseer definition. Only one overseer is supported!
Returns: overseer participant
"""
return self.overseer

def add_client(self, name: str, org: str, props: dict):
self._check_unique_name(name)
self.clients.append(Participant(ParticipantType.CLIENT, name, org, props, self))
self.all_names[name] = True

def get_clients(self):
return self.clients

def add_admin(self, name: str, org: str, props: dict):
self._check_unique_name(name)
admin = Participant(ParticipantType.ADMIN, name, org, props, self)
role = admin.get_prop(PropKey.ROLE)
if not role:
raise ValueError(f"missing role in admin '{name}'")
self.admins.append(admin)
self.all_names[name] = True

def get_admins(self):
return self.admins

def get_all_participants(self):
result = []
if self.server:
result.append(self.server)

if self.overseer:
result.append(self.overseer)

result.extend(self.clients)
result.extend(self.admins)
return result
6 changes: 4 additions & 2 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from nvflare.lighter.constants import CtxKey
from nvflare.lighter.spec import Builder, Participant, ParticipantType, Project, PropKey, ProvisionContext
from nvflare.lighter.constants import CtxKey, ParticipantType, PropKey
from nvflare.lighter.ctx import ProvisionContext
from nvflare.lighter.entity import Participant, Project
from nvflare.lighter.spec import Builder
from nvflare.lighter.utils import serialize_cert, serialize_pri_key

_CERT_BASE_NAME_CLIENT = "client"
Expand Down
Loading

0 comments on commit 947f8e0

Please sign in to comment.