diff --git a/example/machsuite b/example/machsuite index 7b84770bdcf..0263a70968f 160000 --- a/example/machsuite +++ b/example/machsuite @@ -1 +1 @@ -Subproject commit 7b84770bdcfa6f8a06bae077e0f47b2fe649985e +Subproject commit 0263a70968f524affdaacdc5e0a408bb21f9d929 diff --git a/lib/mlir/Transforms/LiftMemRefSubview.cc b/lib/mlir/Transforms/LiftMemRefSubview.cc index 0fc430f5bd3..f427d575a5d 100644 --- a/lib/mlir/Transforms/LiftMemRefSubview.cc +++ b/lib/mlir/Transforms/LiftMemRefSubview.cc @@ -630,6 +630,31 @@ static LogicalResult aliasMemRef(FuncOp callee, return success(); } +static LogicalResult matchCallerToCallee(CallOp caller, FuncOp callee) { + if (caller.getCallee() != callee.getName()) + return failure(); + + OpBuilder b(callee.getContext()); + for (auto it : enumerate(caller.getOperands())) { + auto operand = it.value(); + auto targetType = callee.getArgument(it.index()).getType(); + if (operand.getType() != targetType) { + assert(operand.isa()); + + auto f = dyn_cast( + operand.cast().getOwner()->getParentOp()); + assert(f && "Parent of the type-mismatched value should be a function."); + + operand.setType(targetType); + auto entry = &f.getBlocks().front(); + f.setType(b.getFunctionType(entry->getArgumentTypes(), + f.getType().getResults())); + } + } + + return success(); +} + static LogicalResult flattenPartitionDims(ModuleOp m) { LLVM_DEBUG(dbgs() << "Flattening ..." << m << '\n'); using OpPair = std::pair; @@ -804,6 +829,7 @@ static LogicalResult flattenPartitionDims(ModuleOp m) { subviewOp.erase(); } + /// Rewrite the function interface with the flattened memref type. BlockArgument arg = memref.dyn_cast(); assert(arg); @@ -824,6 +850,12 @@ static LogicalResult flattenPartitionDims(ModuleOp m) { arg.setType(newSrcTy); f.setType(b.getFunctionType(arg.getOwner()->getArgumentTypes(), f.getType().getResults())); + + // Fix the callers to this function. + m.walk([&](CallOp caller) { + if (caller.getCallee() == f.getName()) + assert(succeeded(matchCallerToCallee(caller, f))); + }); } return success(); diff --git a/pyphism/machsuite/ms_flow.py b/pyphism/machsuite/ms_flow.py index b0fcb7ffb4e..f312b1d7099 100644 --- a/pyphism/machsuite/ms_flow.py +++ b/pyphism/machsuite/ms_flow.py @@ -96,7 +96,7 @@ def run(self): .polymer_opt() .phism_fold_if() .phism_loop_transforms() - .phism_array_partition() + .phism_array_partition(flatten=True) # .lower_scf() .lower_llvm() .phism_vitis_opt()