diff --git a/acto/__main__.py b/acto/__main__.py index 0ca73a0166..d01ddd2fc4 100644 --- a/acto/__main__.py +++ b/acto/__main__.py @@ -120,6 +120,7 @@ if "monkey_patch" in config: del config["monkey_patch"] config = OperatorConfig.model_validate(config) + logger.info("Acto started with [%s]", sys.argv) logger.info("Operator config: %s", config) @@ -137,7 +138,6 @@ workdir_path=args.workdir_path, operator_config=config, cluster_runtime="KIND", - preload_images_=None, context_file=context_cache, helper_crd=args.helper_crd, num_workers=args.num_workers, @@ -147,13 +147,12 @@ is_reproduce=False, input_model=DeterministicInputModel, apply_testcase_f=apply_testcase_f, - delta_from=None, focus_fields=config.focus_fields, ) generation_time = datetime.now() logger.info("Acto initialization finished in %s", generation_time - start_time) if not args.learn: - acto.run(modes=["normal"]) + acto.run() normal_finish_time = datetime.now() logger.info("Acto normal run finished in %s", normal_finish_time - start_time) logger.info("Start post processing steps") diff --git a/acto/common.py b/acto/common.py index 74eba6cc9e..9ad528b236 100644 --- a/acto/common.py +++ b/acto/common.py @@ -11,6 +11,7 @@ import kubernetes import pydantic from deepdiff.helper import NotPresent +from typing_extensions import Self from acto.utils.thread_logger import get_thread_logger @@ -47,6 +48,26 @@ def __len__(self): def __contains__(self, item: PathSegment): return item in self.path + @classmethod + def from_json_patch_string(cls, patch_path: str) -> Self: + """Convert a JSON patch string to a PropertyPath object""" + items = patch_path.split("/") + return cls(items[1:]) + + +class HashableDict(dict): + """Hashable dict""" + + def __hash__(self): + return hash(json.dumps(self, sort_keys=True)) + + +class HashableList(list): + """Hashable list""" + + def __hash__(self): + return hash(json.dumps(self, sort_keys=True)) + class Diff(pydantic.BaseModel): """Class for storing the diff between two values""" @@ -78,358 +99,6 @@ def value_eq_with_not_present(a: Any, b: Any) -> bool: ) -# class Oracle(str, enum.Enum): -# """Enum for different oracle types""" - -# ERROR_LOG = "ErrorLog" -# SYSTEM_STATE = "SystemState" -# SYSTEM_HEALTH = "SystemHealth" -# RECOVERY = "Recovery" -# CRASH = "Crash" -# CUSTOM = "Custom" - - -# class RunResult: -# """Result of a single run of a testcase""" - -# def __init__( -# self, revert, generation: int, testcase_signature: dict -# ) -> None: -# self.crash_result: OracleResult = None -# self.input_result: OracleResult = None -# self.health_result: OracleResult = None -# self.state_result: OracleResult = None -# self.log_result: OracleResult = None -# self.custom_result: OracleResult = None -# self.misc_result: OracleResult = None -# self.recovery_result: OracleResult = None -# self.other_results: Dict[str, OracleResult] = {} - -# self.generation = generation - -# self.revert = revert -# self.testcase_signature = testcase_signature - -# def set_result(self, result_name: str, value: "OracleResult"): -# """Set result of a specific oracle""" -# # TODO, store all results in a dict -# if result_name == "crash": -# self.crash_result = value -# elif result_name == "input": -# self.input_result = value -# elif result_name == "health": -# self.health_result = value -# elif result_name == "state": -# self.state_result = value -# elif result_name == "log": -# self.log_result = value -# else: -# self.other_results[result_name] = value -# self.custom_result = value - -# def is_pass(self) -> bool: -# """Returns if the run is a pass""" -# if ( -# not isinstance(self.crash_result, PassResult) -# and self.crash_result is not None -# ): -# return False -# if ( -# not isinstance(self.health_result, PassResult) -# and self.health_result is not None -# ): -# return False -# if ( -# not isinstance(self.custom_result, PassResult) -# and self.custom_result is not None -# ): -# return False - -# if isinstance(self.state_result, PassResult): -# return True -# if ACTO_CONFIG.alarms.invalid_input and isinstance( -# self.log_result, InvalidInputResult -# ): -# return True -# return False - -# def is_invalid(self) -> Tuple[bool, "InvalidInputResult"]: -# """Returns if the run is invalid input""" -# if isinstance(self.input_result, InvalidInputResult): -# return True, self.input_result -# if isinstance(self.log_result, InvalidInputResult): -# return True, self.log_result -# if isinstance(self.misc_result, InvalidInputResult): -# return True, self.misc_result -# return False, None - -# def is_connection_refused(self) -> bool: -# """Returns if the run is connection refused""" -# return isinstance(self.input_result, ConnectionRefusedResult) - -# def is_unchanged(self) -> bool: -# """Returns if kubectl reports unchanged input""" -# return isinstance(self.input_result, UnchangedInputResult) - -# def is_error(self) -> bool: -# """Returns if the run raises an error""" -# if isinstance(self.crash_result, ErrorResult): -# return True -# if isinstance(self.health_result, ErrorResult): -# return True -# if isinstance(self.custom_result, ErrorResult): -# return True -# if isinstance(self.recovery_result, ErrorResult): -# return True -# if not isinstance(self.state_result, ErrorResult): -# return False - -# if ACTO_CONFIG.alarms.invalid_input and isinstance( -# self.log_result, InvalidInputResult -# ): -# return False -# return True - -# def is_basic_error(self) -> bool: -# """Returns if the run raises an error of basic type""" -# if isinstance(self.crash_result, ErrorResult): -# return True -# if isinstance(self.health_result, ErrorResult): -# return True -# if isinstance(self.custom_result, ErrorResult): -# return True -# if isinstance(self.recovery_result, ErrorResult): -# return True -# return False - -# def to_dict(self): -# """serialize RunResult object""" -# return { -# "revert": self.revert, -# "generation": self.generation, -# "testcase": self.testcase_signature, -# "crash_result": self.crash_result.to_dict() -# if self.crash_result -# else None, -# "input_result": self.input_result.to_dict() -# if self.input_result -# else None, -# "health_result": self.health_result.to_dict() -# if self.health_result -# else None, -# "state_result": self.state_result.to_dict() -# if self.state_result -# else None, -# "log_result": self.log_result.to_dict() -# if self.log_result -# else None, -# "custom_result": self.custom_result.to_dict() -# if self.custom_result -# else None, -# "misc_result": self.misc_result.to_dict() -# if self.misc_result -# else None, -# "recovery_result": self.recovery_result.to_dict() -# if self.recovery_result -# else None, -# } - -# @staticmethod -# def from_dict(d: dict) -> "RunResult": -# """deserialize RunResult object""" - -# result = RunResult(d["revert"], d["generation"], d["testcase"]) -# result.crash_result = oracle_result_from_dict(d["crash_result"]) -# result.input_result = oracle_result_from_dict(d["input_result"]) -# result.health_result = oracle_result_from_dict(d["health_result"]) -# result.state_result = oracle_result_from_dict(d["state_result"]) -# result.log_result = oracle_result_from_dict(d["log_result"]) -# result.custom_result = oracle_result_from_dict(d["custom_result"]) -# result.misc_result = oracle_result_from_dict(d["misc_result"]) -# result.recovery_result = oracle_result_from_dict(d["recovery_result"]) -# return result - - -# class OracleResult: -# """Base class for oracle results""" - -# @abstractmethod -# def to_dict(self): -# """serialize OracleResult object""" -# return {} - - -# class PassResult(OracleResult): -# """Indicates the oracle passes""" - -# def to_dict(self): -# return "Pass" - -# def __eq__(self, other): -# return isinstance(other, PassResult) - - -# class InvalidInputResult(OracleResult): -# """Indicates the input is invalid""" - -# def __init__(self, responsible_field: list) -> None: -# self.responsible_field = responsible_field - -# def to_dict(self): -# return {"responsible_field": self.responsible_field} - -# def __eq__(self, other): -# return ( -# isinstance(other, InvalidInputResult) -# and self.responsible_field == other.responsible_field -# ) - - -# class UnchangedInputResult(OracleResult): -# """Indicates the input is unchanged""" - -# def to_dict(self): -# return "UnchangedInput" - - -# class ConnectionRefusedResult(OracleResult): -# """Indicates the connection is refused""" - -# def to_dict(self): -# return "ConnectionRefused" - - -# class ErrorResult(OracleResult, Exception): -# """Base class for error results""" - -# def __init__(self, oracle: Oracle, msg: str) -> None: -# self.oracle = oracle -# self.message = msg - -# def to_dict(self): -# return {"oracle": self.oracle, "message": self.message} - -# @staticmethod -# def from_dict(d: dict): -# """deserialize ErrorResult object""" -# return ErrorResult(d["oracle"], d["message"]) - - -# class StateResult(ErrorResult): -# """Result of system state oracle""" - -# def __init__( -# self, -# oracle: Oracle, -# msg: str, -# input_delta: Diff = None, -# matched_system_delta: Diff = None, -# ) -> None: -# super().__init__(oracle, msg) -# self.input_delta = input_delta -# self.matched_system_delta = matched_system_delta - -# def to_dict(self): -# return { -# "oracle": self.oracle, -# "message": self.message, -# "input_delta": self.input_delta.to_dict() -# if self.input_delta -# else None, -# "matched_system_delta": self.matched_system_delta.to_dict() -# if self.matched_system_delta -# else None, -# } - -# @staticmethod -# def from_dict(d: dict) -> "StateResult": -# result = StateResult(d["oracle"], d["message"]) -# result.input_delta = ( -# Diff.from_dict(d["input_delta"]) if d["input_delta"] else None -# ) -# result.matched_system_delta = ( -# Diff.from_dict(d["matched_system_delta"]) -# if d["matched_system_delta"] -# else None -# ) -# return result - -# def __eq__(self, other): -# return ( -# isinstance(other, StateResult) -# and self.oracle == other.oracle -# and self.message == other.message -# and self.input_delta == other.input_delta -# and self.matched_system_delta == other.matched_system_delta -# ) - - -# class UnhealthyResult(ErrorResult): -# """Result of system health oracle""" - -# def to_dict(self): -# return {"oracle": self.oracle, "message": self.message} - -# @staticmethod -# def from_dict(d: dict): -# return UnhealthyResult(d["oracle"], d["message"]) - - -# class RecoveryResult(ErrorResult): -# """Result of recovery oracle""" - -# def __init__(self, delta, from_, to_) -> None: -# super().__init__(Oracle.RECOVERY, "Recovery") -# self.delta = delta -# self.from_ = from_ -# self.to_ = to_ - -# def to_dict(self): -# return { -# "oracle": self.oracle, -# "delta": json.loads( -# self.delta.to_json( -# default_mapping={datetime: lambda x: x.isoformat()} -# ) -# ), -# "from": self.from_, -# "to": self.to_, -# } - -# @staticmethod -# def from_dict(d: dict) -> "RecoveryResult": -# result = RecoveryResult(d["delta"], d["from"], d["to"]) -# return result - - -# def oracle_result_from_dict(d: dict) -> OracleResult: -# """deserialize OracleResult object""" -# if d is None: -# return PassResult() -# if d == "Pass": -# return PassResult() -# if d == "UnchangedInput": -# return UnchangedInputResult() -# if d == "ConnectionRefused": -# return ConnectionRefusedResult() - -# if "responsible_field" in d: -# return InvalidInputResult(d["responsible_field"]) -# if "oracle" in d: -# if d["oracle"] == Oracle.SYSTEM_STATE: -# return StateResult.from_dict(d) -# if d["oracle"] == Oracle.SYSTEM_HEALTH: -# return UnhealthyResult.from_dict(d) -# if d["oracle"] == Oracle.RECOVERY: -# return RecoveryResult.from_dict(d) -# if d["oracle"] == Oracle.CRASH: -# return UnhealthyResult.from_dict(d) -# if d["oracle"] == Oracle.CUSTOM: -# return ErrorResult.from_dict(d) - -# raise ValueError(f"Invalid oracle result dict: {d}") - - def flatten_list(l: list, curr_path: list) -> list: """Convert list into list of tuples (path, value) diff --git a/acto/deploy.py b/acto/deploy.py index c03f11bf42..bf5265a658 100644 --- a/acto/deploy.py +++ b/acto/deploy.py @@ -1,20 +1,29 @@ import logging +import subprocess import time +from typing import Optional -import yaml - -import acto.utils as utils +from acto import utils from acto.common import kubernetes_client, print_event +from acto.kubectl_client.helm import Helm from acto.kubectl_client.kubectl import KubectlClient from acto.lib.operator_config import DELEGATED_NAMESPACE, DeployConfig from acto.utils import get_thread_logger +from acto.utils.k8s_helper import ( + get_deployment_name_from_yaml, + get_yaml_existing_namespace, +) from acto.utils.preprocess import add_acto_label def wait_for_pod_ready(kubectl_client: KubectlClient) -> bool: """Wait for all pods to be ready""" now = time.time() - p = kubectl_client.wait_for_all_pods(timeout=600) + try: + p = kubectl_client.wait_for_all_pods(timeout=600) + except subprocess.TimeoutExpired: + logging.error("Timeout waiting for all pods to be ready") + return False if p.returncode != 0: logging.error( "Failed to wait for all pods to be ready due to error from kubectl" @@ -35,13 +44,16 @@ class Deploy: def __init__(self, deploy_config: DeployConfig) -> None: self._deploy_config = deploy_config - self._operator_yaml: str = None + self._operator_existing_namespace: Optional[str] = None for step in self._deploy_config.steps: if step.apply and step.apply.operator: - self._operator_yaml = step.apply.file + self._operator_existing_namespace = get_yaml_existing_namespace( + step.apply.file + ) + break + if step.helm_install and step.helm_install.operator: + self._operator_existing_namespace = None break - else: - raise RuntimeError("No operator yaml found in deploy config") # Extract the operator_container_name from config self._operator_container_name = None @@ -51,11 +63,30 @@ def __init__(self, deploy_config: DeployConfig) -> None: step.apply.operator_container_name ) break + if step.helm_install and step.helm_install.operator: + self._operator_container_name = ( + step.helm_install.operator_container_name + ) + break + + self._operator_deployment_name = None + for step in self._deploy_config.steps: + if step.apply and step.apply.operator: + if ( + ret := get_deployment_name_from_yaml(step.apply.file) + ) is not None: + self._operator_deployment_name = ret + break + if step.helm_install and step.helm_install.operator: + self._operator_deployment_name = ( + step.helm_install.operator_deployment_name + ) + break @property - def operator_yaml(self) -> str: + def operator_existing_namespace(self) -> Optional[str]: """Get the operator yaml file path""" - return self._operator_yaml + return self._operator_existing_namespace def deploy( self, @@ -106,6 +137,35 @@ def deploy( elif step.wait: # Simply wait for the specified duration time.sleep(step.wait.duration) + elif step.helm_install: + # Use the namespace from the argument if the namespace is delegated + # If the namespace from the config is explicitly specified, + # use the specified namespace + # If the namespace from the config is set to None, do not apply + # with namespace + release_namespace = "default" + if step.helm_install.namespace == DELEGATED_NAMESPACE: + release_namespace = namespace + elif step.helm_install.namespace is not None: + release_namespace = step.helm_install.namespace + + # Install the helm chart + helm = Helm(kubeconfig, context_name) + p = helm.install( + release_name=step.helm_install.release_name, + chart=step.helm_install.chart, + namespace=release_namespace, + repo=step.helm_install.repo, + version=step.helm_install.version, + ) + if p.returncode != 0: + logger.error( + "Failed to deploy operator due to error from helm" + + f" (returncode={p.returncode})" + + f" (stdout={p.stdout})" + + f" (stderr={p.stderr})" + ) + return False # Add acto label to the operator pod add_acto_label(api_client, namespace) @@ -135,16 +195,11 @@ def deploy_with_retry( logger.error("Failed to deploy operator, retrying...") return False - def operator_name(self) -> str: + def operator_name(self) -> Optional[str]: """Get the name of the operator deployment""" - with open(self._operator_yaml, "r", encoding="utf-8") as f: - operator_yamls = yaml.load_all(f, Loader=yaml.FullLoader) - for yaml_ in operator_yamls: - if yaml_["kind"] == "Deployment": - return yaml_["metadata"]["name"] - return None + return self._operator_deployment_name @property - def operator_container_name(self) -> str: + def operator_container_name(self) -> Optional[str]: """Get the name of the operator container""" return self._operator_container_name diff --git a/acto/engine.py b/acto/engine.py index 5086cae2a6..11c5cf4b6e 100644 --- a/acto/engine.py +++ b/acto/engine.py @@ -18,9 +18,15 @@ from acto.checker.checker_set import CheckerSet from acto.checker.impl.health import HealthChecker -from acto.common import kubernetes_client, print_event +from acto.common import ( + PropertyPath, + kubernetes_client, + postprocess_diff, + print_event, +) from acto.constant import CONST from acto.deploy import Deploy +from acto.input.constraint import XorCondition from acto.input.input import DeterministicInputModel, InputModel from acto.input.testcase import TestCase from acto.input.testplan import TestGroup @@ -31,6 +37,7 @@ from acto.oracle_handle import OracleHandle from acto.result import ( CliStatus, + DeletionOracleResult, DifferentialOracleResult, OracleResults, RunResult, @@ -41,12 +48,8 @@ from acto.runner import Runner from acto.serialization import ActoEncoder, ContextEncoder from acto.snapshot import Snapshot -from acto.utils import ( - delete_operator_pod, - get_yaml_existing_namespace, - process_crd, - update_preload_images, -) +from acto.utils import delete_operator_pod, process_crd +from acto.utils.preprocess import get_existing_images from acto.utils.thread_logger import get_thread_logger, set_thread_logger_prefix from ssa.analysis import analyze @@ -58,8 +61,9 @@ def apply_testcase( path: list, testcase: TestCase, setup: bool = False, + constraints: Optional[list[XorCondition]] = None, ) -> jsonpatch.JsonPatch: - """Apply a testcase to a value""" + """This function realizes the testcase onto the current valueWithSchema""" logger = get_thread_logger(with_prefix=True) prev = value_with_schema.raw_value() @@ -78,7 +82,27 @@ def apply_testcase( ) curr = value_with_schema.raw_value() - patch = jsonpatch.make_patch(prev, curr) + # Satisfy constraints + assumptions: list[tuple[PropertyPath, bool]] = [] + input_change = postprocess_diff( + deepdiff.DeepDiff( + prev, + curr, + view="tree", + ) + ) + for changes in input_change.values(): + for diff in changes.values(): + if isinstance(diff.curr, bool): + assumptions.append((diff.path, diff.curr)) + if constraints is not None: + for constraint in constraints: + result = constraint.solve(assumptions) + if result is not None: + value_with_schema.set_value_by_path(result[1], result[0].path) + + curr = value_with_schema.raw_value() + patch: list[dict] = jsonpatch.make_patch(prev, curr) logger.info("JSON patch: %s", patch) return patch @@ -236,6 +260,7 @@ def __init__( apply_testcase_f: Callable, acto_namespace: int, additional_exclude_paths: Optional[list[str]] = None, + constraints: Optional[list[XorCondition]] = None, ) -> None: self.context = context self.workdir = workdir @@ -273,6 +298,7 @@ def __init__( self.discarded_testcases: dict[str, list[TestCase]] = {} self.apply_testcase_f = apply_testcase_f + self.constraints = constraints self.curr_trial = 0 def run( @@ -415,6 +441,7 @@ def run_trial( generation = 0 trial_id = f"trial-{self.worker_id + self.sequence_base:02d}-{self.curr_trial:04d}" + trial_err: Optional[OracleResults] = None while ( generation < num_mutation ): # every iteration gets a new list of next tests @@ -436,12 +463,9 @@ def run_trial( # if test_group is None, it means this group is exhausted # break and move to the next trial if test_groups is None: - return TrialResult( - trial_id=trial_id, - duration=time.time() - trial_start_time, - error=None, - ) + break + setup_fail = False # to break the loop if setup fails # First make sure all the next tests are valid for ( group, @@ -457,6 +481,14 @@ def run_trial( list(field_path) ) + logger.info( + "Path [%s] has examples: %s", + field_path, + self.input_model.get_schema_by_path( + field_path + ).examples, + ) + if testcase.test_precondition(field_curr_value): # precondition of this testcase satisfies logger.info("Precondition of %s satisfies", field_path) @@ -473,6 +505,7 @@ def run_trial( field_path, testcase, setup=True, + constraints=self.constraints, ) if not testcase.test_precondition( @@ -500,11 +533,9 @@ def run_trial( == CliStatus.CONNECTION_REFUSED ): logger.error("Connection refused, exiting") - return TrialResult( - trial_id=trial_id, - duration=time.time() - trial_start_time, - error=None, - ) + trial_err = None + setup_fail = True + break if ( run_result.is_invalid_input() and run_result.oracle_result.health is None @@ -525,17 +556,18 @@ def run_trial( runner ) generation += 1 - return TrialResult( - trial_id=trial_id, - duration=time.time() - trial_start_time, - error=run_result.oracle_result, - ) + trial_err = run_result.oracle_result + setup_fail = True + break elif run_result.cli_status == CliStatus.UNCHANGED: logger.info("Setup produced unchanged input") group.discard_testcase(self.discarded_testcases) else: ready_testcases.append((group, testcase_with_path)) + if setup_fail: + break + if len(ready_testcases) == 0: logger.info("All setups failed") continue @@ -555,22 +587,24 @@ def run_trial( runner ) generation += 1 - - return TrialResult( - trial_id=f"trial-{self.worker_id + self.sequence_base:02d}" - + f"-{self.curr_trial:04d}", - duration=time.time() - trial_start_time, - error=run_result.oracle_result, - ) + trial_err = run_result.oracle_result + break if self.input_model.is_empty(): logger.info("Input model is empty, break") + trial_err = None break + if trial_err is not None: + trial_err.deletion = self.run_delete(runner, generation=generation) + else: + trial_err = OracleResults() + trial_err.deletion = self.run_delete(runner, generation=generation) + return TrialResult( - trial_id=f"trial-{self.worker_id + self.sequence_base:02d}-{self.curr_trial:04d}", + trial_id=trial_id, duration=time.time() - trial_start_time, - error=None, + error=trial_err, ) def run_testcases( @@ -594,7 +628,10 @@ def run_testcases( "testcase": str(testcase), } patch = self.apply_testcase_f( - curr_input_with_schema, field_path, testcase + curr_input_with_schema, + field_path, + testcase, + constraints=self.constraints, ) # field_node.get_testcases().pop() # finish testcase @@ -723,6 +760,20 @@ def run_recovery( else: return None + def run_delete( + self, runner: Runner, generation: int + ) -> Optional[DeletionOracleResult]: + """Runs the deletion test to check if the operator can properly handle deletion""" + logger = get_thread_logger(with_prefix=True) + + logger.debug("Running delete") + success = runner.delete(generation=generation) + + if not success: + return DeletionOracleResult(message="Deletion test case") + else: + return None + def revert(self, runner, checker, generation) -> ValueWithSchema: """Revert to the previous system state""" curr_input_with_schema = attach_schema_to_value( @@ -751,7 +802,6 @@ def __init__( workdir_path: str, operator_config: OperatorConfig, cluster_runtime: str, - preload_images_: Optional[list], context_file: str, helper_crd: Optional[str], num_workers: int, @@ -761,7 +811,6 @@ def __init__( is_reproduce: bool, input_model: type[DeterministicInputModel], apply_testcase_f: Callable, - delta_from: Optional[str] = None, mount: Optional[list] = None, focus_fields: Optional[list] = None, acto_namespace: int = 0, @@ -773,6 +822,7 @@ def __init__( operator_config.seed_custom_resource, "r", encoding="utf-8" ) as cr_file: self.seed = yaml.load(cr_file, Loader=yaml.FullLoader) + self.seed["metadata"]["name"] = "test-cluster" except yaml.YAMLError as e: logger.error("Failed to read seed yaml, aborting: %s", e) sys.exit(1) @@ -802,6 +852,7 @@ def __init__( self.deploy = deploy self.operator_config = operator_config self.crd_name = operator_config.crd_name + self.crd_version = operator_config.crd_version self.workdir_path = workdir_path self.images_archive = os.path.join(workdir_path, "images.tar") self.num_workers = num_workers @@ -820,10 +871,6 @@ def __init__( analysis_only=analysis_only, ) - # Add additional preload images from arguments - if preload_images_ is not None: - self.context["preload_images"].update(preload_images_) - self.input_model: DeterministicInputModel = input_model( crd=self.context["crd"]["body"], seed_input=self.seed, @@ -835,7 +882,7 @@ def __init__( custom_module_path=operator_config.custom_module, ) - self.sequence_base = 20 if delta_from else 0 + self.sequence_base = 0 if operator_config.custom_oracle is not None: module = importlib.import_module(operator_config.custom_oracle) @@ -846,11 +893,8 @@ def __init__( self.custom_on_init = None # Generate test cases - testplan_path = None - if delta_from is not None: - testplan_path = os.path.join(delta_from, "test_plan.json") self.test_plan = self.input_model.generate_test_plan( - testplan_path, focus_fields=focus_fields + focus_fields=focus_fields ) with open( os.path.join(self.workdir_path, "test_plan.json"), @@ -935,9 +979,13 @@ def __learn(self, context_file, helper_crd, analysis_only=False): ) self.cluster.restart_cluster("learn", learn_kubeconfig) + + existing_images = get_existing_images( + self.cluster.get_node_list("learn") + ) + namespace = ( - get_yaml_existing_namespace(self.deploy.operator_yaml) - or CONST.ACTO_NAMESPACE + self.deploy.operator_existing_namespace or CONST.ACTO_NAMESPACE ) self.context["namespace"] = namespace kubectl_client = KubectlClient(learn_kubeconfig, learn_context_name) @@ -949,7 +997,7 @@ def __learn(self, context_file, helper_crd, analysis_only=False): ) if not deployed: raise RuntimeError( - f"Failed to deploy operator due to max retry exceed" + "Failed to deploy operator due to max retry exceed" ) apiclient = kubernetes_client(learn_kubeconfig, learn_context_name) @@ -958,6 +1006,7 @@ def __learn(self, context_file, helper_crd, analysis_only=False): apiclient, KubectlClient(learn_kubeconfig, learn_context_name), self.crd_name, + self.crd_version, helper_crd, ) @@ -982,9 +1031,13 @@ def __learn(self, context_file, helper_crd, analysis_only=False): "Please make sure the operator config is correct" ) - update_preload_images( - self.context, self.cluster.get_node_list("learn") + current_images = get_existing_images( + self.cluster.get_node_list("learn") ) + for current_image in current_images: + if current_image not in existing_images: + self.context["preload_images"].add(current_image) + self.cluster.delete_cluster("learn", learn_kubeconfig) run_end_time = time.time() @@ -1037,9 +1090,7 @@ def __learn(self, context_file, helper_crd, analysis_only=False): sort_keys=True, ) - def run( - self, modes: list = ["normal", "overspecified", "copiedover"] - ) -> list[OracleResults]: + def run(self) -> list[OracleResults]: """Run the test cases""" logger = get_thread_logger(with_prefix=True) @@ -1084,66 +1135,23 @@ def run( self.apply_testcase_f, self.acto_namespace, self.operator_config.diff_ignore_fields, + constraints=self.operator_config.constraints, ) runners.append(runner) - if "normal" in modes: - threads = [] - for runner in runners: - t = threading.Thread( - target=runner.run, args=[errors, InputModel.NORMAL] - ) - t.start() - threads.append(t) + threads = [] + for runner in runners: + t = threading.Thread( + target=runner.run, args=[errors, InputModel.NORMAL] + ) + t.start() + threads.append(t) - for t in threads: - t.join() + for t in threads: + t.join() normal_time = time.time() - if "overspecified" in modes: - threads = [] - for runner in runners: - t = threading.Thread( - target=runner.run, args=([errors, InputModel.OVERSPECIFIED]) - ) - t.start() - threads.append(t) - - for t in threads: - t.join() - - overspecified_time = time.time() - - if "copiedover" in modes: - threads = [] - for runner in runners: - t = threading.Thread( - target=runner.run, args=([errors, InputModel.COPIED_OVER]) - ) - t.start() - threads.append(t) - - for t in threads: - t.join() - - additional_semantic_time = time.time() - - if InputModel.ADDITIONAL_SEMANTIC in modes: - threads = [] - for runner in runners: - t = threading.Thread( - target=runner.run, - args=([errors, InputModel.ADDITIONAL_SEMANTIC]), - ) - t.start() - threads.append(t) - - for t in threads: - t.join() - - end_time = time.time() - num_total_failed = 0 for runner in runners: for testcases in runner.discarded_testcases.values(): @@ -1151,10 +1159,6 @@ def run( testrun_info = { "normal_duration": normal_time - start_time, - "overspecified_duration": overspecified_time - normal_time, - "copied_over_duration": additional_semantic_time - - overspecified_time, - "additional_semantic_duration": end_time - additional_semantic_time, "num_workers": self.num_workers, "num_total_testcases": self.input_model.metadata, "num_total_failed": num_total_failed, diff --git a/acto/input/input.py b/acto/input/input.py index f7271f28f0..10a2ead2c1 100644 --- a/acto/input/input.py +++ b/acto/input/input.py @@ -4,10 +4,11 @@ import json import logging import operator +import os import random import threading from functools import reduce -from typing import List, Optional, Tuple +from typing import Optional, Tuple import pydantic import yaml @@ -17,7 +18,7 @@ from acto.input import k8s_schemas, property_attribute from acto.input.get_matched_schemas import find_matched_schema from acto.input.test_generators.generator import get_testcases -from acto.schema import BaseSchema +from acto.schema import BaseSchema, BooleanSchema, IntegerSchema from acto.schema.schema import extract_schema from acto.utils import get_thread_logger @@ -67,7 +68,6 @@ def set_worker_id(self, worker_id: int): @abc.abstractmethod def generate_test_plan( self, - delta_from: Optional[str] = None, focus_fields: Optional[list] = None, ) -> dict: """Generate test plan based on CRD""" @@ -116,7 +116,7 @@ def discard_test_case(self): @abc.abstractmethod def next_test( self, - ) -> Optional[List[Tuple[TestGroup, tuple[str, TestCase]]]]: + ) -> Optional[list[Tuple[TestGroup, tuple[str, TestCase]]]]: """Selects next test case to run from the test plan Instead of random, it selects the next test case from the group. @@ -157,7 +157,9 @@ def __init__( self.example_dir = example_dir example_docs = [] if self.example_dir is not None: - for example_filepath in glob.glob(self.example_dir + "*.yaml"): + for example_filepath in glob.glob( + os.path.join(self.example_dir, "*.yaml") + ): with open( example_filepath, "r", encoding="utf-8" ) as example_file: @@ -165,6 +167,8 @@ def __init__( for doc in docs: example_docs.append(doc) for example_doc in example_docs: + logger = get_thread_logger(with_prefix=True) + logger.info("Loading example document %s", example_doc) self.root_schema.load_examples(example_doc) self.num_workers = num_workers @@ -281,7 +285,6 @@ def set_worker_id(self, worker_id: int): def generate_test_plan( self, - delta_from: Optional[str] = None, focus_fields: Optional[list] = None, ) -> dict: """Generate test plan based on CRD""" @@ -303,6 +306,7 @@ def generate_test_plan( num_semantic_test_cases = 0 num_misoperations = 0 num_pruned_test_cases = 0 + missing_examples = [] for path, test_case_list in test_cases: # First, check if the path is in the focus fields if focus_fields is not None: @@ -314,6 +318,20 @@ def generate_test_plan( if not focused: continue + schema = self.get_schema_by_path(path) + if ( + not isinstance(schema, BooleanSchema) + and not isinstance(schema, IntegerSchema) + and len(schema.examples) == 0 + ): + logger.info("No examples for %s", path) + info = [".".join(path), None if "description" not in schema.raw_schema else schema.raw_schema["description"], + "opaque" if "type" not in schema.raw_schema else schema.raw_schema["type"], + None if "properties" not in schema.raw_schema else schema.raw_schema["properties"], + None if "required" not in schema.raw_schema else schema.raw_schema["required"]] + + missing_examples.append(info) + path_str = ( json.dumps(path) .replace('"ITEM"', "0") @@ -340,6 +358,18 @@ def generate_test_plan( normal_testcases[path_str] = filtered_test_case_list + logger.info( + "There are %d properties that do not have examples", + len(missing_examples), + ) + if self.example_dir is not None: + with open( + os.path.join(self.example_dir, "missing_fields.json"), + "w", + encoding="utf-8", + ) as f: + json.dump(missing_examples, f, indent=2) + self.metadata.total_number_of_test_cases = num_test_cases self.metadata.number_of_run_test_cases = num_run_test_cases self.metadata.number_of_primitive_test_cases = num_pruned_test_cases @@ -347,6 +377,7 @@ def generate_test_plan( self.metadata.number_of_misoperations = num_misoperations self.metadata.number_of_pruned_test_cases = num_pruned_test_cases + logger.info("Got %d schemas to focus on", len(normal_testcases)) logger.info("Generated %d test cases in total", num_test_cases) logger.info("Generated %d test cases to run", num_run_test_cases) logger.info( @@ -395,7 +426,7 @@ def split_into_subgroups( def next_test( self, - ) -> Optional[List[Tuple[TestGroup, tuple[str, TestCase]]]]: + ) -> Optional[list[Tuple[TestGroup, tuple[str, TestCase]]]]: """Selects next test case to run from the test plan Instead of random, it selects the next test case from the group. diff --git a/acto/input/test_generators/primitive.py b/acto/input/test_generators/primitive.py index 2079a53e43..f37ec7552b 100644 --- a/acto/input/test_generators/primitive.py +++ b/acto/input/test_generators/primitive.py @@ -96,11 +96,6 @@ def push_setup(prev): if len(schema.examples) > 0: for example in schema.examples: if len(example) > 1: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) return example if prev is None: return schema.gen(minimum=True) @@ -120,16 +115,9 @@ def pop_mutator(prev): return prev def pop_setup(prev): - logger = get_thread_logger(with_prefix=True) - if len(schema.examples) > 0: for example in schema.examples: if len(example) > 1: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) return example if prev is None: return schema.gen(size=schema.min_items + 1) @@ -142,7 +130,7 @@ def empty_mutator(prev): return [] def empty_setup(prev): - return prev + return schema.gen(size=1) def delete(prev): return schema.empty_value() @@ -155,15 +143,9 @@ def delete_precondition(prev): ) def delete_setup(prev): - logger = get_thread_logger(with_prefix=True) if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) example_without_default = [ - x for x in schema.enum if x != schema.default + x for x in schema.examples if x != schema.default ] if len(example_without_default) > 0: return random.choice(example_without_default) @@ -261,15 +243,9 @@ def delete_precondition(prev): ) def delete_setup(prev): - logger = get_thread_logger(with_prefix=True) if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) example_without_default = [ - x for x in schema.enum if x != schema.default + x for x in schema.examples if x != schema.default ] if len(example_without_default) > 0: return random.choice(example_without_default) @@ -519,22 +495,7 @@ def delete_precondition(prev): ) def delete_setup(prev): - logger = get_thread_logger(with_prefix=True) - if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) - example_without_default = [ - x for x in schema.enum if x != schema.default - ] - if len(example_without_default) > 0: - return random.choice(example_without_default) - else: - return schema.gen(exclude_value=schema.default) - else: - return schema.gen(exclude_value=schema.default) + return schema.gen(exclude_value=schema.default) testcases = [ TestCase( @@ -579,6 +540,7 @@ def object_tests(schema: ObjectSchema): DELETION_TEST = "object-deletion" EMPTY_TEST = "object-empty" + CHANGE_TEST = "object-change" def empty_precondition(prev): return prev != {} @@ -587,7 +549,7 @@ def empty_mutator(prev): return {} def empty_setup(prev): - return prev + return schema.gen(exclude_value=schema.default) def delete(prev): return schema.empty_value() @@ -600,15 +562,9 @@ def delete_precondition(prev): ) def delete_setup(prev): - logger = get_thread_logger(with_prefix=True) if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) example_without_default = [ - x for x in schema.enum if x != schema.default + x for x in schema.examples if x != schema.default ] if len(example_without_default) > 0: return random.choice(example_without_default) @@ -617,6 +573,9 @@ def delete_setup(prev): else: return schema.gen(exclude_value=schema.default) + def change_precondition(prev): + return prev is not None + ret = [ TestCase( DELETION_TEST, @@ -639,13 +598,53 @@ def delete_setup(prev): primitive=True, ) ) + + if schema.examples is not None and len(schema.examples) > 1: + example_list = list(schema.examples) + ret.append( + TestCase( + CHANGE_TEST, + change_precondition, + lambda prev: example_list[1], + lambda prev: example_list[0], + primitive=True, + ) + ) return ret @test_generator(property_type="Opaque", priority=Priority.PRIMITIVE) -def opaque_gen(schema: OpaqueSchema): - """Opaque schema to handle the fields that do not have a schema""" - return [] +def opaque_tests(schema: OpaqueSchema): + """Opaque schema to handle the fields that do not have a schema + + It only generates testcases if there are examples provided + """ + DELETION_TEST = "opaque-deletion" + CHANGE_TEST = "opaque-change" + ret = [] + if schema.examples is not None and len(schema.examples) > 0: + example_list = list(schema.examples) + ret.append( + TestCase( + DELETION_TEST, + lambda prev: prev is not None, + lambda prev: None, + lambda prev: example_list[0], + primitive=True, + ) + ) + + if len(schema.examples) > 1: + ret.append( + TestCase( + CHANGE_TEST, + lambda prev: prev is not None, + lambda prev: example_list[1], + lambda prev: example_list[0], + primitive=True, + ) + ) + return ret @test_generator(property_type="String", priority=Priority.PRIMITIVE) @@ -669,9 +668,11 @@ def change(prev): """Test case to change the value to another one""" logger = get_thread_logger(with_prefix=True) if schema.enum is not None: - logger.fatal( + logger.critical( "String field with enum should not call change to mutate" ) + if schema.examples is not None and len(schema.examples) > 0: + return schema.gen(exclude_value=prev) if schema.pattern is not None: new_string = exrex.getone(schema.pattern, schema.max_length) else: @@ -683,16 +684,7 @@ def change(prev): return new_string def change_setup(prev): - logger = get_thread_logger(with_prefix=True) - if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) - return schema.examples[0] - else: - return schema.gen() + return schema.gen() def empty_precondition(prev): return prev != "" @@ -701,7 +693,7 @@ def empty_mutator(prev): return "" def empty_setup(prev): - return prev + return schema.gen(exclude_value=schema.default) def delete(prev): return schema.empty_value() @@ -716,13 +708,8 @@ def delete_precondition(prev): def delete_setup(prev): logger = get_thread_logger(with_prefix=True) if len(schema.examples) > 0: - logger.info( - "Using example for setting up field [%s]: [%s]", - schema.path, - schema.examples[0], - ) example_without_default = [ - x for x in schema.enum if x != schema.default + x for x in schema.examples if x != schema.default ] if len(example_without_default) > 0: return random.choice(example_without_default) diff --git a/acto/input/test_generators/stateful_set.py b/acto/input/test_generators/stateful_set.py index e41748e794..92c4179154 100644 --- a/acto/input/test_generators/stateful_set.py +++ b/acto/input/test_generators/stateful_set.py @@ -58,7 +58,7 @@ def replicas_tests(schema: IntegerSchema) -> list[TestCase]: invalid=True, semantic=True, ) - return [invalid_test, scale_down_up_test, scale_up_down_test, overload_test] + return [invalid_test, scale_down_up_test, scale_up_down_test] @test_generator( diff --git a/acto/input/value_with_schema.py b/acto/input/value_with_schema.py index c47819100f..545c6c0173 100644 --- a/acto/input/value_with_schema.py +++ b/acto/input/value_with_schema.py @@ -131,9 +131,9 @@ def mutate(self, p_delete=0.05, p_replace=0.1): else: letters = string.ascii_lowercase key = "".join(random.choice(letters) for i in range(5)) - self[ - key - ] = self.schema.get_additional_properties().gen() + self[key] = ( + self.schema.get_additional_properties().gen() + ) else: child_key = random.choice( list(self.schema.get_properties()) @@ -175,20 +175,20 @@ def create_path(self, path: list): """Ensures the path exists""" if len(path) == 0: return - key = path.pop(0) + key = path[0] if self.store is None: self.update(self.schema.gen(minimum=True)) self[key] = None elif key not in self.store: self[key] = None - self.store[key].create_path(path) + self.store[key].create_path(path[1:]) def set_value_by_path(self, value, path): if len(path) == 0: self.update(value) else: - key = path.pop(0) - self.store[key].set_value_by_path(value, path) + key = path[0] + self.store[key].set_value_by_path(value, path[1:]) def __getitem__(self, key): return self.store[key] @@ -206,7 +206,7 @@ def __contains__(self, item: str): class ValueWithArraySchema(ValueWithSchema): """Value with ArraySchema attached""" - def __init__(self, value, schema) -> None: + def __init__(self, value, schema: ArraySchema) -> None: self.schema = schema if value is None: self.store = None @@ -306,7 +306,7 @@ def create_path(self, path: list): """Ensures the path exists""" if len(path) == 0: return - key = path.pop(0) + key = path[0] if self.store is None: self.store = [] for _ in range(0, key): @@ -316,14 +316,14 @@ def create_path(self, path: list): for _ in range(len(self.store), key): self.append(None) self.append(None) - self.store[key].create_path(path) + self.store[key].create_path(path[1:]) def set_value_by_path(self, value, path): if len(path) == 0: self.update(value) else: - key = path.pop(0) - self.store[key].set_value_by_path(value, path) + key = path[0] + self.store[key].set_value_by_path(value, path[1:]) def __getitem__(self, key): return self.store[key] diff --git a/acto/result.py b/acto/result.py index 3527fb9bfd..555db7fa45 100644 --- a/acto/result.py +++ b/acto/result.py @@ -94,6 +94,15 @@ class InvalidInputResult(OracleResult): ) +class DeletionOracleResult(OracleResult): + """Model for the result of a deletion oracle run""" + + message: str = pydantic.Field( + description="The message of the oracle run", + default="Deletion failed", + ) + + class OracleResults(pydantic.BaseModel): """The results of a collection of oracles""" @@ -119,6 +128,10 @@ class OracleResults(pydantic.BaseModel): description="The result of the differential oracle", default=None, ) + deletion: Optional[OracleResult] = pydantic.Field( + description="The result of the deletion oracle", + default=None, + ) custom: Optional[OracleResult] = pydantic.Field( description="The result of the health oracle", default=None, diff --git a/acto/schema/array.py b/acto/schema/array.py index 24b5522902..2595e7c661 100644 --- a/acto/schema/array.py +++ b/acto/schema/array.py @@ -1,5 +1,8 @@ import random -from typing import List, Tuple +from typing import Any, List, Optional, Tuple + +from acto.common import HashableList +from acto.utils.thread_logger import get_thread_logger from .base import BaseSchema, TreeNode @@ -100,10 +103,19 @@ def to_tree(self) -> TreeNode: node.add_child("ITEM", self.item_schema.to_tree()) return node - def load_examples(self, example: list): - self.examples.append(example) - for item in example: - self.item_schema.load_examples(item) + def load_examples(self, example: Optional[List[Any]]): + if example is not None: + logger = get_thread_logger(with_prefix=True) + logger.debug("Loading example %s into %s", example, self.path) + + if isinstance(example, list): + self.examples.add(HashableList(example)) + for item in example: + self.item_schema.load_examples(item) + else: + raise TypeError( + f"Expected example to be of type list, got {type(example)}" + ) def set_default(self, instance): self.default = instance @@ -112,6 +124,10 @@ def empty_value(self): return [] def gen(self, exclude_value=None, minimum: bool = False, **kwargs) -> list: + num = 0 + if "size" in kwargs and kwargs["size"] is not None: + num = kwargs["size"] + if self.enum is not None: if exclude_value is not None: return random.choice( @@ -119,18 +135,23 @@ def gen(self, exclude_value=None, minimum: bool = False, **kwargs) -> list: ) else: return random.choice(self.enum) + + if self.examples and len(self.examples) > 0: + candidates = [ + x for x in self.examples if x != exclude_value and len(x) > num + ] + if candidates: + return random.choice(candidates)[num:] + + # XXX: need to handle exclude_value, but not important for now for array types + result = [] + if minimum: + num = self.min_items else: - # XXX: need to handle exclude_value, but not important for now for array types - result = [] - if "size" in kwargs and kwargs["size"] is not None: - num = kwargs["size"] - elif minimum: - num = self.min_items - else: - num = random.randint(self.min_items, self.max_items) - for _ in range(num): - result.append(self.item_schema.gen(minimum=minimum)) - return result + num = random.randint(self.min_items, self.max_items) + for _ in range(num): + result.append(self.item_schema.gen(minimum=minimum)) + return result def __str__(self) -> str: return "Array" diff --git a/acto/schema/base.py b/acto/schema/base.py index af172eb5af..06a6b4b1a3 100644 --- a/acto/schema/base.py +++ b/acto/schema/base.py @@ -149,7 +149,7 @@ def __init__(self, path: list, schema: dict) -> None: self.raw_schema = schema self.default = None if "default" not in schema else schema["default"] self.enum = None if "enum" not in schema else schema["enum"] - self.examples: list[Any] = [] + self.examples: set[Any] = set() self.attributes = PropertyAttribute(value=0) self.copied_over = False diff --git a/acto/schema/integer.py b/acto/schema/integer.py index beb8c1cbf3..3ea41bdc55 100644 --- a/acto/schema/integer.py +++ b/acto/schema/integer.py @@ -1,5 +1,5 @@ import random -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple from .base import BaseSchema, TreeNode from .number import NumberSchema @@ -28,9 +28,9 @@ def get_normal_semantic_schemas( def to_tree(self) -> TreeNode: return TreeNode(self.path) - def load_examples(self, example: Any): + def load_examples(self, example: Optional[Any]): if isinstance(example, int): - self.examples.append(example) + self.examples.add(example) def set_default(self, instance): self.default = int(instance) diff --git a/acto/schema/number.py b/acto/schema/number.py index 434ff056fd..8de08f536b 100644 --- a/acto/schema/number.py +++ b/acto/schema/number.py @@ -1,5 +1,5 @@ import random -from typing import List, Tuple +from typing import List, Optional, Tuple from .base import BaseSchema, TreeNode @@ -59,8 +59,15 @@ def get_normal_semantic_schemas( def to_tree(self) -> TreeNode: return TreeNode(self.path) - def load_examples(self, example: float): - self.examples.append(example) + def load_examples(self, example: Optional[float]): + if isinstance(example, float): + self.examples.add(example) + elif isinstance(example, int): + self.examples.add(example) + else: + raise TypeError( + f"Expected float, got {type(example)} for property {self.path}" + ) def set_default(self, instance): self.default = float(instance) diff --git a/acto/schema/object.py b/acto/schema/object.py index 664c76ffb0..daf2093312 100644 --- a/acto/schema/object.py +++ b/acto/schema/object.py @@ -1,6 +1,7 @@ import random -from typing import List, Tuple +from typing import List, Optional, Tuple +from acto.common import HashableDict from acto.utils.thread_logger import get_thread_logger from .base import BaseSchema, TreeNode @@ -128,11 +129,20 @@ def to_tree(self) -> TreeNode: return node - def load_examples(self, example: dict): - self.examples.append(example) - for key, value in example.items(): - if key in self.properties: - self.properties[key].load_examples(value) + def load_examples(self, example: Optional[dict]): + if example is not None: + logger = get_thread_logger(with_prefix=True) + logger.debug("Loading example %s into %s", example, self.path) + + if isinstance(example, dict): + self.examples.add(HashableDict(example)) + for key, value in example.items(): + if key in self.properties: + self.properties[key].load_examples(value) + else: + raise TypeError( + f"Example {example} is not a dictionary, cannot load it into {self.path}" + ) def set_default(self, instance): self.default = instance @@ -173,6 +183,16 @@ def gen(self, exclude_value=None, minimum: bool = False, **kwargs): else: return random.choice(self.enum) + if self.examples: + if exclude_value is not None: + example_without_exclude = [ + x for x in self.examples if x != exclude_value + ] + if example_without_exclude: + return random.choice(example_without_exclude) + else: + return random.choice(list(self.examples)) + # XXX: need to handle exclude_value, but not important for now for object types result = {} if len(self.properties) == 0: diff --git a/acto/schema/opaque.py b/acto/schema/opaque.py index 94d24761d4..b11f8f8d76 100644 --- a/acto/schema/opaque.py +++ b/acto/schema/opaque.py @@ -1,4 +1,7 @@ -from typing import List, Tuple +from typing import Any, List, Optional, Tuple + +from acto.common import HashableDict, HashableList +from acto.utils.thread_logger import get_thread_logger from .base import BaseSchema, TreeNode @@ -17,8 +20,17 @@ def get_normal_semantic_schemas( def to_tree(self) -> TreeNode: return TreeNode(self.path) - def load_examples(self, example): - pass + def load_examples(self, example: Optional[Any]): + if example is None: + return + logger = get_thread_logger(with_prefix=True) + logger.debug("Loading example %s into %s", example, self.path) + if isinstance(example, dict): + self.examples.add(HashableDict(example)) + elif isinstance(example, list): + self.examples.add(HashableList(example)) + else: + self.examples.add(example) def set_default(self, instance): self.default = instance diff --git a/acto/schema/string.py b/acto/schema/string.py index bad2e6fa81..fb4e99616c 100644 --- a/acto/schema/string.py +++ b/acto/schema/string.py @@ -1,9 +1,10 @@ import random -from typing import List, Tuple +from typing import List, Optional, Tuple import exrex from acto.common import random_string +from acto.utils.thread_logger import get_thread_logger from .base import BaseSchema, TreeNode @@ -46,8 +47,16 @@ def get_normal_semantic_schemas( def to_tree(self) -> TreeNode: return TreeNode(self.path) - def load_examples(self, example: str): - self.examples.append(example) + def load_examples(self, example: Optional[str]): + if example is not None: + if isinstance(example, str): + logger = get_thread_logger(with_prefix=True) + logger.debug("Loading example %s into %s", example, self.path) + self.examples.add(example) + else: + raise TypeError( + f"Expected string, got {type(example)} for {self.path}" + ) def set_default(self, instance): self.default = str(instance) @@ -55,7 +64,12 @@ def set_default(self, instance): def empty_value(self): return "" - def gen(self, exclude_value=None, minimum: bool = False, **kwargs): + def gen( + self, + exclude_value: Optional[str] = None, + minimum: bool = False, + **kwargs, + ): # TODO: Use minLength: the exrex does not support minLength if self.enum is not None: if exclude_value is not None: @@ -64,12 +78,21 @@ def gen(self, exclude_value=None, minimum: bool = False, **kwargs): ) else: return random.choice(self.enum) + if self.examples: + if exclude_value is not None: + example_without_exclude = [ + x for x in self.examples if x != exclude_value + ] + if len(example_without_exclude) > 0: + return random.choice(example_without_exclude) + else: + return random.choice(list(self.examples)) if self.pattern is not None: - # XXX: since it's random, we don't need to exclude the value + # Since it's random, we don't need to exclude the value return exrex.getone(self.pattern, self.max_length) if minimum: return random_string(self.min_length) # type: ignore - return "ACTOKEY" + return "ACTOSTRING" def __str__(self) -> str: return "String" diff --git a/acto/utils/k8s_helper.py b/acto/utils/k8s_helper.py index 24f39ffd8f..e04e7e56d0 100644 --- a/acto/utils/k8s_helper.py +++ b/acto/utils/k8s_helper.py @@ -3,89 +3,119 @@ import kubernetes import yaml -from kubernetes.client.models import (V1Deployment, V1Namespace, V1ObjectMeta, - V1Pod, V1StatefulSet) +from kubernetes.client.models import ( + V1Deployment, + V1Namespace, + V1ObjectMeta, + V1Pod, + V1StatefulSet, +) from .thread_logger import get_thread_logger def is_pod_ready(pod: V1Pod) -> bool: - '''Check if the pod is ready + """Check if the pod is ready Args: pod: Pod object in kubernetes Returns: if the pod is ready - ''' + """ if pod.status is None or pod.status.conditions is None: return False for condition in pod.status.conditions: - if condition.type == 'Ready' and condition.status == 'True': + if condition.type == "Ready" and condition.status == "True": return True return False def get_deployment_available_status(deployment: V1Deployment) -> bool: - '''Get availability status from deployment condition + """Get availability status from deployment condition Args: deployment: Deployment object in kubernetes Returns: if the deployment is available - ''' + """ if deployment.status is None or deployment.status.conditions is None: return False for condition in deployment.status.conditions: - if condition.type == 'Available' and condition.status == 'True': + if condition.type == "Available" and condition.status == "True": return True return False def get_stateful_set_available_status(stateful_set: V1StatefulSet) -> bool: - '''Get availability status from stateful set condition + """Get availability status from stateful set condition Args: stateful_set: stateful set object in kubernetes Returns: if the stateful set is available - ''' + """ if stateful_set.status is None: return False - if stateful_set.status.replicas > 0 and stateful_set.status.current_replicas == stateful_set.status.replicas: + if ( + stateful_set.status.replicas > 0 + and stateful_set.status.current_replicas == stateful_set.status.replicas + ): return True return False def get_yaml_existing_namespace(fn: str) -> Optional[str]: - '''Get yaml's existing namespace + """Get yaml's existing namespace Args: fn (str): Yaml file path Returns: bool: True if yaml has namespace - ''' - with open(fn, 'r') as operator_yaml: - parsed_operator_documents = yaml.load_all(operator_yaml, - Loader=yaml.FullLoader) + """ + with open(fn, "r", encoding="utf-8") as operator_yaml: + parsed_operator_documents = yaml.load_all( + operator_yaml, Loader=yaml.FullLoader + ) for document in parsed_operator_documents: - if document != None and 'metadata' in document and 'namespace' in document['metadata']: - return document['metadata']['namespace'] + if ( + document is not None + and "metadata" in document + and "namespace" in document["metadata"] + ): + return document["metadata"]["namespace"] + return None + + +def get_deployment_name_from_yaml(fn: str) -> Optional[str]: + """Get deployment name from yaml file""" + with open(fn, "r", encoding="utf-8") as operator_yaml: + parsed_operator_documents = yaml.load_all( + operator_yaml, Loader=yaml.FullLoader + ) + for document in parsed_operator_documents: + if ( + document is not None + and "kind" in document + and document["kind"] == "Deployment" + ): + return document["metadata"]["name"] return None def create_namespace(apiclient, name: str) -> V1Namespace: logger = get_thread_logger(with_prefix=False) - corev1Api = kubernetes.client.CoreV1Api(apiclient) + corev1_api = kubernetes.client.CoreV1Api(apiclient) namespace = None try: - namespace = corev1Api.create_namespace( - V1Namespace(metadata=V1ObjectMeta(name=name))) + namespace = corev1_api.create_namespace( + V1Namespace(metadata=V1ObjectMeta(name=name)) + ) except Exception as e: logger.error(e) return namespace @@ -93,9 +123,9 @@ def create_namespace(apiclient, name: str) -> V1Namespace: def delete_namespace(apiclient, name: str) -> bool: logger = get_thread_logger(with_prefix=False) - corev1Api = kubernetes.client.CoreV1Api(apiclient) + corev1_api = kubernetes.client.CoreV1Api(apiclient) try: - corev1Api.delete_namespace(name=name) + corev1_api.delete_namespace(name=name) except Exception as e: logger.error(e) return False @@ -105,22 +135,25 @@ def delete_namespace(apiclient, name: str) -> bool: def delete_operator_pod(apiclient, namespace: str) -> bool: logger = get_thread_logger(with_prefix=False) - coreV1Api = kubernetes.client.CoreV1Api(apiclient) + corev1_api = kubernetes.client.CoreV1Api(apiclient) - operator_pod_list = coreV1Api.list_namespaced_pod( - namespace=namespace, watch=False, label_selector="acto/tag=operator-pod").items + operator_pod_list = corev1_api.list_namespaced_pod( + namespace=namespace, watch=False, label_selector="acto/tag=operator-pod" + ).items if len(operator_pod_list) >= 1: - logger.debug('Got operator pod: pod name:' + - operator_pod_list[0].metadata.name) + logger.debug( + "Got operator pod: pod name: %s", operator_pod_list[0].metadata.name + ) else: - logger.error('Failed to find operator pod') + logger.error("Failed to find operator pod") return False # TODO: refine what should be done if no operator pod can be found - pod = coreV1Api.delete_namespaced_pod(name=operator_pod_list[0].metadata.name, - namespace=namespace) - if pod == None: + pod = corev1_api.delete_namespaced_pod( + name=operator_pod_list[0].metadata.name, namespace=namespace + ) + if pod is None: return False else: time.sleep(10) diff --git a/acto/utils/preprocess.py b/acto/utils/preprocess.py index 77f233caab..1f05d6880e 100644 --- a/acto/utils/preprocess.py +++ b/acto/utils/preprocess.py @@ -12,45 +12,9 @@ from .thread_logger import get_thread_logger -def update_preload_images(context: dict, worker_list): - """Get used images from pod""" - logger = get_thread_logger(with_prefix=False) - - namespace = context.get("namespace", "") - if not namespace: - return - - # block list when getting the operator specific images - k8s_images = [ - "docker.io/kindest/kindnetd", - "docker.io/rancher/local-path-provisioner", - "docker.io/kindest/local-path-provisioner", - "docker.io/kindest/local-path-helper", - "k8s.gcr.io/build-image/debian-base", - "k8s.gcr.io/coredns/coredns", - "k8s.gcr.io/etcd", - "k8s.gcr.io/kube-apiserver", - "k8s.gcr.io/kube-controller-manager", - "k8s.gcr.io/kube-proxy", - "k8s.gcr.io/kube-scheduler", - "k8s.gcr.io/pause", - "docker.io/rancher/klipper-helm", - "docker.io/rancher/klipper-lb", - "docker.io/rancher/mirrored-coredns-coredns", - "docker.io/rancher/mirrored-library-busybox", - "docker.io/rancher/mirrored-library-traefik", - "docker.io/rancher/mirrored-metrics-server", - "docker.io/rancher/mirrored-paus", - # new k8s images - "registry.k8s.io/etcd", - "registry.k8s.io/kube-controller-manager", - "registry.k8s.io/pause", - "registry.k8s.io/kube-proxy", - "registry.k8s.io/coredns/coredns", - "registry.k8s.io/kube-apiserver", - "registry.k8s.io/kube-scheduler", - ] - +def get_existing_images(worker_list: list[str]) -> set[str]: + """Get existing images from pods""" + existing_images = set() for worker in worker_list: p = subprocess.run( [ @@ -69,24 +33,17 @@ def update_preload_images(context: dict, worker_list): output = p.stdout.strip() for line in output.split("\n")[1:]: items = line.split() - if items[0] in k8s_images: - continue if "none" not in items[1]: image = f"{items[0]}:{items[1]}" - else: - logger.warning( - "image %s has no tag, Acto will not preload this image for this run", - items[0], - ) - continue - - context["preload_images"].add(image) + existing_images.add(image) + return existing_images def process_crd( apiclient: kubernetes.client.ApiClient, kubectl_client: KubectlClient, crd_name: Optional[str] = None, + crd_version: Optional[str] = None, helper_crd: Optional[str] = None, ) -> dict: """Get crd from k8s and set context['crd'] @@ -106,9 +63,9 @@ def process_crd( if helper_crd is None: apiextensions_v1_api = kubernetes.client.ApiextensionsV1Api(apiclient) - crds: list[ - k8s_models.V1CustomResourceDefinition - ] = apiextensions_v1_api.list_custom_resource_definition().items + crds: list[k8s_models.V1CustomResourceDefinition] = ( + apiextensions_v1_api.list_custom_resource_definition().items + ) crd: Optional[k8s_models.V1CustomResourceDefinition] = None if len(crds) == 0: logger.error("No crd is found") @@ -139,8 +96,11 @@ def process_crd( crd_data = { "group": spec.group, "plural": spec.names.plural, - # TODO: Handle multiple versions - "version": spec.versions[0].name, + "version": ( + spec.versions[-1].name + if crd_version is None + else crd_version + ), "body": crd_obj, } return crd_data @@ -149,13 +109,25 @@ def process_crd( sys.exit(1) else: with open(helper_crd, "r", encoding="utf-8") as helper_crd_f: - helper_crd_doc = yaml.load(helper_crd_f, Loader=yaml.FullLoader) + helper_crd_docs = list(yaml.safe_load_all(helper_crd_f)) + + + if crd_name: + for doc in helper_crd_docs: + if doc["metadata"]["name"] == crd_name: + helper_crd_doc = doc + break + else: + helper_crd_doc = helper_crd_docs[0] + crd_data = { "group": helper_crd_doc["spec"]["group"], "plural": helper_crd_doc["spec"]["names"]["plural"], - "version": helper_crd_doc["spec"]["versions"][-1][ - "name" - ], # TODO: Handle multiple versions + "version": ( + helper_crd_doc["spec"]["versions"][-1]["name"] + if crd_version is None + else crd_version + ), "body": helper_crd_doc, } return crd_data diff --git a/acto/utils/process_with_except.py b/acto/utils/process_with_except.py index df125378e4..6b6aa25f01 100644 --- a/acto/utils/process_with_except.py +++ b/acto/utils/process_with_except.py @@ -4,10 +4,10 @@ class MyProcess(Process): - '''Process class with excepthook''' + """Process class with excepthook""" def run(self): try: super().run() except Exception: - excepthook(*sys.exc_info()) \ No newline at end of file + excepthook(*sys.exc_info()) diff --git a/acto/utils/thread_logger.py b/acto/utils/thread_logger.py index 2a507c90f3..c567db9f1b 100644 --- a/acto/utils/thread_logger.py +++ b/acto/utils/thread_logger.py @@ -1,31 +1,37 @@ import logging import threading -from typing import Tuple +from typing import Union class PrefixLoggerAdapter(logging.LoggerAdapter): - """ A logger adapter that adds a prefix to every message """ - def process(self, msg: str, kwargs: dict) -> Tuple[str, dict]: - return (f'[{self.extra["prefix"]}] {msg}', kwargs) + """A logger adapter that adds a prefix to every message""" + + def process(self, msg: str, kwargs) -> tuple[str, dict]: + """Add the prefix to the message""" + if self.extra is not None and "prefix" in self.extra: + return (f'[{self.extra["prefix"]}] {msg}', kwargs) + return (msg, kwargs) + logger_prefix = threading.local() + def set_thread_logger_prefix(prefix: str) -> None: - ''' - Store the prefix in the thread local storag, + """ + Store the prefix in the thread local storag, invoke get_thread_logger_with_prefix to get the updated logger - ''' + """ logger_prefix.prefix = prefix -def get_thread_logger(with_prefix: bool) -> logging.LoggerAdapter: - '''Get the logger with the prefix from the thread local storage''' + +def get_thread_logger( + with_prefix: bool = True, +) -> Union[logging.LoggerAdapter, logging.Logger]: + """Get the logger with the prefix from the thread local storage""" logger = logging.getLogger(threading.current_thread().name) logger.setLevel(logging.DEBUG) # if the prefix is not set, return the original logger - if not with_prefix or not hasattr(logger_prefix, 'prefix'): + if not with_prefix or not hasattr(logger_prefix, "prefix"): return logger - return PrefixLoggerAdapter(logger, extra={'prefix': logger_prefix.prefix}) - - - + return PrefixLoggerAdapter(logger, extra={"prefix": logger_prefix.prefix}) diff --git a/scripts/crawl_examples.py b/scripts/crawl_examples.py index f862fb9a8a..5446a9fb96 100644 --- a/scripts/crawl_examples.py +++ b/scripts/crawl_examples.py @@ -1,12 +1,21 @@ +"""Crawls yaml examples from target repo as testing material""" + import argparse -import glob, os -import yaml +import glob +import os -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Crawl CR examples in the project repo') +import yaml - parser.add_argument('--dir', dest='dir', help='Project repo dir', required=True) - parser.add_argument('--kind', '-k', dest='kind', help='CR kind', required=True) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Crawl CR examples in the project repo" + ) + parser.add_argument( + "--dir", dest="dir", help="Project repo dir", required=True + ) + parser.add_argument( + "--kind", "-k", dest="kind", help="CR kind", required=False + ) # parser.add_argument('--dest', # dest='dest', # help='Directory to store the crawlled examples', @@ -14,18 +23,52 @@ args = parser.parse_args() - results = [] + main_results = [] + aux_results = [] - for file in glob.glob(os.path.join(args.dir, '**', '*.yaml'), recursive=True): - with open(file, 'r') as yaml_file: + for file in glob.glob( + os.path.join(args.dir, "**", "*.yaml"), recursive=True + ): + with open(file, "r", encoding="utf-8") as yaml_file: try: file_content = yaml.load(yaml_file, Loader=yaml.FullLoader) - except: + except yaml.YAMLError as e: continue - if 'kind' in file_content and file_content['kind'] == args.kind: - print(file) - results.append(file_content) + if "kind" in file_content: + try: + if file_content["namespace"] != "": + file_content.pop("namespace", None) + except KeyError as e: + print("No namespace, no op") + + try: + if file_content["metadata"]["namespace"] != "": + file_content["metadata"].pop("namespace", None) + except KeyError as e: + print("No metadata.namespace, no op") + + try: + if file_content["spec"]["datacenter"]["namespace"] != "": + file_content["spec"]["datacenter"].pop( + "namespace", None + ) + except KeyError as e: + print("No spec.datacenter.namespace, no op") - with open('examples.yaml', 'w') as out_file: - yaml.dump_all(results, out_file) + if file_content["kind"] == "Deployment": + continue + elif file_content["kind"] == args.kind: + print(file) + main_results.append(file_content) + elif file_content["kind"] == "StorageClass": + print(file) + file_content["provisioner"] = "rancher.io/local-path" + aux_results.append(file_content) + elif file_content["kind"] != "CustomResourceDefinition": + print(file) + aux_results.append(file_content) + with open("examples.yaml", "w", encoding="utf-8") as out_file: + yaml.dump_all(main_results, out_file) + with open("aux-examples.yaml", "w", encoding="utf-8") as out_file: + yaml.dump_all(aux_results, out_file)