diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index 3587e0f3..adcd69b8 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -12,6 +12,7 @@ from dataclasses import dataclass import torch.fx as fx import torch.utils._pytree as pytree +from collections import namedtuple from ..compiler.ir import ( Attribute, @@ -200,76 +201,139 @@ def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]: return dynamics +_Rational = namedtuple("_Rational", ["numerator", "denominator"]) + + def gen_sympy_index(dynamics: dict[IndexSymbol, Any], expr: sympy.Expr) -> OpResult: stack: list[OpResult] = [] - def _broadcast(a, b): - if not isinstance(a, (Value, OpResult)): - a = a.result + def _get_ir_value(arg): + if not isinstance(arg, (Value, OpResult)): + arg = arg.result + + return arg - if not isinstance(b, (Value, OpResult)): - b = b.result + def _check_vec_scalar(a, b): + return isinstance(a.type, VectorType) and a.type.element_type == b.type + + def _broadcast(a, b): + a = _get_ir_value(a) + b = _get_ir_value(b) if a.type == b.type: return a, b - if isinstance(a.type, VectorType) and isinstance( - b.type, (IndexType, IntegerType) - ): - assert a.type.element_type == b.type + if _check_vec_scalar(a, b): b = vector_d.splat(a.type, b) return a, b - if isinstance(a.type, (IndexType, IntegerType)) and isinstance( - b.type, VectorType - ): - assert b.type.element_type == a.type + if _check_vec_scalar(b, a): a = vector_d.splat(b.type, a) return a, b raise CodegenError(f"Cannot broadcast {a.type} and {b.type}") - def _process_mul_add_ops(term, is_mul): - args = [] - callables = [] - for _ in range(len(term.args)): - val = stack.pop() - if callable(val): - callables.append(val) - else: - args.append(val) - operation = None - for arg in args: - if operation is None: - operation = arg - continue + def get_const_val(arg): + if isinstance(arg, OpResult): + arg = arg.owner.opview - if is_mul: - operation = arith_d.MulIOp(*_broadcast(operation, arg)) - else: - operation = arith_d.AddIOp(*_broadcast(operation, arg)) + if isinstance(arg, arith_d.ConstantOp): + value = arg.attributes["value"] + if isinstance(value, IntegerAttr): + return int(value) - for arg in callables: - operation = arg(operation, is_mul) + return None - stack.append(operation) + def muli_fold(lhs, rhs): + if get_const_val(lhs) == 1: + return rhs + + if get_const_val(rhs) == 1: + return lhs + + return arith_d.muli(lhs, rhs) + + # `x + (a/b)` transformed into `(x*b + a) / b` + def _add(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.denominator, rhs)) + numerator = arith_d.addi(*_broadcast(numerator, lhs.numerator)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.denominator)) + numerator = arith_d.addi(*_broadcast(numerator, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + lhs_numerator = muli_fold(*_broadcast(lhs.numerator, rhs.denominator)) + rhs_numerator = muli_fold(*_broadcast(rhs.numerator, lhs.denominator)) + numerator = arith_d.addi(*_broadcast(lhs_numerator, rhs_numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return arith_d.addi(*_broadcast(lhs, rhs)) + + # `x * (a/b)` transformed into `(x * a) / b` + def _mul(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs.numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return muli_fold(*_broadcast(lhs, rhs)) - def _get_mul(numerator): - return lambda x: arith_d.MulIOp(*_broadcast(x, numerator)) + def _floor(value): + if isinstance(value, _Rational): + value = arith_d.divsi(*_broadcast(value.numerator, value.denominator)) - def _get_add(numerator, denominator): - return lambda x: arith_d.AddIOp( - *_broadcast(arith_d.MulIOp(*_broadcast(x, denominator)), numerator) - ) + return value - def _get_div(mul, add, denominator): - return lambda x, is_mul: arith_d.DivSIOp( - *_broadcast(mul(x) if is_mul else add(x), denominator) - ) + def _ceiling(value): + if isinstance(value, _Rational): + value = arith_d.ceildivsi(*_broadcast(value.numerator, value.denominator)) + + return value + + def _group_rationals(stack, count): + """Group rationals and non-rationals args into 2 contiguous sets. + + This allows to mul/add all non-rationals first, reducing total number of ops. + """ + rationals = [] + non_rationals = [] + for _ in range(count): + val = stack.pop() + if isinstance(val, _Rational): + rationals.append(val) + else: + non_rationals.append(val) + + return non_rationals + rationals + + def _apply(args, func): + assert len(args) > 0 + value = args[0] + for val in args[1:]: + value = func(value, val) + + return value + + def _enforce_non_rational(val, term): + if isinstance(val, _Rational): + raise CodegenError(f"Rational is not supported yet in '{type(term)}'") def _get_const(val): if isinstance(val, int): - return arith_d.constant(IndexType.get(), res) + return arith_d.constant(IndexType.get(), val) if isinstance(val, (tuple, list)): vec_type = VectorType.get([len(val)], IndexType.get()) @@ -296,56 +360,50 @@ def _get_const(val): else: raise CodegenError(f"Unknown symbol {term}") case sympy.Integer(): - stack.append(arith_d.constant(IndexType.get(), int(term))) + stack.append(_get_const(int(term))) case sympy.Mul(): - _process_mul_add_ops(term, is_mul=True) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _mul)) case sympy.Add(): - _process_mul_add_ops(term, is_mul=False) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _add)) case sympy.Mod(): rhs = stack.pop() lhs = stack.pop() - mod = arith_d.RemSIOp(*_broadcast(lhs, rhs)) + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + mod = arith_d.remsi(*_broadcast(lhs, rhs)) stack.append(mod) case sympy.floor(): - # TODO: Since divsi rounds to zero, this seems to work. - # But check whether floordivsi is needed. - stack.append(stack.pop()) + stack.append(_floor(stack.pop())) + case sympy.ceiling(): + stack.append(_ceiling(stack.pop())) case sympy.Rational(): - # `x * (a/b)` transformed into `(x * a) / b` - # `x + (a/b)` transformed into `(x*b + a) / b` - numerator = arith_d.constant(IndexType.get(), abs(term.p)) - denominator = arith_d.constant(IndexType.get(), abs(term.q)) - # Assumes that the negative term is always carried on the numerator - if abs(term.p) > term.p: - zero = arith_d.constant(IndexType.get(), int(0)) - numerator = arith_d.SubIOp(*_broadcast(zero, numerator)) - mul = lambda x: x - if abs(term.p) != 1: - mul = _get_mul(numerator) - add = _get_add(numerator, denominator) - operation = _get_div(mul, add, denominator) - stack.append(operation) + numerator = _get_const(term.p) + denominator = _get_const(term.q) + stack.append(_Rational(numerator, denominator)) case sympy.StrictLessThan(): rhs = stack.pop() lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs)) stack.append(res) case sympy.And(): rhs = stack.pop() lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) res = arith_d.andi(*_broadcast(lhs, rhs)) stack.append(res) - case sympy.ceiling(): - value = stack.pop() - if not isinstance(value, arith_d.DivSIOp): - raise CodegenError(f"Cannot handle ceil({value}) yet") - stack.append(arith_d.CeilDivSIOp(value.lhs, value.rhs)) case sympy.UnevaluatedExpr(): continue case _: raise CodegenError(f"Can not handle {type(term)} : {term}") - if len(stack) != 1: + + if len(stack) != 1 or isinstance(stack[0], _Rational): raise CodegenError(f"Expected single result, got {len(stack)}") + return stack[0] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..93bd8d6f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runperf", action="store_true", default=False, help="run performance tests" + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "perf_only: performace test, runs only with '--runperf'" + ) + + +def pytest_collection_modifyitems(config, items): + run_perf = config.getoption("--runperf") + for item in items: + is_perf_only = next(item.iter_markers("perf_only"), None) is not None + if run_perf: + if not is_perf_only: + item.add_marker(pytest.mark.skip("skip non-perf test")) + else: + if is_perf_only: + item.add_marker(pytest.mark.skip("skip perf test")) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 5b6aa640..69ba718a 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -641,15 +641,193 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05) +_igemm_cases = [ + (4, 5, 5, 10, 2, 2, 16, 3), + (2, 5, 5, 10, 2, 2, 16, 3), + (1, 5, 5, 10, 2, 2, 16, 3), + (4, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 3, 2, 2, 16, 3), + (2, 5, 5, 1, 2, 2, 16, 3), + (4, 5, 5, 10, 2, 2, 2, 3), + (2, 5, 5, 10, 2, 2, 2, 3), + (1, 5, 5, 10, 2, 2, 2, 3), + (4, 5, 5, 4, 2, 2, 2, 3), + (2, 5, 5, 4, 2, 2, 2, 3), + (1, 5, 5, 3, 2, 2, 2, 3), + (2, 5, 5, 1, 2, 2, 2, 3), + (1, 5, 5, 1, 2, 2, 2, 3), + (4, 5, 5, 10, 2, 2, 1, 3), + (2, 5, 5, 10, 2, 2, 1, 3), + (1, 5, 5, 10, 2, 2, 1, 3), + (4, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 1, 3), + (1, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 3, 2, 2, 1, 3), + (4, 5, 5, 1, 2, 2, 1, 3), + (2, 5, 5, 1, 2, 2, 1, 3), + (1, 5, 5, 1, 2, 2, 1, 3), + (4, 5, 5, 10, 2, 2, 16, 2), + (2, 5, 5, 10, 2, 2, 16, 2), + (1, 5, 5, 10, 2, 2, 16, 2), + (4, 5, 5, 4, 2, 2, 16, 2), + (1, 5, 5, 4, 2, 2, 16, 2), + (4, 5, 5, 3, 2, 2, 16, 2), + (4, 5, 5, 1, 2, 2, 16, 2), + (1, 5, 5, 1, 2, 2, 16, 2), + (4, 5, 5, 10, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 2, 2), + (1, 5, 5, 10, 2, 2, 2, 2), + (4, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 3, 2, 2, 2, 2), + (2, 5, 5, 1, 2, 2, 2, 2), + (1, 5, 5, 1, 2, 2, 2, 2), + (4, 5, 5, 10, 2, 2, 1, 2), + (2, 5, 5, 10, 2, 2, 1, 2), + (1, 5, 5, 10, 2, 2, 1, 2), + (4, 5, 5, 4, 2, 2, 1, 2), + (2, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 4, 2, 2, 1, 2), + (4, 5, 5, 1, 2, 2, 1, 2), + (1, 5, 5, 1, 2, 2, 1, 2), + (4, 5, 5, 10, 2, 2, 16, 1), + (2, 5, 5, 10, 2, 2, 16, 1), + (4, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 4, 2, 2, 16, 1), + (1, 5, 5, 4, 2, 2, 16, 1), + (4, 5, 5, 3, 2, 2, 16, 1), + (1, 5, 5, 3, 2, 2, 16, 1), + (2, 5, 5, 1, 2, 2, 16, 1), + (1, 5, 5, 1, 2, 2, 16, 1), + (4, 5, 5, 10, 2, 2, 2, 1), + (2, 5, 5, 10, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 2, 1), + (4, 5, 5, 4, 2, 2, 2, 1), + (2, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 3, 2, 2, 2, 1), + (2, 5, 5, 1, 2, 2, 2, 1), + (1, 5, 5, 1, 2, 2, 2, 1), + (4, 5, 5, 10, 2, 2, 1, 1), + (2, 5, 5, 10, 2, 2, 1, 1), + (4, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 4, 2, 2, 1, 1), + (1, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 2, 2, 1, 1), + (4, 5, 5, 10, 2, 2, 16, 3), + (2, 5, 5, 10, 2, 2, 16, 3), + (1, 5, 5, 10, 2, 2, 16, 3), + (4, 5, 5, 4, 2, 2, 16, 3), + (2, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 4, 2, 2, 16, 3), + (4, 5, 5, 1, 2, 2, 16, 3), + (1, 5, 5, 1, 2, 2, 16, 3), + (4, 5, 5, 10, 2, 2, 2, 3), + (1, 5, 5, 10, 2, 2, 2, 3), + (2, 5, 5, 4, 2, 2, 2, 3), + (1, 5, 5, 4, 2, 2, 2, 3), + (2, 5, 5, 3, 2, 2, 2, 3), + (4, 5, 5, 1, 2, 2, 2, 3), + (2, 5, 5, 1, 2, 2, 2, 3), + (1, 5, 5, 1, 2, 2, 2, 3), + (4, 5, 5, 10, 2, 2, 1, 3), + (2, 5, 5, 10, 2, 2, 1, 3), + (1, 5, 5, 10, 2, 2, 1, 3), + (4, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 1, 3), + (1, 5, 5, 4, 2, 2, 1, 3), + (4, 5, 5, 1, 2, 2, 1, 3), + (2, 5, 5, 1, 2, 2, 1, 3), + (1, 5, 5, 1, 2, 2, 1, 3), + (4, 5, 5, 10, 2, 2, 16, 2), + (2, 5, 5, 10, 2, 2, 16, 2), + (1, 5, 5, 10, 2, 2, 16, 2), + (4, 5, 5, 4, 2, 2, 16, 2), + (1, 5, 5, 4, 2, 2, 16, 2), + (4, 5, 5, 1, 2, 2, 16, 2), + (2, 5, 5, 1, 2, 2, 16, 2), + (4, 5, 5, 10, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 2, 2), + (1, 5, 5, 10, 2, 2, 2, 2), + (4, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 4, 2, 2, 2, 2), + (1, 5, 5, 4, 2, 2, 2, 2), + (1, 5, 5, 3, 2, 2, 2, 2), + (2, 5, 5, 1, 2, 2, 2, 2), + (1, 5, 5, 1, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 1, 2), + (1, 5, 5, 10, 2, 2, 1, 2), + (4, 5, 5, 4, 2, 2, 1, 2), + (2, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 3, 2, 2, 1, 2), + (2, 5, 5, 1, 2, 2, 1, 2), + (1, 5, 5, 1, 2, 2, 1, 2), + (4, 5, 5, 10, 2, 2, 16, 1), + (2, 5, 5, 10, 2, 2, 16, 1), + (1, 5, 5, 10, 2, 2, 16, 1), + (4, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 4, 2, 2, 16, 1), + (1, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 3, 2, 2, 16, 1), + (1, 5, 5, 3, 2, 2, 16, 1), + (4, 5, 5, 1, 2, 2, 16, 1), + (1, 5, 5, 1, 2, 2, 16, 1), + (4, 5, 5, 10, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 2, 1), + (4, 5, 5, 4, 2, 2, 2, 1), + (2, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 4, 2, 2, 2, 1), + (4, 5, 5, 3, 2, 2, 2, 1), + (4, 5, 5, 1, 2, 2, 2, 1), + (2, 5, 5, 1, 2, 2, 2, 1), + (1, 5, 5, 1, 2, 2, 2, 1), + (4, 5, 5, 10, 2, 2, 1, 1), + (2, 5, 5, 10, 2, 2, 1, 1), + (4, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 4, 2, 2, 1, 1), + (1, 5, 5, 4, 2, 2, 1, 1), + (4, 5, 5, 3, 2, 2, 1, 1), + (2, 5, 5, 3, 2, 2, 1, 1), + (1, 5, 5, 3, 2, 2, 1, 1), + (2, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 3, 3, 1, 1), +] + +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) + +_igemm_cases += [ + perf_test(2, 128, 128, 16, 3, 3, 320, 1), + perf_test(2, 128, 128, 320, 1, 1, 640, 1), + perf_test(2, 128, 128, 320, 1, 1, 960, 1), + perf_test(2, 128, 128, 320, 3, 3, 16, 1), + perf_test(2, 128, 128, 320, 3, 3, 320, 1), + perf_test(2, 32, 32, 1280, 1, 1, 1920, 1), + perf_test(2, 32, 32, 1280, 1, 1, 2560, 1), + perf_test(2, 32, 32, 1280, 1, 1, 640, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1280, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1920, 1), + perf_test(2, 32, 32, 1280, 3, 3, 2560, 1), + perf_test(2, 32, 32, 1280, 3, 3, 640, 1), + perf_test(2, 32, 32, 640, 3, 3, 640, 1), + perf_test(2, 64, 64, 320, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 1280, 1), + perf_test(2, 64, 64, 640, 1, 1, 1920, 1), + perf_test(2, 64, 64, 640, 1, 1, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 960, 1), + perf_test(2, 64, 64, 640, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 3, 3, 640, 1), +] + + @require_e2e -@pytest.mark.parametrize("n", [1, 2, 4]) -@pytest.mark.parametrize("c", [1, 3, 4, 10]) -@pytest.mark.parametrize("nf", [1, 2, 16]) -@pytest.mark.parametrize("stride", [1, 2, 3]) +@pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) @pytest.mark.parametrize("mem_space", [GLOBAL_ADDRESS_SPACE, SHARED_ADDRESS_SPACE]) -def test_igemm_conv(n, c, nf, stride, mem_space): - h, w = 5, 5 # Image. - cf, hf, wf = c, 2, 2 # Filters. +def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space): + cf = c padding = 0 # TODO: only pad=0 is supported for now torch.manual_seed(1)