diff --git a/acto/post_process/post_diff_test.py b/acto/post_process/post_diff_test.py index 52a41e99cb..baabe98ac4 100644 --- a/acto/post_process/post_diff_test.py +++ b/acto/post_process/post_diff_test.py @@ -2,6 +2,7 @@ import difflib import glob import hashlib +import itertools import json import logging import multiprocessing @@ -960,24 +961,13 @@ def __get_diff_paths( original_result = self.trial_to_steps[trial_basename].steps[ str(gen) ] - args.append((diff_test_result, original_result, self.config)) + args.append([diff_test_result, original_result, self.config]) with multiprocessing.Pool(num_workers) as pool: - diff_results = pool.map(self.check_diff_test_step, args) + diff_results = pool.starmap(__get_diff_paths_helper, args) - diff_result = self.check_diff_test_step( - diff_test_result, original_result, self.config - ) - - for diff_result in diff_results: - if diff_result is not None: - for diff in diff_result.diff.values(): - if not isinstance(diff, list): - continue - for diff_item in diff: - if not isinstance(diff_item, DiffLevel): - continue - indeterministic_regex.add(diff_item.path()) + for diff_item in itertools.chain.from_iterable(diff_results): + indeterministic_regex.add(diff_item) # Handle the case where the name is not deterministic common_regex = compute_common_regex(list(indeterministic_regex)) @@ -985,6 +975,26 @@ def __get_diff_paths( return common_regex +def __get_diff_paths_helper( + diff_test_result: DiffTestResult, + original_result: Step, + config: OperatorConfig, +) -> list[str]: + diff_result = PostDiffTest.check_diff_test_step( + diff_test_result, original_result, config + ) + indeterministic_regex = set() + if diff_result is not None: + for diff in diff_result.diff.values(): + if not isinstance(diff, list): + continue + for diff_item in diff: + if not isinstance(diff_item, DiffLevel): + continue + indeterministic_regex.add(diff_item.path()) + return list(indeterministic_regex) + + def main(): """Main entry point.""" parser = argparse.ArgumentParser()