diff --git a/xdsl_pdl/dialects/irdl_extension.py b/xdsl_pdl/dialects/irdl_extension.py index 8c14954..5bd3cad 100644 --- a/xdsl_pdl/dialects/irdl_extension.py +++ b/xdsl_pdl/dialects/irdl_extension.py @@ -2,13 +2,13 @@ from typing import Sequence from xdsl.irdl import ( - AttrSizedOperandSegments, IRDLOperation, irdl_op_definition, + operand_def, region_def, var_operand_def, ) -from xdsl.ir import Dialect, Region, SSAValue +from xdsl.ir import Dialect, IsTerminator, Region, SSAValue from xdsl.dialects.irdl import AttributeType from xdsl.parser import DictionaryAttr, Parser @@ -43,12 +43,33 @@ def print(self, printer: Printer) -> None: printer.print_op_attributes(self.attributes) +@irdl_op_definition +class MatchOp(IRDLOperation): + name = "irdl_ext.match" + + arg = operand_def(AttributeType()) + + assembly_format = "attr-dict $arg" + + def __init__( + self, + arg: SSAValue, + attr_dict: DictionaryAttr | None = None, + ): + super().__init__( + operands=[arg], + attributes=attr_dict.data if attr_dict is not None else None, + ) + + @irdl_op_definition class YieldOp(IRDLOperation): name = "irdl_ext.yield" args = var_operand_def(AttributeType()) + traits = frozenset({IsTerminator()}) + assembly_format = "attr-dict $args" def __init__( @@ -81,4 +102,4 @@ def __init__( ) -IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp, EqOp]) +IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, MatchOp, YieldOp, EqOp]) diff --git a/xdsl_pdl/passes/optimize_irdl.py b/xdsl_pdl/passes/optimize_irdl.py index c2b4e11..1e41b61 100644 --- a/xdsl_pdl/passes/optimize_irdl.py +++ b/xdsl_pdl/passes/optimize_irdl.py @@ -3,7 +3,9 @@ from xdsl.ir import MLContext, Operation, SSAValue, Use from xdsl.dialects import irdl -from xdsl.rewriter import InsertPoint +from xdsl.printer import Printer +from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.traits import IsTerminator from xdsl_pdl.dialects import irdl_extension from xdsl.dialects.builtin import ModuleOp from xdsl.pattern_rewriter import ( @@ -225,8 +227,49 @@ def match_and_rewrite(self, op: irdl_extension.EqOp, rewriter: PatternRewriter, ] +class RemoveDuplicateMatchOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: irdl_extension.CheckSubsetOp, rewriter: PatternRewriter, / + ): + for block in [op.lhs.block, op.rhs.block]: + match_ops: list[irdl_extension.MatchOp] = [] + for match_op in block.ops: + if isinstance(match_op, irdl_extension.MatchOp): + match_ops.append(match_op) + + # Detach the match operations + for match_op in match_ops: + match_op.detach() + + # Deduplicate them + dedup_match_ops: list[irdl_extension.MatchOp | None] = list(match_ops) + for index, match_op in enumerate(dedup_match_ops): + if match_op is None: + continue + for index2, match_op2 in list(enumerate(dedup_match_ops))[index + 1 :]: + if match_op2 is None: + continue + if match_op.arg == match_op2.arg: + match_op2.erase() + dedup_match_ops[index2] = None + + deduped_match_ops = [ + match_op for match_op in dedup_match_ops if match_op is not None + ] + if block.ops.last is not None and block.ops.last.has_trait(IsTerminator): + Rewriter.insert_ops_at_location( + deduped_match_ops, InsertPoint.before(block.ops.last) + ) + else: + Rewriter.insert_ops_at_location( + deduped_match_ops, InsertPoint.at_end(block) + ) + + class OptimizeIRDL(ModulePass): def apply(self, ctx: MLContext, op: ModuleOp): + walker = PatternRewriteWalker( GreedyRewritePatternApplier( [ @@ -240,6 +283,7 @@ def apply(self, ctx: MLContext, op: ModuleOp): AllOfIdenticalPattern(), RemoveEqOpPattern(), AllOfNestedPattern(), + RemoveDuplicateMatchOpPattern(), ] ) ) diff --git a/xdsl_pdl/passes/pdl_to_irdl.py b/xdsl_pdl/passes/pdl_to_irdl.py index 188402e..ea9c1db 100644 --- a/xdsl_pdl/passes/pdl_to_irdl.py +++ b/xdsl_pdl/passes/pdl_to_irdl.py @@ -29,7 +29,7 @@ from xdsl.traits import SymbolTable from xdsl.utils.hints import isa from z3 import Symbol -from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp +from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, MatchOp, YieldOp def add_missing_pdl_result(program: PatternOp): @@ -271,6 +271,11 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /): merge_op = EqOp([irdl_operand, pdl_operand]) rewriter.insert_op_before_matched_op(merge_op) + # Mark irdl_operand as matched. + # This ensures that the constraint will not be deleted, and will match + # an actual attribute (instead of holding no value). + rewriter.insert_op_before_matched_op(MatchOp(irdl_operand)) + for uses in list(op.op.uses): if not isinstance(uses.operation, ResultOp): raise Exception("Expected a `pdl.result` operation")