Skip to content

Commit

Permalink
specify return types of static and dynamic execution more precisely
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-luecke committed Jan 12, 2024
1 parent 0c8136b commit eed75e9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 22 deletions.
13 changes: 9 additions & 4 deletions xdsl_pdl/analysis/mlir_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class MLIRNoMatch(Exception):
error_msg: str


@dataclass
class MLIRSuccess:
pass


def run_with_mlir(
program: Operation, pattern: PatternOp, mlir_executable_path: str
) -> str:
Expand Down Expand Up @@ -70,16 +75,16 @@ def run_with_mlir(
)
except subprocess.TimeoutExpired:
raise MLIRInfiniteLoop(mlir_input.getvalue())
if "RecordMatch" not in res.stderr:
raise MLIRNoMatch(mlir_input.getvalue(), res.stderr)
if res.returncode != 0:
raise MLIRFailure(mlir_input.getvalue(), res.stderr)
if "RecordMatch" not in res.stderr:
raise MLIRNoMatch(mlir_input.getvalue(), res.stderr)
return res.stdout


def analyze_with_mlir(
pattern: PatternOp, ctx: MLContext, randgen: Random, mlir_executable_path: str
) -> MLIRFailure | MLIRInfiniteLoop | MLIRNoMatch | None:
) -> MLIRFailure | MLIRInfiniteLoop | MLIRNoMatch | MLIRSuccess:
"""
Run the pattern on multiple examples with MLIR.
If MLIR returns an error in any of the examples, returns the error.
Expand All @@ -105,4 +110,4 @@ def analyze_with_mlir(
return e
except MLIRNoMatch as e:
return e
return None
return MLIRSuccess()
5 changes: 3 additions & 2 deletions xdsl_pdl/analysis/pdl_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,9 @@ def _analyze_rhs_op(self, rhs_op: Operation) -> AnalyzedPDLOperation | None:
assert isinstance(rhs_op.op_value, OpResult)
assert isinstance(rhs_op.op_value.op, pdl.OperationOp)
if (analyzed_op := self.get_analysis(rhs_op.op_value.op)) is None:
raise PDLAnalysisException(
rhs_op, "Unknown pdl.Operation to be erased!"
raise PDLAnalysisAborted(
rhs_op,
"pdl.Operation to be erased is not part of the matching DAG!",
)
analyzed_op.erased_by = rhs_op
else:
Expand Down
3 changes: 2 additions & 1 deletion xdsl_pdl/tools/generate_pdl_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from xdsl_pdl.analysis.mlir_analysis import (
MLIRFailure,
MLIRSuccess,
analyze_with_mlir,
)

Expand Down Expand Up @@ -57,7 +58,7 @@ def fuzz_pdl_matches(
mlir_analysis = analyze_with_mlir(
module.ops.first, ctx, Random(seed), mlir_executable_path
)
if mlir_analysis is None:
if isinstance(mlir_analysis, MLIRSuccess):
print("MLIR analysis succeeded")
else:
print("MLIR analysis failed")
Expand Down
45 changes: 30 additions & 15 deletions xdsl_pdl/tools/generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@
from xdsl.dialects.pdl import (
PatternOp,
)
from xdsl_pdl.analysis.pdl_analysis import PDLAnalysisAborted, pdl_analysis_pass
from xdsl_pdl.analysis.pdl_analysis import (
PDLAnalysisAborted,
PDLAnalysisException,
pdl_analysis_pass,
)
from xdsl_pdl.analysis.mlir_analysis import (
MLIRFailure,
MLIRInfiniteLoop,
MLIRNoMatch,
MLIRSuccess,
analyze_with_mlir,
)

Expand All @@ -28,7 +35,9 @@

def fuzz_pdl_matches(
module: ModuleOp, ctx: MLContext, randgen: Random, mlir_executable_path: str
) -> tuple[bool, bool] | None:
) -> tuple[
bool | Exception, MLIRNoMatch | MLIRSuccess | MLIRFailure | MLIRInfiniteLoop
]:
"""
Returns the result of the PDL analysis, and the result of the analysis using
program fuzzing and MLIR.
Expand All @@ -37,20 +46,17 @@ def fuzz_pdl_matches(
raise Exception("Expected a single toplevel pattern op")

# Check if the pattern is valid
analysis_correct = True
analysis_correct: bool | Exception = True
try:
pdl_analysis_pass(ctx, module)
except PDLAnalysisAborted:
analysis_correct = False
except Exception:
return None
except Exception as e:
analysis_correct = e

mlir_analysis = analyze_with_mlir(
module.ops.first, ctx, randgen, mlir_executable_path
)
if isinstance(mlir_analysis, MLIRNoMatch):
return None
return analysis_correct, mlir_analysis is None

return analysis_correct, mlir_analysis


class GenerateTableMain(xDSLOptMain):
Expand All @@ -62,7 +68,8 @@ def __init__(self):
super().__init__()
self.ctx.allow_unregistered = True
self.num_tested = 0
self.failed_analyses = []
self.failed_analyses: list[int] = []
self.no_mlir_matches: list[int] = []
self.values = (([], []), ([], []))

def register_all_dialects(self):
Expand All @@ -85,10 +92,15 @@ def run_one_thread(self, seed: int):
)
self.num_tested += 1
print(f"Tested {self.num_tested} patterns", end="\r")
if test_res is None:
if isinstance(test_res[0], PDLAnalysisException):
self.failed_analyses.append(seed)
return
self.values[int(test_res[0])][int(test_res[1])].append(seed)

if isinstance(test_res[1], MLIRNoMatch):
self.no_mlir_matches.append(seed)

self.values[int(isinstance(test_res[0], bool) and bool(test_res[0]))][
int(isinstance(test_res[1], MLIRSuccess))
].append(seed)

def run(self):
randgen = Random()
Expand All @@ -112,12 +124,15 @@ def run(self):
print(
f"PDL Analysis raised an exception: {len(self.failed_analyses)}: {self.failed_analyses} \n"
)
print(
f"No MLIR matches generated: {len(self.no_mlir_matches)}: {self.no_mlir_matches} \n"
)

print(
f"Total: s fail d fail, s succ d succ, s fail d succ, s succ d fail, failed analyses"
)
print(
f"occurences: {len(self.values[0][0])}, {len(self.values[1][1])}, {len(self.values[0][1])},{len(self.values[1][0])},{len(self.failed_analyses)}"
f"categories: {len(self.values[0][0])}, {len(self.values[1][1])}, {len(self.values[0][1])},{len(self.values[1][0])},{len(self.failed_analyses)}"
)

print_results(
Expand Down

0 comments on commit eed75e9

Please sign in to comment.