diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 7995d1fceeb6..d91812fb55cb 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -300,6 +300,8 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static Array DefaultHexagon(); /*! \brief Create default schedule rules for Micro */ TVM_DLL static Array DefaultMicro(); + /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */ + TVM_DLL static Array DefaultARM(const String& type); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index d1d5422a436f..ff0bd03ab9cb 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -580,6 +580,36 @@ class Array : public ObjectRef { } } + template + static size_t CalcCapacityImpl() { + return 0; + } + + template + static size_t CalcCapacityImpl(Array value, Args... args) { + return value.size() + CalcCapacityImpl(args...); + } + + template + static size_t CalcCapacityImpl(T value, Args... args) { + return 1 + CalcCapacityImpl(args...); + } + + template + static void AgregateImpl(Array& dest) {} // NOLINT(*) + + template + static void AgregateImpl(Array& dest, Array value, Args... args) { // NOLINT(*) + dest.insert(dest.end(), value.begin(), value.end()); + AgregateImpl(dest, args...); + } + + template + static void AgregateImpl(Array& dest, T value, Args... args) { // NOLINT(*) + dest.push_back(value); + AgregateImpl(dest, args...); + } + public: // Array's own methods @@ -680,6 +710,19 @@ class Array : public ObjectRef { /*! \brief specify container node */ using ContainerType = ArrayNode; + /*! + * \brief Agregate arguments into a single Array + * \param args sequence of T or Array elements + * \return Agregated Array + */ + template + static Array Agregate(Args... args) { + Array result; + result.reserve(CalcCapacityImpl(args...)); + AgregateImpl(result, args...); + return result; + } + private: /*! * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 9357f0ceb28a..521d882e24eb 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -26,7 +26,7 @@ @T.prim_func -def dot_product_4x4_i8i8i32_desc( +def neon_4x4_i8i8i32_desc( A: T.Buffer((4,), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), @@ -42,7 +42,7 @@ def dot_product_4x4_i8i8i32_desc( @T.prim_func -def dot_product_4x4_i8i8i32_neon( +def neon_4x4_i8i8i32_impl( A: T.Buffer((4,), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), @@ -102,42 +102,71 @@ def dot_product_4x4_i8i8i32_neon( ) -@T.prim_func -def dot_product_4x4_i8i8i32_sdot( - A: T.Buffer((4,), "int8", offset_factor=1), - B: T.Buffer((4, 4), "int8", offset_factor=1), - C: T.Buffer((4,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:4], A[0:4], B[0:4, 0:4]) - T.writes(C[0:4]) - - A_i8x4 = A.vload([0], "int8x4") - A_i32 = T.reinterpret(A_i8x4, dtype="int32") - vec_ai32 = T.broadcast(A_i32, 4) - vec_a = T.reinterpret(vec_ai32, dtype="int8x16") - - vec_b = B.vload([0, 0], dtype="int8x16") - - vec_c = C.vload([0], dtype="int32x4") - - C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), - T.uint32(3), - vec_c, - vec_a, - vec_b, - dtype="int32x4", - ) +def get_dotprod_intrin(in_dtype, out_dtype): + if in_dtype == "uint8": + instr = "udot.v4u32.v16u8" + else: # if in_dtype == "int8" + instr = "sdot.v4i32.v16i8" + + in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype) + out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype) + in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype) + + @T.prim_func + def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1) + B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1) + C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1) + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + for i in T.serial(0, 4): + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) * T.cast( + B[vi, vk], dtype=out_dtype + ) + + @T.prim_func + def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1) + B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1) + C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1) + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_i8x4 = A.vload([0], in_dtype_x4) + A_i32 = T.reinterpret(A_i8x4, dtype=out_dtype) + vec_ai32 = T.broadcast(A_i32, 4) + vec_a = T.reinterpret(vec_ai32, dtype=in_dtype_x16) + + vec_b = B.vload([0, 0], dtype=in_dtype_x16) + + vec_c = C.vload([0], dtype=out_dtype_x4) + + C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)), + T.uint32(3), + vec_c, + vec_a, + vec_b, + dtype=out_dtype_x4, + ) + + return dot_prod_desc, dot_prod_impl ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon" ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot" +ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot" +ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot" + +TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, neon_4x4_i8i8i32_impl) + +TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", "int32")) -TensorIntrin.register( - ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon -) +TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", "uint32")) -TensorIntrin.register( - ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot -) +TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", "int32")) diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 49a7c9911c01..35f1151c9c1d 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -295,6 +295,94 @@ Array ScheduleRule::DefaultMicro() { }; } +Array GetNeonSpecificRules() { + return { + ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/String("dot_4x4_i8i8s32_neon"), + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(32), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + }; +} + +Array GetDotprodSpecificRules() { + return { + ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"), + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(32), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/String("dot_4x4_u8u8u32_udot"), + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(32), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"), + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(32), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + }; +} + +Array ScheduleRule::DefaultARM(const String& type) { + return Array::Agregate( + ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), + ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/8, + /*max_innermost_factor=*/Integer(32)), + "neon" == type ? GetNeonSpecificRules() : Array{}, + "dotprod" == type ? GetDotprodSpecificRules() : Array{}, + ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(32), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/8, + /*max_vectorize_extent=*/32, + /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_explicit=*/true), + ScheduleRule::RandomComputeLocation()); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -325,6 +413,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") .set_body_typed(ScheduleRule::DefaultHexagon); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro") .set_body_typed(ScheduleRule::DefaultMicro); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM") + .set_body_typed(ScheduleRule::DefaultARM); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index a3669e996f40..c96554e6a2d6 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../target/parsers/aprofile.h" #include "../utils.h" namespace tvm { @@ -38,6 +39,16 @@ String GetRuleKindFromTarget(const Target& target) { return "avx512"; } } + + TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); + TargetFeatures afeatures = Downcast(target_json.at("features")); + + if (Downcast(afeatures.at("has_dotprod"))) { + return "dotprod"; + } + if (Downcast(afeatures.at("has_asimd"))) { + return "asimd"; + } return "llvm"; } if (target->kind->name == "hexagon") { @@ -110,6 +121,14 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultMicro(); default_postprocs = Postproc::DefaultMicro(); default_mutator_probs = Mutator::DefaultMicro(); + } else if (kind == "asimd") { + default_sch_rules = ScheduleRule::DefaultARM("neon"); + default_postprocs = Postproc::DefaultCPUTensorization(); + default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "dotprod") { + default_sch_rules = ScheduleRule::DefaultARM("dotprod"); + default_postprocs = Postproc::DefaultCPUTensorization(); + default_mutator_probs = Mutator::DefaultLLVM(); } else { LOG(FATAL) << "Unsupported kind: " << kind; throw; diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 716f829653f3..6c069dc6bf0a 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -23,6 +23,8 @@ import pytest import tvm import tvm.testing +from tvm import te +from tvm.ir.module import IRModule from tvm._ffi import register_func from tvm.error import TVMError from tvm.meta_schedule import TuneContext @@ -36,6 +38,23 @@ # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off + +def get_matmul_packed(m, n, k, lhs_type="int8", rhs_dtype="int8", acc_dtype="int32"): + X = te.placeholder((m, k), name="X", dtype=lhs_type) + W = te.placeholder((n, k), name="W", dtype=rhs_dtype) + + ak = te.reduce_axis((0, k), name="k") + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype(acc_dtype) * W[j, ak].astype(acc_dtype), + axis=ak, + ), + name="compute", + ) + return te.create_prim_func([X, W, matmul]) + + @tvm.script.ir_module class Matmul: @T.prim_func @@ -404,6 +423,60 @@ def _get_sch(filter_fn): assert len(schs) == 8 +@pytest.mark.parametrize( + "target,mod,expected_intr", + [ + ( + Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon -num-cores 2"), + IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8", "int32")}), + "dot_4x4_i8i8s32_neon", + ), + ( + Target( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod -num-cores 2" + ), + IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8", "int32")}), + "dot_4x4_i8i8s32_sdot", + ), + ( + Target( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod -num-cores 2" + ), + IRModule({"main": get_matmul_packed(128, 128, 128, "uint8", "uint8", "uint32")}), + "dot_4x4_u8u8u32_udot", + ), + ( + Target( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod -num-cores 2" + ), + IRModule({"main": get_matmul_packed(128, 128, 128, "uint8", "uint8", "int32")}), + "dot_4x4_u8u8i32_hdot", + ), + ], +) +def test_meta_schedule_post_order_apply_arm_intrin(target, mod, expected_intr): + context = TuneContext( + mod=mod, + target=target, + task_name="Arm Intrinsic Task", + space_generator=PostOrderApply(), # Triggers default generator + rand_state=1, # Change it while all tests are not passing + ) + post_order_apply = context.space_generator + schs = post_order_apply.generate_design_space(mod) + + assert len(schs) != 0 + + for sch in schs: + sch.enter_postproc() + + for proc in context.space_generator.postprocs: + proc.apply(sch) + + assert any(["call_llvm_pure_intrin" in sch.mod.script() for sch in schs]) + assert any([expected_intr in str(sch.trace) for sch in schs]) + + def test_meta_schedule_derived_object(): @derived_object class RemoveBlock(PyScheduleRule):