Skip to content

Commit

Permalink
WIP tirith parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
refeed committed Sep 4, 2024
1 parent 56d83f8 commit 9898bb1
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 45 deletions.
20 changes: 19 additions & 1 deletion src/tirith/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ def __init__(self, prog="PROG") -> None:
dest="inputPath",
help="Input file path",
)
parser.add_argument(
"-var-path",
metavar="PATH",
type=str,
default=[],
action="append",
dest="varPaths",
help="Variable file path(s)",
)
parser.add_argument(
"-var",
metavar="PATH",
type=str,
default=[],
action="append",
dest="inlineVars",
help="Inline variable(s)",
)
parser.add_argument(
"--json",
dest="json",
Expand Down Expand Up @@ -111,7 +129,7 @@ def __init__(self, prog="PROG") -> None:
setup_logging(verbose=args.verbose)

try:
result = start_policy_evaluation(args.policyPath, args.inputPath)
result = start_policy_evaluation(args.policyPath, args.inputPath, args.varPaths, args.inlineVars)

if args.json:
formatted_result = json.dumps(result, indent=3)
Expand Down
55 changes: 52 additions & 3 deletions src/tirith/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tirith.providers.common import ProviderError
from ..providers import PROVIDERS_DICT
from .evaluators import EVALUATORS_DICT
from .policy_parameterization import get_policy_with_vars_replaced


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,7 +205,17 @@ def final_evaluator(eval_string: str, eval_id_values: Dict[str, Optional[bool]])
return final_eval_result, []


def start_policy_evaluation(policy_path: str, input_path: str) -> Dict:
def start_policy_evaluation(
policy_path: str, input_path: str, var_paths: List[str] = [], inline_vars: List[str] = []
) -> Dict:
"""
Start Tirith policy evaluation from policy file, input file, and optional variable files.
:param policy_path: Path to the policy file
:param input_path: Path to the input file
:param var_paths: List of paths to the variable files
:return: Policy evaluation result
"""
with open(policy_path) as f:
policy_data = json.load(f)
# TODO: validate policy_data against schema
Expand All @@ -218,12 +229,50 @@ def start_policy_evaluation(policy_path: str, input_path: str) -> Dict:
input_data = json.load(f)
# TODO: validate input_data using the optionally available validate function in provider

return start_policy_evaluation_from_dict(policy_data, input_data)
# TODO: Move this logic into another module
# Merge policy variables into one dictionary
var_dicts = []
for var_path in var_paths:
with open(var_path, encoding="utf-8") as f:
var_dicts.append(json.load(f))

merged_var_dict = _merge_var_dicts(var_dicts)

variable_pattern = re.compile(r"(?P<var_name>\w+)=(?P<var_json>.+)")
for inline_var in inline_vars:
match = re.fullmatch(variable_pattern, inline_var)
if match:
try:
merged_var_dict[match.group("var_name")] = json.loads(match.group("var_json"))
except json.JSONDecodeError:
logger.error(f"Failed to parse inline variable: {inline_var}")
else:
logger.error(f"Invalid inline variable: {inline_var}")

return start_policy_evaluation_from_dict(policy_data, input_data, merged_var_dict)


def _merge_var_dicts(var_dicts: List[dict]) -> dict:
"""
Utility to merge var_dicts
:param var_dicts: List of var dictionaries
:return: A merged dictionary
"""
merged_var_dict = {}
for var_dict in var_dicts:
merged_var_dict.update(var_dict)
return merged_var_dict


def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict, var_dict: Dict = {}) -> Dict:
policy_dict, not_found_vars = get_policy_with_vars_replaced(policy_dict, var_dict)
if not_found_vars:
return {"errors": [f"Variables not found: {', '.join(not_found_vars)}"]}

def start_policy_evaluation_from_dict(policy_dict: Dict, input_dict: Dict) -> Dict:
policy_meta = policy_dict.get("meta")
eval_objects = policy_dict.get("evaluators")

final_evaluation_policy_string = policy_dict.get("eval_expression")
provider_module = policy_meta.get("required_provider", "core")
# TODO: Write functionality for dynamically importing evaluators from other modules.
Expand Down
83 changes: 59 additions & 24 deletions src/tirith/core/policy_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,74 @@
import re
import pydash

from typing import List, Tuple

_VAR_PATTERN = re.compile(r"{{\s*var\.([\w\.]+)\s*}}")

class PydashPathNotFound:
pass

class _VariableNotFound:
pass

def check_match(string: str, pattern: re.Pattern) -> re.Match:
match_ = re.fullmatch(pattern, string)
return match_

def _replace_vars_in_dict(dictionary: dict, var_dict: dict, not_found_vars: List[str]):
"""
Replace the variables in the dictionary with the values from the var_dict
def helper(dictionary: dict, var_pattern: re.Pattern, var_dict: dict):
:param dictionary: The dictionary to replace the variables in
:param var_pattern: The pattern to match the variables
:param var_dict: The dictionary containing the variables
"""
for key, value in dictionary.items():
if isinstance(value, str):
match = check_match(value, var_pattern)
if bool(match):
dictionary[key] = pydash.get(var_dict, match.group(1), default=PydashPathNotFound)
if not isinstance(value, str):
continue
_replace_var_in_dict(dictionary, key, var_dict, not_found_vars)


def replace_vars(policy_dict: dict, var_dict: dict) -> dict:
var_pattern = re.compile(r"{{var\.([\w\.]+)}}")
def _replace_var_in_dict(dictionary: dict, key: str, var_dict: dict, not_found_vars: list):
"""
Replace the variable in the dictionary with the value from the var_dict
This only replaces single dictionary key
evaluators = policy_dict["evaluators"]
helper(policy_dict["meta"], var_pattern, var_dict)
for i in range(len(evaluators)):
match = check_match(evaluators[i]["id"], var_pattern)
if bool(match):
evaluators[i]["id"] = pydash.get(var_dict, match.group(1), default=PydashPathNotFound)
:param dictionary: The dictionary to replace the variable in
:param key: The key of the param `dictionary` to replace the variable in
:param var_dict: The dictionary containing the variables
:param not_found_vars: The list to store the variables that are not found in
"""
var_expression = dictionary[key]

match = _VAR_PATTERN.match(var_expression)
if not match:
return

helper(evaluators[i]["condition"], var_pattern, var_dict)
helper(evaluators[i]["provider_args"], var_pattern, var_dict)
var_name = match.group(1)
var_value = pydash.get(var_dict, var_name, default=_VariableNotFound)
if var_value is _VariableNotFound:
not_found_vars.append(var_name)
return
dictionary[key] = var_value


def get_policy_with_vars_replaced(policy_dict: dict, var_dict: dict) -> Tuple[dict, List[str]]:
"""
Replace the variables in the policy_dict with the values from the var_dict
:param policy_dict: The policy dictionary
:param var_dict: The dictionary containing the variables
:return: The policy dictionary with the variables replaced
and the list of variables that are not found
"""
not_found_vars = []
# Replace vars in the meta key
_replace_vars_in_dict(policy_dict["meta"], var_dict, not_found_vars)

# Replace vars in the evaluators
evaluators = policy_dict["evaluators"]
for evaluator in evaluators:
_replace_var_in_dict(evaluator, "id", var_dict, not_found_vars)
_replace_vars_in_dict(evaluator["provider_args"], var_dict, not_found_vars)
_replace_vars_in_dict(evaluator["condition"], var_dict, not_found_vars)

match = check_match(policy_dict["eval_expression"], var_pattern)
if bool(match):
policy_dict["eval_expression"] = pydash.get(var_dict, match.group(1), default=PydashPathNotFound)
# Replace vars in the eval_expression
_replace_var_in_dict(policy_dict, "eval_expression", var_dict, not_found_vars)

return policy_dict
return policy_dict, not_found_vars
19 changes: 11 additions & 8 deletions src/tirith/prettyprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def pretty_print_result_dict(final_result_dict: Dict) -> None:
:param final_result_dict: Result dictionary generated by core.
"""
checks = final_result_dict["evaluators"]
checks = final_result_dict.get("evaluators", [])
num_passed_checks = 0
num_failed_checks = 0
num_skipped_checks = 0
Expand Down Expand Up @@ -116,10 +116,13 @@ def pretty_print_result_dict(final_result_dict: Dict) -> None:

print(f"Passed: {num_passed_checks} Failed: {num_failed_checks} Skipped: {num_skipped_checks}")
print()
print(f"Final expression used:\n-> {TermStyle.grey(final_result_dict['eval_expression'])}")
if final_result_dict["final_result"]:
print(TermStyle.success("✔ Passed final evaluator"))
elif final_result_dict["final_result"] is None:
print(TermStyle.skipped("= Skipped final evaluator"))
else:
print(TermStyle.fail("✘ Failed final evaluation"))
if 'eval_expression' in final_result_dict:
print(f"Final expression used:\n-> {TermStyle.grey(final_result_dict['eval_expression'])}")

if 'final_result' in final_result_dict:
if final_result_dict["final_result"]:
print(TermStyle.success("✔ Passed final evaluator"))
elif final_result_dict["final_result"] is None:
print(TermStyle.skipped("= Skipped final evaluator"))
else:
print(TermStyle.fail("✘ Failed final evaluation"))
17 changes: 8 additions & 9 deletions tests/core/test_policy_parameterization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from tirith.core.policy_parameterization import replace_vars, PydashPathNotFound
from tirith.core.policy_parameterization import get_policy_with_vars_replaced, _VariableNotFound


@pytest.fixture
Expand All @@ -27,20 +27,19 @@ def processed_policy():
}

# Run the function once and return the result
return replace_vars(input_dict, var_dict)
return get_policy_with_vars_replaced(input_dict, var_dict)


def test_nested_dict(processed_policy):
assert processed_policy["meta"]["required_provider"] == "stackguardian/json"


def test_path_not_found(processed_policy):
assert processed_policy["evaluators"][0]["provider_args"]["key_path"] == PydashPathNotFound
assert processed_policy[0]["meta"]["required_provider"] == "stackguardian/json"


def test_var_value_in_list(processed_policy):
assert processed_policy["evaluators"][0]["condition"]["value"] == 2
assert processed_policy[0]["evaluators"][0]["condition"]["value"] == 2


def test_eval_expression_parameterization(processed_policy):
assert processed_policy["eval_expression"] == "check0"
assert processed_policy[0]["eval_expression"] == "check0"

def test_not_found_variable(processed_policy):
assert processed_policy[1] == ["key_path"]

0 comments on commit 9898bb1

Please sign in to comment.