diff --git a/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir b/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir index bf14d55036..e7e62fb697 100644 --- a/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir +++ b/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir @@ -47,8 +47,8 @@ builtin.module { // CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () // CHECK-NEXT: func.return // CHECK-NEXT: } - // CHECK-NEXT: func.func @timer_start() -> f64 - // CHECK-NEXT: func.func @timer_end(f64) -> f64 + // CHECK-NEXT: func.func private @timer_start() -> f64 + // CHECK-NEXT: func.func private @timer_end(f64) -> f64 // CHECK-NEXT: } func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 { diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index fd157496cc..9dd4275244 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -689,6 +689,16 @@ class AllocOp(IRDLOperation): traits = traits_def(AllocOpEffect()) +class CastOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.stencil import ( + RemoveCastWithNoEffect, + ) + + return (RemoveCastWithNoEffect(),) + + @irdl_op_definition class CastOp(IRDLOperation): """ @@ -721,7 +731,7 @@ class CastOp(IRDLOperation): "$field attr-dict-with-keyword `:` type($field) `->` type($result)" ) - traits = traits_def(NoMemoryEffect()) + traits = traits_def(NoMemoryEffect(), CastOpHasCanonicalizationPatternsTrait()) @staticmethod def get( diff --git a/xdsl/transforms/canonicalization_patterns/stencil.py b/xdsl/transforms/canonicalization_patterns/stencil.py index 8cee81a291..6a019275a1 100644 --- a/xdsl/transforms/canonicalization_patterns/stencil.py +++ b/xdsl/transforms/canonicalization_patterns/stencil.py @@ -109,3 +109,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args)) rewriter.replace_matched_op(new, replace_results) + + +class RemoveCastWithNoEffect(RewritePattern): + """ + Remove `stencil.cast` where input and output types are equal. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None: + if op.result.type == op.field.type: + rewriter.replace_matched_op([], new_results=[op.field]) diff --git a/xdsl/transforms/csl_stencil_handle_async_flow.py b/xdsl/transforms/csl_stencil_handle_async_flow.py index 8d7204e182..f9449a1922 100644 --- a/xdsl/transforms/csl_stencil_handle_async_flow.py +++ b/xdsl/transforms/csl_stencil_handle_async_flow.py @@ -263,6 +263,28 @@ def _is_inside_wrapper_outside_apply(op: Operation): return is_inside_wrapper and not is_inside_apply and has_apply_inside +@dataclass(frozen=True) +class CopyArithConstants(RewritePattern): + """ """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: arith.Constant, rewriter: PatternRewriter, /): + if not (parent_func := self._get_enclosing_function(op)): + return + for use in list(op.result.uses): + use_func = self._get_enclosing_function(use.operation) + if use_func != parent_func: + rewriter.insert_op(cln := op.clone(), InsertPoint.before(use.operation)) + op.result.replace_by_if(cln.result, lambda x: x == use) + + @staticmethod + def _get_enclosing_function(op: Operation) -> csl.FuncOp | None: + parent = op.parent_op() + while parent and not isinstance(parent, csl.FuncOp): + parent = parent.parent_op() + return parent + + @dataclass(frozen=True) class CslStencilHandleAsyncControlFlow(ModulePass): """ @@ -283,3 +305,6 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: apply_recursively=False, ) module_pass.rewrite_module(op) + PatternRewriteWalker( + CopyArithConstants(), apply_recursively=False + ).rewrite_module(op) diff --git a/xdsl/transforms/function_transformations.py b/xdsl/transforms/function_transformations.py index d333f463b7..d07af74467 100644 --- a/xdsl/transforms/function_transformations.py +++ b/xdsl/transforms/function_transformations.py @@ -95,8 +95,8 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: end_func_t = func.FunctionType.from_lists( [builtin.Float64Type()], [builtin.Float64Type()] ) - start_func = func.FuncOp(TIMER_START, start_func_t, Region([])) - end_func = func.FuncOp(TIMER_END, end_func_t, Region([])) + start_func = func.FuncOp(TIMER_START, start_func_t, Region([]), "private") + end_func = func.FuncOp(TIMER_END, end_func_t, Region([]), "private") PatternRewriteWalker( AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False