From ac9b1665728bd433d10b1f0ec224c27636bf0277 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 7 Apr 2020 13:33:56 -0700 Subject: [PATCH] [RUNTIME] Quick fix PackedFunc String passing --- include/tvm/runtime/packed_func.h | 14 ++++++++++---- tests/cpp/packed_func_test.cc | 6 ++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d5c017502426..1b3ad571dbaa 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -513,8 +513,11 @@ class TVMArgValue : public TVMPODValue_ { } } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { @@ -605,8 +608,11 @@ class TVMRetValue : public TVMPODValue_ { return *ptr(); } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 4a815ffd5d7d..d0313c60d984 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -95,6 +95,12 @@ TEST(PackedFunc, str) { CHECK(y == "hello"); *rv = x; })("hello"); + + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); + })(runtime::String("hello")); }