Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Support devito timers in the csl pipeline #3312

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>,
// CHECK-NEXT: }) : () -> ()


func.func @bufferized(%arg0 : memref<512xf32>, %arg1 : memref<512xf32>) {
func.func @bufferized(%arg0 : memref<512xf32>, %arg1 : memref<512xf32>, %timers : !llvm.ptr) {
%start = func.call @timer_start() : () -> f64
%0 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
csl_stencil.apply(%arg0 : memref<512xf32>, %0 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>}> ({
^0(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>):
Expand All @@ -120,8 +121,12 @@ func.func @bufferized(%arg0 : memref<512xf32>, %arg1 : memref<512xf32>) {
linalg.mul ins(%arg3_1, %8 : memref<510xf32>, memref<510xf32>) outs(%arg3_1 : memref<510xf32>)
csl_stencil.yield %arg3_1 : memref<510xf32>
}) to <[0, 0], [1, 1]>
%end = func.call @timer_end(%start) : (f64) -> f64
"llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
func.return
}
func.func private @timer_start() -> f64
func.func private @timer_end(f64) -> f64


// CHECK: "csl_wrapper.module"() <{"width" = 1024 : i16, "height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "bufferized"}> ({
Expand Down Expand Up @@ -150,25 +155,39 @@ func.func @bufferized(%arg0 : memref<512xf32>, %arg1 : memref<512xf32>) {
// CHECK-NEXT: %77 = "csl_wrapper.import"(%72, %74, %stencil_comms_params_1) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module
// CHECK-NEXT: %arg0 = memref.alloc() : memref<512xf32>
// CHECK-NEXT: %arg1 = memref.alloc() : memref<512xf32>
// CHECK-NEXT: %timers = memref.alloc() : memref<6xui16>
// CHECK-NEXT: %78 = "csl.addressof"(%arg0) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %79 = "csl.addressof"(%arg1) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %80 = "csl.addressof"(%timers) : (memref<6xui16>) -> !csl.ptr<ui16, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: "csl.export"(%78) <{"var_name" = "arg0", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%79) <{"var_name" = "arg1", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%80) <{"var_name" = "timers", "type" = !csl.ptr<ui16, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<ui16, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: %81 = "csl_wrapper.import"() <{"module" = "<time>", "fields" = []}> : () -> !csl.imported_module
// CHECK-NEXT: "csl.export"() <{"var_name" = @bufferized, "type" = () -> ()}> : () -> ()
// CHECK-NEXT: csl.func @bufferized() {
// CHECK-NEXT: %80 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: csl_stencil.apply(%arg0 : memref<512xf32>, %80 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>}> ({
// CHECK-NEXT: %82 = "csl.addressof"(%timers) : (memref<6xui16>) -> !csl.ptr<ui16, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %83 = "csl.ptrcast"(%82) : (!csl.ptr<ui16, #csl<ptr_kind many>, #csl<ptr_const var>>) -> !csl.ptr<memref<3xui16>, #csl<ptr_kind single>, #csl<ptr_const var>>
// CHECK-NEXT: "csl.member_call"(%81) <{"field" = "enable_tsc"}> : (!csl.imported_module) -> ()
// CHECK-NEXT: "csl.member_call"(%81, %83) <{"field" = "get_timestamp"}> : (!csl.imported_module, !csl.ptr<memref<3xui16>, #csl<ptr_kind single>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: %84 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: csl_stencil.apply(%arg0 : memref<512xf32>, %84 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>}> ({
// CHECK-NEXT: ^4(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>):
// CHECK-NEXT: %81 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %82 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "memref.copy"(%81, %82) : (memref<255xf32>, memref<255xf32, strided<[1], offset: ?>>) -> ()
// CHECK-NEXT: %85 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %86 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "memref.copy"(%85, %86) : (memref<255xf32>, memref<255xf32, strided<[1], offset: ?>>) -> ()
// CHECK-NEXT: csl_stencil.yield %arg4 : memref<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^5(%arg2_1 : memref<512xf32>, %arg3_1 : memref<510xf32>):
// CHECK-NEXT: %83 = arith.constant dense<1.666600e-01> : memref<510xf32>
// CHECK-NEXT: linalg.mul ins(%arg3_1, %83 : memref<510xf32>, memref<510xf32>) outs(%arg3_1 : memref<510xf32>)
// CHECK-NEXT: %87 = arith.constant dense<1.666600e-01> : memref<510xf32>
// CHECK-NEXT: linalg.mul ins(%arg3_1, %87 : memref<510xf32>, memref<510xf32>) outs(%arg3_1 : memref<510xf32>)
// CHECK-NEXT: csl_stencil.yield %arg3_1 : memref<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: %85 = arith.constant 3 : index
// CHECK-NEXT: %86 = memref.load %timers[%85] : memref<6xui16>
// CHECK-NEXT: %87 = "csl.addressof"(%86) : (ui16) -> !csl.ptr<ui16, #csl<ptr_kind single>, #csl<ptr_const var>>
// CHECK-NEXT: %88 = "csl.ptrcast"(%87) : (!csl.ptr<ui16, #csl<ptr_kind single>, #csl<ptr_const var>>) -> !csl.ptr<memref<3xui16>, #csl<ptr_kind single>, #csl<ptr_const var>>
// CHECK-NEXT: "csl.member_call"(%81, %88) <{"field" = "get_timestamp"}> : (!csl.imported_module, !csl.ptr<memref<3xui16>, #csl<ptr_kind single>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.member_call"(%81) <{"field" = "disable_tsc"}> : (!csl.imported_module) -> ()
// CHECK-NEXT: "csl.member_call"(%76) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
Expand Down
148 changes: 145 additions & 3 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

from xdsl.builder import ImplicitBuilder
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, memref, stencil
from xdsl.dialects import arith, builtin, func, llvm, memref, stencil
from xdsl.dialects.builtin import (
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
IndexType,
IntegerAttr,
IntegerType,
ShapedType,
Signedness,
TensorType,
)
from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper
from xdsl.ir import Attribute, BlockArgument, Operation, SSAValue
from xdsl.ir import Attribute, BlockArgument, Operation, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -26,6 +30,22 @@
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

_TIMER_START = "timer_start"
_TIMER_END = "timer_end"
_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END]


def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None:
"""
Return the enclosing csl_wrapper.module
"""
parent_op = op.parent_op()
while parent_op:
if isinstance(parent_op, csl_wrapper.ModuleOp):
return parent_op
parent_op = parent_op.parent_op()
return None


@dataclass(frozen=True)
class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):
Expand All @@ -40,6 +60,10 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# erase timer stubs
if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES:
rewriter.erase_matched_op()
return
# find csl_stencil.apply ops, abort if there are none
apply_ops = self.get_csl_stencil_apply_ops(op)
if len(apply_ops) == 0:
Expand Down Expand Up @@ -176,6 +200,7 @@ def _translate_function_args(
ptr_converts: list[Operation] = []
export_ops: list[Operation] = []
cast_ops: list[Operation] = []
import_ops: list[Operation] = []

for arg in args:
arg_name = arg.name_hint or ("arg" + str(args.index(arg)))
Expand Down Expand Up @@ -215,8 +240,49 @@ def _translate_function_args(
arg_op_mapping.append(cast_op.outputs[0])
else:
arg_op_mapping.append(alloc.memref)
# check if this looks like a timer
elif isinstance(arg.type, llvm.LLVMPointerType) and all(
isinstance(u.operation, llvm.StoreOp)
and isinstance(u.operation.value, OpResult)
and isinstance(u.operation.value.op, func.Call)
and u.operation.value.op.callee.string_value() == _TIMER_END
for u in arg.uses
):
start_end_size = 3
arg_t = memref.MemRefType(
IntegerType(16, Signedness.UNSIGNED), (2 * start_end_size,)
)
arg_ops.append(alloc := memref.Alloc([], [], arg_t))
ptr_converts.append(
address := csl.AddressOfOp(
operands=[alloc],
result_types=[
csl.PtrType(
[
arg_t.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.MANY),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
)
)
export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address)))
arg_op_mapping.append(alloc.memref)
import_ops.append(
csl_wrapper.ImportOp(
"<time>",
field_name_mapping={},
)
)

return [*arg_ops, *cast_ops, *ptr_converts, *export_ops], arg_op_mapping
return [
*arg_ops,
*cast_ops,
*ptr_converts,
*export_ops,
*import_ops,
], arg_op_mapping

def initialise_layout_module(self, module_op: csl_wrapper.ModuleOp):
"""Initialises the layout_module (wrapper block) by setting up (esp. stencil-related) program params"""
Expand Down Expand Up @@ -319,6 +385,81 @@ def initialise_program_module(
module_op.program_module.block.add_op(csl_wrapper.YieldOp([], []))


@dataclass(frozen=True)
class LowerTimerFuncCall(RewritePattern):
"""
Lowers calls to the start and end timer to csl API calls.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
if (
not isinstance(end_call := op.value.owner, func.Call)
or not end_call.callee.string_value() == _TIMER_END
or not (isinstance(start_call := end_call.arguments[0].owner, func.Call))
or not start_call.callee.string_value() == _TIMER_START
or not (wrapper := _get_module_wrapper(op))
or not isa(op.ptr.type, AnyMemRefType)
):
return

time_lib = wrapper.get_program_import("<time>")

three_elem_ptr_type = csl.PtrType(
[
memref.MemRefType(op.ptr.type.get_element_type(), (3,)),
csl.PtrKindAttr(csl.PtrKind.SINGLE),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)

rewriter.insert_op(
[
three := arith.Constant.from_int_and_width(3, IndexType()),
load_three := memref.Load.get(op.ptr, [three]),
addr_of := csl.AddressOfOp(
operands=[load_three],
result_types=[
csl.PtrType(
[
op.ptr.type.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.SINGLE),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
),
ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
csl.MemberCallOp("disable_tsc", None, time_lib, []),
],
InsertPoint.before(end_call),
)
rewriter.insert_op(
[
addr_of := csl.AddressOfOp(
operands=[op.ptr],
result_types=[
csl.PtrType(
[
op.ptr.type.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.MANY),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
),
ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
csl.MemberCallOp("enable_tsc", None, time_lib, []),
csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
],
InsertPoint.before(start_call),
)
rewriter.erase_op(op)
rewriter.erase_op(end_call)
rewriter.erase_op(start_call)


@dataclass(frozen=True)
class CslStencilToCslWrapperPass(ModulePass):
"""
Expand All @@ -333,6 +474,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
GreedyRewritePatternApplier(
[
ConvertStencilFuncToModuleWrappedPattern(),
LowerTimerFuncCall(),
]
),
apply_recursively=False,
Expand Down
7 changes: 4 additions & 3 deletions xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
class FuncOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /):
for arg in op.args:
if isa(arg.type, FieldType[Attribute]):
op.replace_argument_type(arg, stencil_field_to_tensor(arg.type))
if not op.is_declaration:
for arg in op.args:
if isa(arg.type, FieldType[Attribute]):
op.replace_argument_type(arg, stencil_field_to_tensor(arg.type))


def is_tensorized(
Expand Down
3 changes: 2 additions & 1 deletion xdsl/transforms/stencil_shape_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def convert_type(self, typ: stencil.FieldType[Attribute], /) -> Attribute | None
class FuncOpShapeUpdate(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
op.update_function_type()
if not op.is_declaration:
op.update_function_type()


@dataclass(frozen=True)
Expand Down
Loading