Skip to content

Commit

Permalink
misc: Stencil and csl lowering fixes (xdslproject#3442)
Browse files Browse the repository at this point in the history
Various minor fixes:
* `stencil`: Add canonicalisation pattern for `stencil.cast` where input
type == output type
* `csl-stencil-handle-async-flow`: When creating a function call graph,
copy arith constants to the function in which they are used, undoing
mlir-opt canonicalisation moving them around.
* `test-add-timers-to-top-level-funcs`: Generate timers as `private`
functions

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
2 people authored and EdmundGoodman committed Dec 6, 2024
1 parent 5c08e4e commit b632d12
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ builtin.module {
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: func.func @timer_start() -> f64
// CHECK-NEXT: func.func @timer_end(f64) -> f64
// CHECK-NEXT: func.func private @timer_start() -> f64
// CHECK-NEXT: func.func private @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 {
Expand Down
12 changes: 11 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,16 @@ class AllocOp(IRDLOperation):
traits = traits_def(AllocOpEffect())


class CastOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.stencil import (
RemoveCastWithNoEffect,
)

return (RemoveCastWithNoEffect(),)


@irdl_op_definition
class CastOp(IRDLOperation):
"""
Expand Down Expand Up @@ -721,7 +731,7 @@ class CastOp(IRDLOperation):
"$field attr-dict-with-keyword `:` type($field) `->` type($result)"
)

traits = traits_def(NoMemoryEffect())
traits = traits_def(NoMemoryEffect(), CastOpHasCanonicalizationPatternsTrait())

@staticmethod
def get(
Expand Down
11 changes: 11 additions & 0 deletions xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N

rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args))
rewriter.replace_matched_op(new, replace_results)


class RemoveCastWithNoEffect(RewritePattern):
"""
Remove `stencil.cast` where input and output types are equal.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None:
if op.result.type == op.field.type:
rewriter.replace_matched_op([], new_results=[op.field])
25 changes: 25 additions & 0 deletions xdsl/transforms/csl_stencil_handle_async_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,28 @@ def _is_inside_wrapper_outside_apply(op: Operation):
return is_inside_wrapper and not is_inside_apply and has_apply_inside


@dataclass(frozen=True)
class CopyArithConstants(RewritePattern):
""" """

@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Constant, rewriter: PatternRewriter, /):
if not (parent_func := self._get_enclosing_function(op)):
return
for use in list(op.result.uses):
use_func = self._get_enclosing_function(use.operation)
if use_func != parent_func:
rewriter.insert_op(cln := op.clone(), InsertPoint.before(use.operation))
op.result.replace_by_if(cln.result, lambda x: x == use)

@staticmethod
def _get_enclosing_function(op: Operation) -> csl.FuncOp | None:
parent = op.parent_op()
while parent and not isinstance(parent, csl.FuncOp):
parent = parent.parent_op()
return parent


@dataclass(frozen=True)
class CslStencilHandleAsyncControlFlow(ModulePass):
"""
Expand All @@ -283,3 +305,6 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
apply_recursively=False,
)
module_pass.rewrite_module(op)
PatternRewriteWalker(
CopyArithConstants(), apply_recursively=False
).rewrite_module(op)
4 changes: 2 additions & 2 deletions xdsl/transforms/function_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
end_func_t = func.FunctionType.from_lists(
[builtin.Float64Type()], [builtin.Float64Type()]
)
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]))
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]))
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]), "private")
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]), "private")

PatternRewriteWalker(
AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False
Expand Down

0 comments on commit b632d12

Please sign in to comment.