From 2211e05144ad8885e5dde30c225f9c6a6fc700a1 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 7 Apr 2020 17:58:39 -0500 Subject: [PATCH 1/3] [LLVM] Fix generation of LLVM intrinsics The type list in the call to llvm::Intrinsic::getDeclaration is not the intrinsic's signature, it's the list of overloaded types. Without this fix, the updated unit test would cause the following error: TVMError: LLVM module verification failed with the following errors: Intrinsic name not mangled correctly for type arguments! Should be: llvm.ctlz.i32 i32 (i32, i1)* @llvm.ctlz.i32.i1 Special handling for llvm.prefetch, sig matching for overloaded ints only The prefetch intrinsic returns void in LLVM, while it returns i32 in TVM. This case needs to be handled specially, because rule-based intrinsic translation would cause invalid LLVM type to be created. Do the signature matching only for overloaded intrinsics. It's not needed for non-overloaded ones, so this can save a bit of compile-time. --- src/target/llvm/codegen_llvm.cc | 91 +++++++++++++++++-- src/target/llvm/codegen_llvm.h | 15 +++ src/target/llvm/intrin_rule_llvm.cc | 2 +- .../unittest/test_target_codegen_llvm.py | 29 +++++- topi/python/topi/arm_cpu/bitserial_conv2d.py | 9 +- 5 files changed, 130 insertions(+), 16 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bb0b7e46baf8..69312821645d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -684,6 +684,74 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { return call; } +llvm::Function* CodeGenLLVM::GetIntrinsicDecl( + llvm::Intrinsic::ID id, llvm::Type* ret_type, + llvm::ArrayRef arg_types) { + llvm::Module* module = module_.get(); + + if (!llvm::Intrinsic::isOverloaded(id)) { + return llvm::Intrinsic::getDeclaration(module, id, {}); + } + + llvm::SmallVector infos; + llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos); + llvm::SmallVector overload_types; + +#if TVM_LLVM_VERSION >= 90 + auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { + overload_types.clear(); + llvm::ArrayRef ref(infos); + auto match = + llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); + if (error) { + return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; + } + } + return match; + }; + + // First, try matching the signature assuming non-vararg case. + auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false); + switch (try_match(fn_ty, false)) { + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet: + // The return type doesn't match, there is nothing else to do. + return nullptr; + case llvm::Intrinsic::MatchIntrinsicTypes_Match: + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg: + break; + } + + // Keep adding one type at a time (starting from empty list), and + // try matching the vararg signature. + llvm::SmallVector var_types; + for (int i = 0, e = arg_types.size(); i <= e; ++i) { + if (i > 0) var_types.push_back(arg_types[i - 1]); + auto* ft = llvm::FunctionType::get(ret_type, var_types, true); + if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + } + } + // Failed to identify the type. + return nullptr; + +#else // TVM_LLVM_VERSION + llvm::ArrayRef ref(infos); + // matchIntrinsicType returns true on error. + if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { + return nullptr; + } + for (llvm::Type* t : arg_types) { + if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) { + return nullptr; + } + } + return llvm::Intrinsic::getDeclaration(module, id, overload_types); +#endif // TVM_LLVM_VERSION +} + llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); @@ -691,19 +759,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; - std::vector sig_type; + std::vector arg_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); if (i - 2 < static_cast(num_signature)) { - sig_type.push_back(arg_value.back()->getType()); + arg_type.push_back(arg_value.back()->getType()); } } - llvm::Type *return_type = GetLLVMType(GetRef(op)); - if (sig_type.size() > 0 && return_type != sig_type[0]) { - sig_type.insert(sig_type.begin(), return_type); - } - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), id, sig_type); + // LLVM's prefetch intrinsic returns "void", while TVM's prefetch + // returns int32. This causes problems because prefetch is one of + // those intrinsics that is generated automatically via the + // tvm.intrin.rule mechanism. Any other intrinsic with a type + // mismatch will have to be treated specially here. + // TODO(kparzysz-quic): fix this once TVM prefetch uses the same + // type as LLVM. + llvm::Type *return_type = (id != llvm::Intrinsic::prefetch) + ? GetLLVMType(GetRef(op)) + : llvm::Type::getVoidTy(*ctx_); + + llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); + CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch"; return builder_->CreateCall(f, arg_value); } else if (op->is_intrinsic(CallNode::bitwise_and)) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 6249aa4f74bc..e785f3eab275 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -232,6 +232,21 @@ class CodeGenLLVM : * \param type The corresponding TVM Type. */ llvm::Type* GetLLVMType(const PrimExpr& expr) const; + /*! + * \brief Get the declaration of the LLVM intrinsic based on the intrinsic + * id, and the type of the actual call. + * + * \param id The intrinsic id. + * \param ret_type The call return type. + * \param arg_types The types of the call arguments. + * + * \return Return the llvm::Function pointer, or nullptr if the declaration + * could not be generated (e.g. if the argument/return types do not + * match). + */ + llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type* ret_type, + llvm::ArrayRef arg_types); // initialize the function state. void InitFuncState(); // Get alignment given index. diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 880a0fe58000..1a2d4931c674 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -30,7 +30,7 @@ namespace codegen { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>); +.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 34135c6ef7ee..3de1d1679e70 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -40,6 +40,30 @@ def test_llvm_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_overloaded_intrin(): + # Name lookup for overloaded intrinsics in LLVM 4- requires a name + # that includes the overloaded types. + if tvm.target.codegen.llvm_version_major() < 5: + return + + def use_llvm_intrinsic(A, C): + ib = tvm.tir.ir_builder.create() + L = A.vload((0,0)) + I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz', + tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1')) + S = C.vstore((0,0), I) + ib.emit(S) + return ib.get() + + A = tvm.te.placeholder((1,1), dtype = 'int32', name = 'A') + C = tvm.te.extern((1,1), [A], + lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), + name = 'C' , dtype = 'int32') + + s = tvm.te.create_schedule(C.op) + f = tvm.build(s, [A, C], target = 'llvm') + + def test_llvm_import(): # extern "C" is necessary to get the correct signature cc_code = """ @@ -82,9 +106,9 @@ def check_llvm(use_file): def test_llvm_lookup_intrin(): ib = tvm.tir.ir_builder.create() - m = te.size_var("m") A = ib.pointer("uint8x8", name="A") - x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A) + z = tvm.tir.const(0, 'int32') + x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) ib.emit(x) body = ib.get() func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True) @@ -680,6 +704,7 @@ def vectorizer(op): test_llvm_vadd_pipeline() test_llvm_add_pipeline() test_llvm_intrin() + test_llvm_overloaded_intrin() test_llvm_flip_pipeline() test_llvm_madd_pipeline() test_llvm_temp_space() diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index bdda496f8fb8..b7da66f9168f 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -197,7 +197,6 @@ def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] - args_1 = tvm.tir.const(1, 'uint32') args_2 = tvm.tir.const(2, 'uint32') if unipolar: @@ -237,10 +236,10 @@ def _instr(index): cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts8[i*2], cnts8[i*2+1]) + args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts4[i*2], cnts4[i*2+1]) + args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) @@ -257,10 +256,10 @@ def _instr(index): cnts8[i] = tvm.tir.popcount(w_ & x_) for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts8[i*2], cnts8[i*2+1]) + args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts4[i*2], cnts4[i*2+1]) + args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) From f249fbb3f13d129bad8509db4850deb4c0da1f29 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 10 Apr 2020 16:41:05 -0500 Subject: [PATCH 2/3] Include intrinsic name in the error message --- src/target/llvm/codegen_llvm.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 69312821645d..7112691de1bc 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -778,7 +778,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); - CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch"; + CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " + << llvm::Intrinsic::getName(id, {}); return builder_->CreateCall(f, arg_value); } else if (op->is_intrinsic(CallNode::bitwise_and)) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); From f94f4a1aab89203881795a19471f7b33ccd81d0f Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 10 Apr 2020 16:41:37 -0500 Subject: [PATCH 3/3] Fix number of arguments for llvm.fmuladd and llvm.pow --- src/target/llvm/intrin_rule_llvm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 1a2d4931c674..58bfb371c577 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); @@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);