diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 905d14296d98..d7df2a4bb690 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" -import warnings from typing import Any, Optional import tvm._ffi @@ -251,7 +250,7 @@ def call_llvm_intrin(dtype, name, *args, span=None): The name of the llvm intrinsic function. args : list - Poistional arguments. + Positional arguments. span : Optional[Span] The location of this operator in the source code. @@ -271,7 +270,7 @@ def call_llvm_intrin(dtype, name, *args, span=None): else: llvm_id = name if llvm_id == 0: - warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") + raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, Op.get("tir.call_llvm_intrin"), @@ -293,7 +292,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): The name of the llvm intrinsic function. args : list - Poistional arguments. + Positional arguments. span : Optional[Span] The location of this operator in the source code. @@ -313,7 +312,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): else: llvm_id = name if llvm_id == 0: - warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") + raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, Op.get("tir.call_llvm_pure_intrin"), diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 8de142f8613e..e25b074401d4 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -250,6 +250,31 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) dtype_print_location = static_cast(dtype_locations[op].IntValue()); } + if (name == "call_llvm_pure_intrin" || name == "call_llvm_intrin") { + int n_args = call->args.size(); + int64_t id = call->args[0].as()->value; + auto f_llvm_lookup_intrinsic_name = + tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name"); + + Array args; + args.reserve(n_args + 1); + if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { + args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); + } + + for (int i = 0; i < n_args; ++i) { + if ((i == 0) && (f_llvm_lookup_intrinsic_name)) { + String name = (*f_llvm_lookup_intrinsic_name)(id); + args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayIndex(i))); + } else { + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); + } + } + if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { + args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); + } + return prefix->Call(args); + } } else if (call->op.as()) { prefix = d->AsDoc(call->op, call_p->Attr("op")); } else { diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index 21981d1f0ba1..8cffe8171a23 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -234,5 +234,12 @@ def test_comm_reducer(num_args): assert tvm.tir.max(*range(num_args)) == num_args - 1 +def test_llvm_intrin(): + with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): + a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy", 0) + with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): + a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy", 0) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 70d56e6903b7..0636a79334a5 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -504,6 +504,13 @@ def test_cast(): ) +def test_llvm_intrin_imm(): + a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0)) + _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))') + a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0)) + _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))') + + def test_binary_arith(): a = tir.Var("a", "int32") b = tir.Var("b", "int32")