From 3d9af5fb67366b19118509d685716654bd7de9e1 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 16 Oct 2024 19:19:30 +0200 Subject: [PATCH 1/2] transformations: Support timers in the csl pipeline --- xdsl/transforms/csl_stencil_to_csl_wrapper.py | 142 +++++++++++++++++- .../stencil_tensorize_z_dimension.py | 7 +- xdsl/transforms/stencil_shape_minimize.py | 3 +- 3 files changed, 145 insertions(+), 7 deletions(-) diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py index 695b9cd9c6..1090596021 100644 --- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py +++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py @@ -3,16 +3,20 @@ from xdsl.builder import ImplicitBuilder from xdsl.context import MLContext -from xdsl.dialects import arith, builtin, func, memref, stencil +from xdsl.dialects import arith, builtin, func, llvm, memref, stencil from xdsl.dialects.builtin import ( + AnyMemRefType, AnyMemRefTypeConstr, AnyTensorTypeConstr, + IndexType, IntegerAttr, + IntegerType, ShapedType, + Signedness, TensorType, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper -from xdsl.ir import Attribute, BlockArgument, Operation, SSAValue +from xdsl.ir import Attribute, BlockArgument, Operation, OpResult, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -26,6 +30,22 @@ from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr +_TIMER_START = "timer_start" +_TIMER_END = "timer_end" +_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END] + + +def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None: + """ + Return the enclosing csl_wrapper.module + """ + parent_op = op.parent_op() + while parent_op: + if isinstance(parent_op, csl_wrapper.ModuleOp): + return parent_op + parent_op = parent_op.parent_op() + return None + @dataclass(frozen=True) class ConvertStencilFuncToModuleWrappedPattern(RewritePattern): @@ -176,6 +196,7 @@ def _translate_function_args( ptr_converts: list[Operation] = [] export_ops: list[Operation] = [] cast_ops: list[Operation] = [] + import_ops: list[Operation] = [] for arg in args: arg_name = arg.name_hint or ("arg" + str(args.index(arg))) @@ -215,8 +236,49 @@ def _translate_function_args( arg_op_mapping.append(cast_op.outputs[0]) else: arg_op_mapping.append(alloc.memref) + # check if this looks like a timer + elif isinstance(arg.type, llvm.LLVMPointerType) and all( + isinstance(u.operation, llvm.StoreOp) + and isinstance(u.operation.value, OpResult) + and isinstance(u.operation.value.op, func.Call) + and u.operation.value.op.callee.string_value() == _TIMER_END + for u in arg.uses + ): + start_end_size = 3 + arg_t = memref.MemRefType( + IntegerType(16, Signedness.UNSIGNED), (2 * start_end_size,) + ) + arg_ops.append(alloc := memref.Alloc([], [], arg_t)) + ptr_converts.append( + address := csl.AddressOfOp( + operands=[alloc], + result_types=[ + csl.PtrType( + [ + arg_t.get_element_type(), + csl.PtrKindAttr(csl.PtrKind.MANY), + csl.PtrConstAttr(csl.PtrConst.VAR), + ] + ) + ], + ) + ) + export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address))) + arg_op_mapping.append(alloc.memref) + import_ops.append( + csl_wrapper.ImportOp( + "