From eed75e9f4629c5dad0bc4844abfd781ee296d2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lu=CC=88cke?= Date: Fri, 12 Jan 2024 18:31:53 +0100 Subject: [PATCH] specify return types of static and dynamic execution more precisely --- xdsl_pdl/analysis/mlir_analysis.py | 13 +++++--- xdsl_pdl/analysis/pdl_analysis.py | 5 +-- xdsl_pdl/tools/generate_pdl_matches.py | 3 +- xdsl_pdl/tools/generate_table.py | 45 +++++++++++++++++--------- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/xdsl_pdl/analysis/mlir_analysis.py b/xdsl_pdl/analysis/mlir_analysis.py index adc6de6..795f3c6 100644 --- a/xdsl_pdl/analysis/mlir_analysis.py +++ b/xdsl_pdl/analysis/mlir_analysis.py @@ -30,6 +30,11 @@ class MLIRNoMatch(Exception): error_msg: str +@dataclass +class MLIRSuccess: + pass + + def run_with_mlir( program: Operation, pattern: PatternOp, mlir_executable_path: str ) -> str: @@ -70,16 +75,16 @@ def run_with_mlir( ) except subprocess.TimeoutExpired: raise MLIRInfiniteLoop(mlir_input.getvalue()) - if "RecordMatch" not in res.stderr: - raise MLIRNoMatch(mlir_input.getvalue(), res.stderr) if res.returncode != 0: raise MLIRFailure(mlir_input.getvalue(), res.stderr) + if "RecordMatch" not in res.stderr: + raise MLIRNoMatch(mlir_input.getvalue(), res.stderr) return res.stdout def analyze_with_mlir( pattern: PatternOp, ctx: MLContext, randgen: Random, mlir_executable_path: str -) -> MLIRFailure | MLIRInfiniteLoop | MLIRNoMatch | None: +) -> MLIRFailure | MLIRInfiniteLoop | MLIRNoMatch | MLIRSuccess: """ Run the pattern on multiple examples with MLIR. If MLIR returns an error in any of the examples, returns the error. @@ -105,4 +110,4 @@ def analyze_with_mlir( return e except MLIRNoMatch as e: return e - return None + return MLIRSuccess() diff --git a/xdsl_pdl/analysis/pdl_analysis.py b/xdsl_pdl/analysis/pdl_analysis.py index 1996f4e..8ba9549 100644 --- a/xdsl_pdl/analysis/pdl_analysis.py +++ b/xdsl_pdl/analysis/pdl_analysis.py @@ -541,8 +541,9 @@ def _analyze_rhs_op(self, rhs_op: Operation) -> AnalyzedPDLOperation | None: assert isinstance(rhs_op.op_value, OpResult) assert isinstance(rhs_op.op_value.op, pdl.OperationOp) if (analyzed_op := self.get_analysis(rhs_op.op_value.op)) is None: - raise PDLAnalysisException( - rhs_op, "Unknown pdl.Operation to be erased!" + raise PDLAnalysisAborted( + rhs_op, + "pdl.Operation to be erased is not part of the matching DAG!", ) analyzed_op.erased_by = rhs_op else: diff --git a/xdsl_pdl/tools/generate_pdl_matches.py b/xdsl_pdl/tools/generate_pdl_matches.py index 892a02e..db96a52 100644 --- a/xdsl_pdl/tools/generate_pdl_matches.py +++ b/xdsl_pdl/tools/generate_pdl_matches.py @@ -22,6 +22,7 @@ ) from xdsl_pdl.analysis.mlir_analysis import ( MLIRFailure, + MLIRSuccess, analyze_with_mlir, ) @@ -57,7 +58,7 @@ def fuzz_pdl_matches( mlir_analysis = analyze_with_mlir( module.ops.first, ctx, Random(seed), mlir_executable_path ) - if mlir_analysis is None: + if isinstance(mlir_analysis, MLIRSuccess): print("MLIR analysis succeeded") else: print("MLIR analysis failed") diff --git a/xdsl_pdl/tools/generate_table.py b/xdsl_pdl/tools/generate_table.py index 28fb37f..f41f366 100644 --- a/xdsl_pdl/tools/generate_table.py +++ b/xdsl_pdl/tools/generate_table.py @@ -16,9 +16,16 @@ from xdsl.dialects.pdl import ( PatternOp, ) -from xdsl_pdl.analysis.pdl_analysis import PDLAnalysisAborted, pdl_analysis_pass +from xdsl_pdl.analysis.pdl_analysis import ( + PDLAnalysisAborted, + PDLAnalysisException, + pdl_analysis_pass, +) from xdsl_pdl.analysis.mlir_analysis import ( + MLIRFailure, + MLIRInfiniteLoop, MLIRNoMatch, + MLIRSuccess, analyze_with_mlir, ) @@ -28,7 +35,9 @@ def fuzz_pdl_matches( module: ModuleOp, ctx: MLContext, randgen: Random, mlir_executable_path: str -) -> tuple[bool, bool] | None: +) -> tuple[ + bool | Exception, MLIRNoMatch | MLIRSuccess | MLIRFailure | MLIRInfiniteLoop +]: """ Returns the result of the PDL analysis, and the result of the analysis using program fuzzing and MLIR. @@ -37,20 +46,17 @@ def fuzz_pdl_matches( raise Exception("Expected a single toplevel pattern op") # Check if the pattern is valid - analysis_correct = True + analysis_correct: bool | Exception = True try: pdl_analysis_pass(ctx, module) - except PDLAnalysisAborted: - analysis_correct = False - except Exception: - return None + except Exception as e: + analysis_correct = e mlir_analysis = analyze_with_mlir( module.ops.first, ctx, randgen, mlir_executable_path ) - if isinstance(mlir_analysis, MLIRNoMatch): - return None - return analysis_correct, mlir_analysis is None + + return analysis_correct, mlir_analysis class GenerateTableMain(xDSLOptMain): @@ -62,7 +68,8 @@ def __init__(self): super().__init__() self.ctx.allow_unregistered = True self.num_tested = 0 - self.failed_analyses = [] + self.failed_analyses: list[int] = [] + self.no_mlir_matches: list[int] = [] self.values = (([], []), ([], [])) def register_all_dialects(self): @@ -85,10 +92,15 @@ def run_one_thread(self, seed: int): ) self.num_tested += 1 print(f"Tested {self.num_tested} patterns", end="\r") - if test_res is None: + if isinstance(test_res[0], PDLAnalysisException): self.failed_analyses.append(seed) - return - self.values[int(test_res[0])][int(test_res[1])].append(seed) + + if isinstance(test_res[1], MLIRNoMatch): + self.no_mlir_matches.append(seed) + + self.values[int(isinstance(test_res[0], bool) and bool(test_res[0]))][ + int(isinstance(test_res[1], MLIRSuccess)) + ].append(seed) def run(self): randgen = Random() @@ -112,12 +124,15 @@ def run(self): print( f"PDL Analysis raised an exception: {len(self.failed_analyses)}: {self.failed_analyses} \n" ) + print( + f"No MLIR matches generated: {len(self.no_mlir_matches)}: {self.no_mlir_matches} \n" + ) print( f"Total: s fail d fail, s succ d succ, s fail d succ, s succ d fail, failed analyses" ) print( - f"occurences: {len(self.values[0][0])}, {len(self.values[1][1])}, {len(self.values[0][1])},{len(self.values[1][0])},{len(self.failed_analyses)}" + f"categories: {len(self.values[0][0])}, {len(self.values[1][1])}, {len(self.values[0][1])},{len(self.values[1][0])},{len(self.failed_analyses)}" ) print_results(