Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Stencil and csl lowering fixes #3442

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ builtin.module {
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: func.func @timer_start() -> f64
// CHECK-NEXT: func.func @timer_end(f64) -> f64
// CHECK-NEXT: func.func private @timer_start() -> f64
// CHECK-NEXT: func.func private @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 {
Expand Down
12 changes: 11 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,16 @@ class AllocOp(IRDLOperation):
traits = traits_def(AllocOpEffect())


class CastOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.stencil import (
RemoveCastWithNoEffect,
)

return (RemoveCastWithNoEffect(),)


@irdl_op_definition
class CastOp(IRDLOperation):
"""
Expand Down Expand Up @@ -722,7 +732,7 @@ class CastOp(IRDLOperation):
"$field attr-dict-with-keyword `:` type($field) `->` type($result)"
)

traits = traits_def(NoMemoryEffect())
traits = traits_def(NoMemoryEffect(), CastOpHasCanonicalizationPatternsTrait())

@staticmethod
def get(
Expand Down
11 changes: 11 additions & 0 deletions xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N

rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args))
rewriter.replace_matched_op(new, replace_results)


class RemoveCastWithNoEffect(RewritePattern):
"""
Remove `stencil.cast` where input and output types are equal.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None:
if op.result.type == op.field.type:
rewriter.replace_matched_op([], new_results=[op.field])
25 changes: 25 additions & 0 deletions xdsl/transforms/csl_stencil_handle_async_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,28 @@ def _is_inside_wrapper_outside_apply(op: Operation):
return is_inside_wrapper and not is_inside_apply and has_apply_inside


@dataclass(frozen=True)
class CopyArithConstants(RewritePattern):
""" """

@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Constant, rewriter: PatternRewriter, /):
if not (parent_func := self._get_enclosing_function(op)):
return
for use in list(op.result.uses):
use_func = self._get_enclosing_function(use.operation)
if use_func != parent_func:
rewriter.insert_op(cln := op.clone(), InsertPoint.before(use.operation))
op.result.replace_by_if(cln.result, lambda x: x == use)

@staticmethod
def _get_enclosing_function(op: Operation) -> csl.FuncOp | None:
parent = op.parent_op()
while parent and not isinstance(parent, csl.FuncOp):
parent = parent.parent_op()
return parent


@dataclass(frozen=True)
class CslStencilHandleAsyncControlFlow(ModulePass):
"""
Expand All @@ -283,3 +305,6 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
apply_recursively=False,
)
module_pass.rewrite_module(op)
PatternRewriteWalker(
CopyArithConstants(), apply_recursively=False
).rewrite_module(op)
4 changes: 2 additions & 2 deletions xdsl/transforms/function_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
end_func_t = func.FunctionType.from_lists(
[builtin.Float64Type()], [builtin.Float64Type()]
)
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]))
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]))
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]), "private")
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]), "private")

PatternRewriteWalker(
AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False
Expand Down
Loading