diff --git a/tests/filecheck/transforms/linalg-fuse-multiply-add.mlir b/tests/filecheck/transforms/linalg-fuse-multiply-add.mlir new file mode 100644 index 0000000000..f9b34f384a --- /dev/null +++ b/tests/filecheck/transforms/linalg-fuse-multiply-add.mlir @@ -0,0 +1,63 @@ +// RUN: xdsl-opt %s -p linalg-fuse-multiply-add | filecheck %s +// RUN: xdsl-opt %s -p linalg-fuse-multiply-add{require_scalar_factor=true} | filecheck %s --check-prefix=SCALAR +// RUN: xdsl-opt %s -p linalg-fuse-multiply-add{require_erasable_mul=true} | filecheck %s --check-prefix=FOLD-MUL + +builtin.module { + %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %c = arith.constant dense<2.997925e+08> : tensor<8xf32> + %0 = linalg.mul ins(%t0, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) -> tensor<8xf32> + %1 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> + %2 = linalg.add ins(%0, %t2 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %3 = linalg.add ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> + %4 = linalg.sub ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> +} + +// CHECK-NEXT: builtin.module { +// CHECK-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) +// CHECK-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> +// CHECK-NEXT: %0 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%t0, %t1, %t2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) { +// CHECK-NEXT: ^0(%2 : f32, %3 : f32, %4 : f32, %5 : f32): +// CHECK-NEXT: %6 = arith.mulf %2, %3 : f32 +// CHECK-NEXT: %7 = arith.addf %6, %4 : f32 +// CHECK-NEXT: linalg.yield %7 : f32 +// CHECK-NEXT: } -> tensor<8xf32> +// CHECK-NEXT: %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%c, %t1, %t3 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) { +// CHECK-NEXT: ^1(%9 : f32, %10 : f32, %11 : f32, %12 : f32): +// CHECK-NEXT: %13 = arith.mulf %9, %10 : f32 +// CHECK-NEXT: %14 = arith.addf %13, %11 : f32 +// CHECK-NEXT: linalg.yield %14 : f32 +// CHECK-NEXT: } -> tensor<8xf32> +// CHECK-NEXT: %15 = linalg.sub ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> +// CHECK-NEXT: } + + +// SCALAR-NEXT: builtin.module { +// SCALAR-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) +// SCALAR-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> +// SCALAR-NEXT: %0 = linalg.mul ins(%t0, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) -> tensor<8xf32> +// SCALAR-NEXT: %1 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> +// SCALAR-NEXT: %2 = linalg.add ins(%0, %t2 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> +// SCALAR-NEXT: %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%c, %t1, %t3 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) { +// SCALAR-NEXT: ^0(%4 : f32, %5 : f32, %6 : f32, %7 : f32): +// SCALAR-NEXT: %8 = arith.mulf %4, %5 : f32 +// SCALAR-NEXT: %9 = arith.addf %8, %6 : f32 +// SCALAR-NEXT: linalg.yield %9 : f32 +// SCALAR-NEXT: } -> tensor<8xf32> +// SCALAR-NEXT: %10 = linalg.sub ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> +// SCALAR-NEXT: } + + +// FOLD-MUL-NEXT: builtin.module { +// FOLD-MUL-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) +// FOLD-MUL-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> +// FOLD-MUL-NEXT: %0 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> +// FOLD-MUL-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%t0, %t1, %t2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) { +// FOLD-MUL-NEXT: ^0(%2 : f32, %3 : f32, %4 : f32, %5 : f32): +// FOLD-MUL-NEXT: %6 = arith.mulf %2, %3 : f32 +// FOLD-MUL-NEXT: %7 = arith.addf %6, %4 : f32 +// FOLD-MUL-NEXT: linalg.yield %7 : f32 +// FOLD-MUL-NEXT: } -> tensor<8xf32> +// FOLD-MUL-NEXT: %8 = linalg.add ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> +// FOLD-MUL-NEXT: %9 = linalg.sub ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> +// FOLD-MUL-NEXT: } diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 615431443b..a8f24322fa 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -166,6 +166,11 @@ def get_lift_arith_to_linalg(): return LiftArithToLinalg + def get_linalg_fuse_multiply_add(): + from xdsl.transforms.linalg_transformations import LinalgFuseMultiplyAddPass + + return LinalgFuseMultiplyAddPass + def get_linalg_to_csl(): from xdsl.transforms.linalg_to_csl import LinalgToCsl @@ -482,6 +487,7 @@ def get_stencil_shape_minimize(): "hls-convert-stencil-to-ll-mlir": get_hls_convert_stencil_to_ll_mlir, "apply-individual-rewrite": get_individual_rewrite, "lift-arith-to-linalg": get_lift_arith_to_linalg, + "linalg-fuse-multiply-add": get_linalg_fuse_multiply_add, "linalg-to-csl": get_linalg_to_csl, "lower-affine": get_lower_affine, "lower-csl-stencil": get_lower_csl_stencil, diff --git a/xdsl/transforms/linalg_transformations.py b/xdsl/transforms/linalg_transformations.py new file mode 100644 index 0000000000..bbf26aa9d4 --- /dev/null +++ b/xdsl/transforms/linalg_transformations.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass + +from xdsl.builder import Builder +from xdsl.context import MLContext +from xdsl.dialects import arith, linalg +from xdsl.dialects.builtin import AffineMapAttr, DenseIntOrFPElementsAttr, ModuleOp +from xdsl.ir import BlockArgument, OpResult, SSAValue +from xdsl.ir.affine import AffineMap +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + + +def build_generic_fma( + mul_op1: SSAValue, mul_op2: SSAValue, add_op: SSAValue, out: SSAValue +) -> linalg.Generic: + inputs = (mul_op1, mul_op2, add_op) + outputs = (out,) + + arg_types = linalg.NamedOpBase.body_arg_types((*inputs, *outputs)) + + @Builder.implicit_region(arg_types) + def body(args: tuple[BlockArgument, ...]) -> None: + m = arith.Mulf(args[0], args[1]) + a = arith.Addf(m, args[2]) + linalg.YieldOp(a) + + return linalg.Generic( + inputs, + outputs, + body, + 4 * [AffineMapAttr(AffineMap.from_callable(lambda i,: (i,)))], + [linalg.IteratorTypeAttr.parallel()], + [out.type], + ) + + +@dataclass(frozen=True) +class FuseMultiplyAddPass(RewritePattern): + require_scalar_factor: bool + require_erasable_mul: bool + + @op_type_rewrite_pattern + def match_and_rewrite(self, mul: linalg.MulOp, rewriter: PatternRewriter, /): + if ( + len(mul.res) != 1 + or self.require_erasable_mul + and len(set(use.operation for use in mul.res[0].uses)) != 1 + ): + return + + for add in set( + use.operation + for use in mul.res[0].uses + if isinstance(use.operation, linalg.AddOp) + and mul.res[0] in use.operation.inputs + ): + # if the `require_scalar_factor` flag is set, check if either operand of `mul` is a scalar + if ( + self.require_scalar_factor + and not self.is_scalar_constant(mul.inputs[0]) + and not self.is_scalar_constant(mul.inputs[1]) + ): + return + + # the operand of `add` that is not the `mul` result + add_operand = ( + add.inputs[0] if mul.res[0] == add.inputs[1] else add.inputs[1] + ) + + # build fma op + fma = build_generic_fma( + mul.inputs[0], mul.inputs[1], add_operand, mul.outputs[0] + ) + + # replace in position of the add op + rewriter.replace_op(add, fma) + if len(mul.res[0].uses) == 0: + rewriter.erase_matched_op() + + @staticmethod + def is_scalar_constant(op: SSAValue) -> bool: + """ + Returns if the value is a scalar. This currently checks for scalar constants, and could + in the future be extended to check for dynamically provided scalar values expanded via linalg.fill + """ + return ( + isinstance(op, OpResult) + and isinstance(op.op, arith.Constant) + and ( + not isinstance(v := op.op.value, DenseIntOrFPElementsAttr) + or v.data.data.count(v.data.data[0]) == len(v.data.data) + ) + ) + + +@dataclass(frozen=True) +class LinalgFuseMultiplyAddPass(ModulePass): + """ + Pass that fuses linalg multiply and add ops into a `generic` fma. + """ + + name = "linalg-fuse-multiply-add" + + require_scalar_factor: bool = False + """Set to require one of the mul factors to be a scalar constant""" + + require_erasable_mul: bool = False + """Set to only fuse ops if the multiply has no other use and can be erased""" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + module_pass = PatternRewriteWalker( + FuseMultiplyAddPass(self.require_scalar_factor, self.require_erasable_mul), + apply_recursively=False, + ) + module_pass.rewrite_module(op)