From dcd078eb82b7c9396d6c5e5622e06da705112c6f Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 16 Mar 2023 12:17:01 -0700 Subject: [PATCH 1/2] show error when incorrect input name passed --- src/runtime/aot_executor/aot_executor.cc | 2 +- tests/python/relay/aot/test_cpp_aot.py | 38 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 39d5570030d6..1fed42bf04b0 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -191,7 +191,7 @@ int AotExecutor::GetInputIndex(const std::string& name) { return i; } } - return -1; + ICHECK(false) << "Invalid input name."; } std::string AotExecutor::GetInputName(int index) { diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 3c7a3bc0ca12..5ae9864b2d80 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -248,5 +248,43 @@ def test_aot_input_name_with_special_character(target_kind: str, input_name: str assert (runner.get_output(0).asnumpy() == expected_output).all() +@pytest.mark.parametrize("target_kind", ["c", "llvm"]) +def test_aot_incorrect_input_name(target_kind: str): + """Test passing incorrect input name.""" + dtype = "float32" + correct_input_name = "input" + incorrect_input_name = "input1" + input = relay.var(correct_input_name, shape=(10, 5), dtype=dtype) + weight = relay.var("weight", shape=(1, 5), dtype=dtype) + output = relay.add(input, weight) + func = relay.Function([input, weight], output) + + input_data = np.random.rand(10, 5).astype(dtype) + weight_data = np.random.rand(1, 5).astype(dtype) + params = {"weight": weight_data} + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + tvm.IRModule.from_expr(func), + target=target_kind, + params=params, + executor=tvm.relay.backend.Executor("aot", {"interface-api": "packed"}), + ) + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", "-O0"]) + + loaded_mod = tvm.runtime.load_module(test_so_path) + runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) + inputs = {incorrect_input_name: input_data} + + error_regex = r"Invalid input name." + with pytest.raises(tvm.TVMError, match=error_regex): + runner.set_input(**inputs) + + with pytest.raises(tvm.TVMError, match=error_regex): + runner.get_input_index(incorrect_input_name) + + if __name__ == "__main__": tvm.testing.main() From b3e7033a1ba365fd84c103fa21e3ae4a087e5483 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 16 Mar 2023 13:44:36 -0700 Subject: [PATCH 2/2] fix name --- tests/python/relay/aot/test_cpp_aot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 5ae9864b2d80..c1b4fd817a84 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -254,10 +254,10 @@ def test_aot_incorrect_input_name(target_kind: str): dtype = "float32" correct_input_name = "input" incorrect_input_name = "input1" - input = relay.var(correct_input_name, shape=(10, 5), dtype=dtype) + input1 = relay.var(correct_input_name, shape=(10, 5), dtype=dtype) weight = relay.var("weight", shape=(1, 5), dtype=dtype) - output = relay.add(input, weight) - func = relay.Function([input, weight], output) + output = relay.add(input1, weight) + func = relay.Function([input1, weight], output) input_data = np.random.rand(10, 5).astype(dtype) weight_data = np.random.rand(1, 5).astype(dtype)