Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Sep 2, 2023
1 parent f65908b commit 1d65d27
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 11 deletions.
105 changes: 96 additions & 9 deletions nvflare/fuel/utils/validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@

class DefaultPolicy:

"""
Defines policy for how to determine default value
"""

DISALLOW = "disallow"
ANY = "any"
EMPTY = "empty"
ALL = "all"

@classmethod
def valid_policy(cls, p: str):
return p in [cls.DISALLOW, cls.ANY, cls.EMPTY, cls.ALL]


def check_positive_int(name, value):
if not isinstance(value, int):
Expand Down Expand Up @@ -56,6 +64,13 @@ def check_str(name, value):
check_object_type(name, value, str)


def check_non_empty_str(name, value):
check_object_type(name, value, str)
v = value.strip()
if not v:
raise ValueError(f"{name} must not be empty")


def check_object_type(name, value, obj_type):
if not isinstance(value, obj_type):
raise TypeError(f"{name} must be {obj_type}, but got {type(value)}.")
Expand All @@ -66,20 +81,23 @@ def check_callable(name, value):
raise ValueError(f"{name} must be callable, but got {type(value)}.")


def _validate_candidates(var_name: str, candidates, base: list):
def _determine_candidates_value(var_name: str, candidates, base: list):
if not isinstance(base, list):
raise TypeError(f"base must be list but got {type(base)}")

if candidates is None:
return [] # empty
return None # empty

if isinstance(candidates, str):
c = candidates.lower().strip()
if not c:
return None
return []

if c == SYMBOL_ALL:
return candidates
return base
elif c == SYMBOL_NONE:
return None
elif c in candidates:
elif c in base:
return [c]
else:
raise ValueError(f"value of '{var_name}' ({candidates}) is invalid")
Expand All @@ -101,7 +119,43 @@ def _validate_candidates(var_name: str, candidates, base: list):


def validate_candidates(var_name: str, candidates, base: list, default_policy: str, allow_none: bool):
c = _validate_candidates(var_name, candidates, base)
"""Validate specified candidates against the items in the "base" list, based on specified policy
and returns determined value for the candidates.
The value of candidates could have the following cases:
1. Not explicitly specified (Python object None or empty list [])
In this case, the default_policy decides the final result:
- ANY: returns a list that contains a single item from the base
- EMPTY: returns an empty list
- ALL: returns the base list
- DISALLOW: raise exception - candidates must be explicitly specified
2. A list of string items
In this case, each item in the candidates list must be in the "base". Duplicates are removed.
3. A string with special value "@all" to mean "all items from the base"
Returns the base list.
4. A string with special value "@none" to mean "no items"
If allow_none is True, then returns an empty list; otherwise raise exception.
5. A string that is not a special value
If it is in the "base", return a list that contains this item; otherwise raise exception.
Args:
var_name: the name of the "candidates" var from the caller
candidates: the candidates to be validated
base: the base list that contains valid items
default_policy: policy for how to handle default value when "candidates" is not explicitly specified.
allow_none: whether "none" is allowed for candidates.
Returns:
"""
if not DefaultPolicy.valid_policy(default_policy):
raise ValueError(f"invalid default policy {default_policy}")

c = _determine_candidates_value(var_name, candidates, base)

if c is None:
if not allow_none:
Expand All @@ -119,11 +173,11 @@ def validate_candidates(var_name: str, candidates, base: list, default_policy: s
raise ValueError(f"invalid value '{candidates}' in '{var_name}': it must be subset of {base}")
else:
# any
return [candidates[0]]
return [base[0]]
return c


def _validate_candidate(var_name: str, candidate, base: list):
def _determine_candidate_value(var_name: str, candidate, base: list):
if candidate is None:
return None

Expand All @@ -143,7 +197,40 @@ def _validate_candidate(var_name: str, candidate, base: list):


def validate_candidate(var_name: str, candidate, base: list, default_policy: str, allow_none: bool):
c = _validate_candidate(var_name, candidate, base)
"""Validate specified candidate against the items in the "base" list, based on specified policy
and returns determined value for the candidate.
The value of candidate could have the following cases:
1. Not explicitly specified (Python object None or empty string)
In this case, the default_policy decides the final result:
- ANY: returns the first item from the base
- EMPTY: returns an empty str
- ALL or DISALLOW: raise exception - candidate must be explicitly specified
2. A string with special value "@none" to mean "nothing"
If allow_none is True, then returns an empty str; otherwise raise exception.
3. A string that is not a special value
If it is in the "base", return it; otherwise raise exception.
All other cases, raise exception.
NOTE: the final value is normalized (leading and trailing white spaces are removed).
Args:
var_name: the name of the "candidate" var from the caller
candidate: the candidate to be validated
base: the base list that contains valid items
default_policy: policy for how to handle default value when "candidates" is not explicitly specified.
allow_none: whether "none" is allowed for candidates.
Returns:
"""
if not DefaultPolicy.valid_policy(default_policy):
raise ValueError(f"invalid default policy {default_policy}")

c = _determine_candidate_value(var_name, candidate, base)
if c is None:
if not allow_none:
raise ValueError(f"{var_name} must be specified")
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_test/fuel/utils/validation_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class TestValidationUtils:

@pytest.mark.parametrize(
"var_name, candidate, base, default_policy, allow_none, output",
[
Expand Down Expand Up @@ -84,4 +83,4 @@ def test_validate_candidates(self, var_name, candidates, base, default_policy, a
)
def test_validate_candidates_error(self, var_name, candidate, base, default_policy, allow_none):
with pytest.raises(ValueError):
validate_candidates(var_name, candidate, base, default_policy, allow_none)
validate_candidates(var_name, candidate, base, default_policy, allow_none)

0 comments on commit 1d65d27

Please sign in to comment.