Skip to content

Commit

Permalink
Add the irdl_ext.match_op operation
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 21, 2024
1 parent 0089d70 commit 35139c1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
27 changes: 24 additions & 3 deletions xdsl_pdl/dialects/irdl_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from typing import Sequence

from xdsl.irdl import (
AttrSizedOperandSegments,
IRDLOperation,
irdl_op_definition,
operand_def,
region_def,
var_operand_def,
)
from xdsl.ir import Dialect, Region, SSAValue
from xdsl.ir import Dialect, IsTerminator, Region, SSAValue

from xdsl.dialects.irdl import AttributeType
from xdsl.parser import DictionaryAttr, Parser
Expand Down Expand Up @@ -43,12 +43,33 @@ def print(self, printer: Printer) -> None:
printer.print_op_attributes(self.attributes)


@irdl_op_definition
class MatchOp(IRDLOperation):
name = "irdl_ext.match"

arg = operand_def(AttributeType())

assembly_format = "attr-dict $arg"

def __init__(
self,
arg: SSAValue,
attr_dict: DictionaryAttr | None = None,
):
super().__init__(
operands=[arg],
attributes=attr_dict.data if attr_dict is not None else None,
)


@irdl_op_definition
class YieldOp(IRDLOperation):
name = "irdl_ext.yield"

args = var_operand_def(AttributeType())

traits = frozenset({IsTerminator()})

assembly_format = "attr-dict $args"

def __init__(
Expand Down Expand Up @@ -81,4 +102,4 @@ def __init__(
)


IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp, EqOp])
IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, MatchOp, YieldOp, EqOp])
46 changes: 45 additions & 1 deletion xdsl_pdl/passes/optimize_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from xdsl.ir import MLContext, Operation, SSAValue, Use
from xdsl.dialects import irdl
from xdsl.rewriter import InsertPoint
from xdsl.printer import Printer
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.traits import IsTerminator
from xdsl_pdl.dialects import irdl_extension
from xdsl.dialects.builtin import ModuleOp
from xdsl.pattern_rewriter import (
Expand Down Expand Up @@ -225,8 +227,49 @@ def match_and_rewrite(self, op: irdl_extension.EqOp, rewriter: PatternRewriter,
]


class RemoveDuplicateMatchOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: irdl_extension.CheckSubsetOp, rewriter: PatternRewriter, /
):
for block in [op.lhs.block, op.rhs.block]:
match_ops: list[irdl_extension.MatchOp] = []
for match_op in block.ops:
if isinstance(match_op, irdl_extension.MatchOp):
match_ops.append(match_op)

# Detach the match operations
for match_op in match_ops:
match_op.detach()

# Deduplicate them
dedup_match_ops: list[irdl_extension.MatchOp | None] = list(match_ops)
for index, match_op in enumerate(dedup_match_ops):
if match_op is None:
continue
for index2, match_op2 in list(enumerate(dedup_match_ops))[index + 1 :]:
if match_op2 is None:
continue
if match_op.arg == match_op2.arg:
match_op2.erase()
dedup_match_ops[index2] = None

deduped_match_ops = [
match_op for match_op in dedup_match_ops if match_op is not None
]
if block.ops.last is not None and block.ops.last.has_trait(IsTerminator):
Rewriter.insert_ops_at_location(
deduped_match_ops, InsertPoint.before(block.ops.last)
)
else:
Rewriter.insert_ops_at_location(
deduped_match_ops, InsertPoint.at_end(block)
)


class OptimizeIRDL(ModulePass):
def apply(self, ctx: MLContext, op: ModuleOp):

walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
Expand All @@ -240,6 +283,7 @@ def apply(self, ctx: MLContext, op: ModuleOp):
AllOfIdenticalPattern(),
RemoveEqOpPattern(),
AllOfNestedPattern(),
RemoveDuplicateMatchOpPattern(),
]
)
)
Expand Down
7 changes: 6 additions & 1 deletion xdsl_pdl/passes/pdl_to_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from xdsl.traits import SymbolTable
from xdsl.utils.hints import isa
from z3 import Symbol
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, MatchOp, YieldOp


def add_missing_pdl_result(program: PatternOp):
Expand Down Expand Up @@ -271,6 +271,11 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /):
merge_op = EqOp([irdl_operand, pdl_operand])
rewriter.insert_op_before_matched_op(merge_op)

# Mark irdl_operand as matched.
# This ensures that the constraint will not be deleted, and will match
# an actual attribute (instead of holding no value).
rewriter.insert_op_before_matched_op(MatchOp(irdl_operand))

for uses in list(op.op.uses):
if not isinstance(uses.operation, ResultOp):
raise Exception("Expected a `pdl.result` operation")
Expand Down

0 comments on commit 35139c1

Please sign in to comment.