Skip to content

Commit

Permalink
transformations: (memref-to-dsd) Handle csl variables (#3236)
Browse files Browse the repository at this point in the history
Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Oct 2, 2024
1 parent 7e5543b commit b986f4a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
8 changes: 8 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ builtin.module {
// CHECK-NEXT: %26 = "test.op"() : () -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.fadds"(%26, %26, %26) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()

%33 = "csl.variable"() : () -> !csl.var<memref<512xf32>>
%34 = "csl.load_var"(%33) : (!csl.var<memref<512xf32>>) -> memref<512xf32>
"csl.store_var"(%33, %34) : (!csl.var<memref<512xf32>>, memref<512xf32>) -> ()

// CHECK-NEXT: %27 = "csl.variable"() : () -> !csl.var<!csl<dsd mem1d_dsd>>
// CHECK-NEXT: %28 = "csl.load_var"(%27) : (!csl.var<!csl<dsd mem1d_dsd>>) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.store_var"(%27, %28) : (!csl.var<!csl<dsd mem1d_dsd>>, !csl<dsd mem1d_dsd>) -> ()

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
Expand Down
9 changes: 7 additions & 2 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,15 @@ class LoadVarOp(IRDLOperation):
var = operand_def(VarType)
res = result_def()

def __init__(self, var: VariableOp):
def __init__(self, var: VariableOp | SSAValue):
if isinstance(var, SSAValue):
assert isinstance(var.type, VarType)
result_t = var.type.get_element_type()
else:
result_t = var.get_element_type()
super().__init__(
operands=[var],
result_types=[var.get_element_type()],
result_types=[result_t],
)

def verify_(self) -> None:
Expand Down
38 changes: 37 additions & 1 deletion xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,46 @@ def match_and_rewrite(self, op: csl.AddressOfOp, rewriter: PatternRewriter, /):
)


class CslVarUpdate(RewritePattern):
"""Update CSL Variable Definitions."""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.VariableOp, rewriter: PatternRewriter, /):
if (
not isinstance(op.res.type, csl.VarType)
or not isa(elem_t := op.res.type.get_element_type(), MemRefType[Attribute])
or op.default
):
return
dsd_t = csl.DsdType(
csl.DsdKind.mem1d_dsd if len(elem_t.shape) == 1 else csl.DsdKind.mem4d_dsd
)
rewriter.replace_matched_op(csl.VariableOp.from_type(dsd_t))


class CslVarLoad(RewritePattern):
"""Update CSL Load Variables."""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.LoadVarOp, rewriter: PatternRewriter, /):
if (
not isa(op.res.type, MemRefType[Attribute])
or not isinstance(op.var.type, csl.VarType)
or not isa(op.var.type.get_element_type(), csl.DsdType)
):
return
rewriter.replace_matched_op(csl.LoadVarOp(op.var))


@dataclass(frozen=True)
class MemrefToDsdPass(ModulePass):
"""
Lowers memref ops to CSL DSDs.
Note, that CSL uses memref types in some places
Note, that CSL uses memref types in some places.
This performs a backwards pass translating memref-consuming ops to dsd-consuming ops when all memref type
information is known. A second forward pass translates memref-generating ops to dsd-generating ops.
"""

name = "memref-to-dsd"
Expand All @@ -287,6 +321,8 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
forward_pass = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
CslVarUpdate(),
CslVarLoad(),
LowerAllocOpPass(),
DsdOpUpdateType(),
RetainAddressOfOpPass(),
Expand Down

0 comments on commit b986f4a

Please sign in to comment.