diff --git a/acto/input/input.py b/acto/input/input.py index 3f2b406ae5..920c0c46b3 100644 --- a/acto/input/input.py +++ b/acto/input/input.py @@ -251,6 +251,11 @@ def __init__( ) ) + for base_schema, k8s_schema_name in self.full_matched_schemas: + base_schema.attributes |= ( + property_attribute.PropertyAttribute.Mapped + ) + # Apply custom property attributes based on the property_attribute module self.apply_custom_field() diff --git a/acto/post_process/post_diff_test.py b/acto/post_process/post_diff_test.py index 3a61ddbf35..52a41e99cb 100644 --- a/acto/post_process/post_diff_test.py +++ b/acto/post_process/post_diff_test.py @@ -725,7 +725,9 @@ def check(self, workdir: str, num_workers: int = 1): diff_test_result_path ) if diff_test_result.input_digest == seed_input_digest: - diff_skip_regex = self.__get_diff_paths(diff_test_result) + diff_skip_regex = self.__get_diff_paths( + diff_test_result, num_workers + ) logger.info( "Seed input digest: %s, diff_skip_regex: %s", seed_input_digest, @@ -920,7 +922,9 @@ def check_diff_test_step( to_state=diff_test_result.snapshot.system_state, ) - def __get_diff_paths(self, diff_test_result: DiffTestResult) -> list[str]: + def __get_diff_paths( + self, diff_test_result: DiffTestResult, num_workers: int + ) -> list[str]: """Get the diff paths from a diff test result Algorithm: Iterate on the original trials, in principle they should be the same @@ -946,9 +950,9 @@ def __get_diff_paths(self, diff_test_result: DiffTestResult) -> list[str]: list[str]: The list of diff paths """ - initial_regex: set[str] = set() indeterministic_regex: set[str] = set() - first_step = True + + args = [] for original in diff_test_result.originals: trial = original["trial"] gen = original["gen"] @@ -956,27 +960,29 @@ def __get_diff_paths(self, diff_test_result: DiffTestResult) -> list[str]: original_result = self.trial_to_steps[trial_basename].steps[ str(gen) ] + args.append((diff_test_result, original_result, self.config)) + + with multiprocessing.Pool(num_workers) as pool: + diff_results = pool.map(self.check_diff_test_step, args) + diff_result = self.check_diff_test_step( diff_test_result, original_result, self.config ) - if diff_result is not None: - for diff in diff_result.diff.values(): - if not isinstance(diff, list): - continue - for diff_item in diff: - if not isinstance(diff_item, DiffLevel): + for diff_result in diff_results: + if diff_result is not None: + for diff in diff_result.diff.values(): + if not isinstance(diff, list): continue - if first_step: - initial_regex.add(re.escape(diff_item.path())) - else: + for diff_item in diff: + if not isinstance(diff_item, DiffLevel): + continue indeterministic_regex.add(diff_item.path()) - first_step = False # Handle the case where the name is not deterministic common_regex = compute_common_regex(list(indeterministic_regex)) - return list(initial_regex) + common_regex + return common_regex def main(): diff --git a/acto/reproduce.py b/acto/reproduce.py index 9544ad21a1..791f26418d 100644 --- a/acto/reproduce.py +++ b/acto/reproduce.py @@ -1,5 +1,6 @@ import argparse import functools +import importlib import json import logging import os @@ -13,14 +14,16 @@ from acto import DEFAULT_KUBERNETES_VERSION from acto.engine import Acto -from acto.input.input import DeterministicInputModel +from acto.input import k8s_schemas, property_attribute +from acto.input.input import CustomKubernetesMapping, DeterministicInputModel from acto.input.testcase import TestCase from acto.input.testplan import TestGroup from acto.input.value_with_schema import ValueWithSchema -from acto.input.valuegenerator import extract_schema_with_value_generator from acto.lib.operator_config import OperatorConfig from acto.post_process.post_diff_test import PostDiffTest from acto.result import OracleResults +from acto.schema.base import BaseSchema +from acto.schema.schema import extract_schema from acto.utils import get_thread_logger @@ -81,9 +84,7 @@ def __init__( custom_module_path: Optional[str] = None, ) -> None: logger = get_thread_logger(with_prefix=True) - # WARNING: Not sure the initialization is correct - # TODO: The line below need to be reviewed. - self.root_schema = extract_schema_with_value_generator( + self.root_schema = extract_schema( [], crd["spec"]["versions"][-1]["schema"]["openAPIV3Schema"] ) self.testcases = [] @@ -107,6 +108,69 @@ def __init__( self.num_workers = 1 self.metadata = {} + override_matches: Optional[list[tuple[BaseSchema, str]]] = None + if custom_module_path is not None: + custom_module = importlib.import_module(custom_module_path) + + # We need to do very careful sanitization here because we are + # loading user-provided module + if hasattr(custom_module, "KUBERNETES_TYPE_MAPPING"): + custum_kubernetes_type_mapping = ( + custom_module.KUBERNETES_TYPE_MAPPING + ) + if isinstance(custum_kubernetes_type_mapping, list): + override_matches = [] + for custom_mapping in custum_kubernetes_type_mapping: + if isinstance(custom_mapping, CustomKubernetesMapping): + try: + schema = self.get_schema_by_path( + custom_mapping.schema_path + ) + except KeyError as exc: + raise RuntimeError( + "Schema path of the custom mapping is invalid: " + f"{custom_mapping.schema_path}" + ) from exc + + override_matches.append( + (schema, custom_mapping.kubernetes_schema_name) + ) + else: + raise TypeError( + "Expected CustomKubernetesMapping in KUBERNETES_TYPE_MAPPING, " + f"but got {type(custom_mapping)}" + ) + + # Do the matching from CRD to Kubernetes schemas + # Match the Kubernetes schemas to subproperties of the root schema + kubernetes_schema_matcher = k8s_schemas.K8sSchemaMatcher.from_version( + kubernetes_version, None + ) + top_matched_schemas = ( + kubernetes_schema_matcher.find_top_level_matched_schemas( + self.root_schema + ) + ) + for base_schema, k8s_schema_name in top_matched_schemas: + logging.info( + "Matched schema %s to k8s schema %s", + base_schema.get_path(), + k8s_schema_name, + ) + self.full_matched_schemas = ( + kubernetes_schema_matcher.expand_top_level_matched_schemas( + top_matched_schemas + ) + ) + + for base_schema, k8s_schema_name in self.full_matched_schemas: + base_schema.attributes |= ( + property_attribute.PropertyAttribute.Mapped + ) + + # Apply custom property attributes based on the property_attribute module + self.apply_custom_field() + def initialize(self, initial_value: dict): """Override"""