Skip to content

Commit

Permalink
[MetaSchedule][ARM] Enable ARM CPU intrinsic for MetaSchedule (#14209)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsbarinov1 authored Mar 31, 2023
1 parent 7831a79 commit b724c87
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 35 deletions.
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
/*! \brief Create default schedule rules for Micro */
TVM_DLL static Array<ScheduleRule, void> DefaultMicro();
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};
Expand Down
43 changes: 43 additions & 0 deletions include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,36 @@ class Array : public ObjectRef {
}
}

template <typename... Args>
static size_t CalcCapacityImpl() {
return 0;
}

template <typename... Args>
static size_t CalcCapacityImpl(Array<T> value, Args... args) {
return value.size() + CalcCapacityImpl(args...);
}

template <typename... Args>
static size_t CalcCapacityImpl(T value, Args... args) {
return 1 + CalcCapacityImpl(args...);
}

template <typename... Args>
static void AgregateImpl(Array<T>& dest) {} // NOLINT(*)

template <typename... Args>
static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) { // NOLINT(*)
dest.insert(dest.end(), value.begin(), value.end());
AgregateImpl(dest, args...);
}

template <typename... Args>
static void AgregateImpl(Array<T>& dest, T value, Args... args) { // NOLINT(*)
dest.push_back(value);
AgregateImpl(dest, args...);
}

public:
// Array's own methods

Expand Down Expand Up @@ -680,6 +710,19 @@ class Array : public ObjectRef {
/*! \brief specify container node */
using ContainerType = ArrayNode;

/*!
* \brief Agregate arguments into a single Array<T>
* \param args sequence of T or Array<T> elements
* \return Agregated Array<T>
*/
template <typename... Args>
static Array<T> Agregate(Args... args) {
Array<T> result;
result.reserve(CalcCapacityImpl(args...));
AgregateImpl(result, args...);
return result;
}

private:
/*!
* \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements.
Expand Down
99 changes: 64 additions & 35 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@T.prim_func
def dot_product_4x4_i8i8i32_desc(
def neon_4x4_i8i8i32_desc(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
Expand All @@ -42,7 +42,7 @@ def dot_product_4x4_i8i8i32_desc(


@T.prim_func
def dot_product_4x4_i8i8i32_neon(
def neon_4x4_i8i8i32_impl(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
Expand Down Expand Up @@ -102,42 +102,71 @@ def dot_product_4x4_i8i8i32_neon(
)


@T.prim_func
def dot_product_4x4_i8i8i32_sdot(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_i8x4 = A.vload([0], "int8x4")
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
vec_ai32 = T.broadcast(A_i32, 4)
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

vec_b = B.vload([0, 0], dtype="int8x16")

vec_c = C.vload([0], dtype="int32x4")

C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
T.uint32(3),
vec_c,
vec_a,
vec_b,
dtype="int32x4",
)
def get_dotprod_intrin(in_dtype, out_dtype):
if in_dtype == "uint8":
instr = "udot.v4u32.v16u8"
else: # if in_dtype == "int8"
instr = "sdot.v4i32.v16i8"

in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype)
out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype)
in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype)

@T.prim_func
def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])
for i in T.serial(0, 4):
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) * T.cast(
B[vi, vk], dtype=out_dtype
)

@T.prim_func
def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_i8x4 = A.vload([0], in_dtype_x4)
A_i32 = T.reinterpret(A_i8x4, dtype=out_dtype)
vec_ai32 = T.broadcast(A_i32, 4)
vec_a = T.reinterpret(vec_ai32, dtype=in_dtype_x16)

vec_b = B.vload([0, 0], dtype=in_dtype_x16)

vec_c = C.vload([0], dtype=out_dtype_x4)

C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)),
T.uint32(3),
vec_c,
vec_a,
vec_b,
dtype=out_dtype_x4,
)

return dot_prod_desc, dot_prod_impl


ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot"
ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot"

TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, neon_4x4_i8i8i32_impl)

TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", "int32"))

TensorIntrin.register(
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
)
TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", "uint32"))

TensorIntrin.register(
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
)
TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", "int32"))
90 changes: 90 additions & 0 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,94 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
};
}

Array<ScheduleRule> GetNeonSpecificRules() {
return {
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/String("dot_4x4_i8i8s32_neon"),
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(32),
/*vector_load_lens=*/NullOpt,
/*reuse_read=*/NullOpt,
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("may")},
{"levels", Array<Integer>{1, 2}},
{"scope", String("global")}}),
};
}

Array<ScheduleRule> GetDotprodSpecificRules() {
return {
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/String("dot_4x4_i8i8s32_sdot"),
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(32),
/*vector_load_lens=*/NullOpt,
/*reuse_read=*/NullOpt,
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("may")},
{"levels", Array<Integer>{1, 2}},
{"scope", String("global")}}),
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/String("dot_4x4_u8u8u32_udot"),
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(32),
/*vector_load_lens=*/NullOpt,
/*reuse_read=*/NullOpt,
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("may")},
{"levels", Array<Integer>{1, 2}},
{"scope", String("global")}}),
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/String("dot_4x4_u8u8i32_hdot"),
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(32),
/*vector_load_lens=*/NullOpt,
/*reuse_read=*/NullOpt,
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("may")},
{"levels", Array<Integer>{1, 2}},
{"scope", String("global")}}),
};
}

Array<ScheduleRule> ScheduleRule::DefaultARM(const String& type) {
return Array<ScheduleRule>::Agregate(
ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
/*inline_const_tensor=*/true,
/*disallow_if_then_else=*/true,
/*require_injective=*/true,
/*require_ordered=*/true,
/*disallow_op=*/Array<String>{"tir.exp"}),
ScheduleRule::AddRFactor(
/*max_jobs_per_core=*/8,
/*max_innermost_factor=*/Integer(32)),
"neon" == type ? GetNeonSpecificRules() : Array<ScheduleRule>{},
"dotprod" == type ? GetDotprodSpecificRules() : Array<ScheduleRule>{},
ScheduleRule::MultiLevelTiling(
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(32),
/*vector_load_lens=*/NullOpt,
/*reuse_read=*/NullOpt,
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("may")},
{"levels", Array<Integer>{1, 2}},
{"scope", String("global")}}),
ScheduleRule::ParallelizeVectorizeUnroll(
/*max_jobs_per_core=*/8,
/*max_vectorize_extent=*/32,
/*unroll_max_steps=*/Array<Integer>{0, 8, 32, 256},
/*unroll_explicit=*/true),
ScheduleRule::RandomComputeLocation());
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) {
const auto* self = n.as<PyScheduleRuleNode>();
Expand Down Expand Up @@ -325,6 +413,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon")
.set_body_typed(ScheduleRule::DefaultHexagon);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro")
.set_body_typed(ScheduleRule::DefaultMicro);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM")
.set_body_typed(ScheduleRule::DefaultARM);

} // namespace meta_schedule
} // namespace tvm
19 changes: 19 additions & 0 deletions src/meta_schedule/space_generator/space_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "../../target/parsers/aprofile.h"
#include "../utils.h"

namespace tvm {
Expand All @@ -38,6 +39,16 @@ String GetRuleKindFromTarget(const Target& target) {
return "avx512";
}
}

TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export());
TargetFeatures afeatures = Downcast<TargetFeatures>(target_json.at("features"));

if (Downcast<Bool>(afeatures.at("has_dotprod"))) {
return "dotprod";
}
if (Downcast<Bool>(afeatures.at("has_asimd"))) {
return "asimd";
}
return "llvm";
}
if (target->kind->name == "hexagon") {
Expand Down Expand Up @@ -110,6 +121,14 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) {
default_sch_rules = ScheduleRule::DefaultMicro();
default_postprocs = Postproc::DefaultMicro();
default_mutator_probs = Mutator::DefaultMicro();
} else if (kind == "asimd") {
default_sch_rules = ScheduleRule::DefaultARM("neon");
default_postprocs = Postproc::DefaultCPUTensorization();
default_mutator_probs = Mutator::DefaultLLVM();
} else if (kind == "dotprod") {
default_sch_rules = ScheduleRule::DefaultARM("dotprod");
default_postprocs = Postproc::DefaultCPUTensorization();
default_mutator_probs = Mutator::DefaultLLVM();
} else {
LOG(FATAL) << "Unsupported kind: " << kind;
throw;
Expand Down
Loading

0 comments on commit b724c87

Please sign in to comment.