From ca8153d502b93e653f7a961309110455b90a0e22 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 17 Mar 2023 10:29:06 -0700 Subject: [PATCH] [AOT]Raise error when input name is not valid (#14322) This PR fixes #13013. --- src/runtime/aot_executor/aot_executor.cc | 2 +- tests/python/relay/aot/test_cpp_aot.py | 38 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 39d5570030d6..1fed42bf04b0 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -191,7 +191,7 @@ int AotExecutor::GetInputIndex(const std::string& name) { return i; } } - return -1; + ICHECK(false) << "Invalid input name."; } std::string AotExecutor::GetInputName(int index) { diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 3c7a3bc0ca12..c1b4fd817a84 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -248,5 +248,43 @@ def test_aot_input_name_with_special_character(target_kind: str, input_name: str assert (runner.get_output(0).asnumpy() == expected_output).all() +@pytest.mark.parametrize("target_kind", ["c", "llvm"]) +def test_aot_incorrect_input_name(target_kind: str): + """Test passing incorrect input name.""" + dtype = "float32" + correct_input_name = "input" + incorrect_input_name = "input1" + input1 = relay.var(correct_input_name, shape=(10, 5), dtype=dtype) + weight = relay.var("weight", shape=(1, 5), dtype=dtype) + output = relay.add(input1, weight) + func = relay.Function([input1, weight], output) + + input_data = np.random.rand(10, 5).astype(dtype) + weight_data = np.random.rand(1, 5).astype(dtype) + params = {"weight": weight_data} + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + tvm.IRModule.from_expr(func), + target=target_kind, + params=params, + executor=tvm.relay.backend.Executor("aot", {"interface-api": "packed"}), + ) + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", "-O0"]) + + loaded_mod = tvm.runtime.load_module(test_so_path) + runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) + inputs = {incorrect_input_name: input_data} + + error_regex = r"Invalid input name." + with pytest.raises(tvm.TVMError, match=error_regex): + runner.set_input(**inputs) + + with pytest.raises(tvm.TVMError, match=error_regex): + runner.get_input_index(incorrect_input_name) + + if __name__ == "__main__": tvm.testing.main()