diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 98196c13af7f..0cf3da2511ef 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -491,6 +491,11 @@ struct ObjectTypeChecker> { if (!ptr->IsInstance()) { return String(ptr->GetTypeKey()); } + + if constexpr (std::is_same_v) { + return NullOpt; + } + const ArrayNode* n = static_cast(ptr); for (size_t i = 0; i < n->size(); i++) { const ObjectRef& p = (*n)[i]; @@ -504,6 +509,8 @@ struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; + if constexpr (std::is_same_v) return true; + const ArrayNode* n = static_cast(ptr); for (const ObjectRef& p : *n) { if (!ObjectTypeChecker::Check(p.get())) { @@ -520,10 +527,21 @@ struct ObjectTypeChecker> { static Optional CheckAndGetMismatch(const Object* ptr) { if (ptr == nullptr) return NullOpt; if (!ptr->IsInstance()) return String(ptr->GetTypeKey()); + + if constexpr (std::is_same_v && std::is_same_v) { + return NullOpt; + } + const MapNode* n = static_cast(ptr); for (const auto& kv : *n) { - Optional key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - Optional value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + Optional key_type = NullOpt; + if constexpr (!std::is_same_v) { + key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + } + Optional value_type = NullOpt; + if constexpr (!std::is_same_v) { + value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + } if (key_type.defined() || value_type.defined()) { std::string key_name = key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker::TypeName(); @@ -537,10 +555,19 @@ struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; + + if constexpr (std::is_same_v && std::is_same_v) { + return true; + } + const MapNode* n = static_cast(ptr); for (const auto& kv : *n) { - if (!ObjectTypeChecker::Check(kv.first.get())) return false; - if (!ObjectTypeChecker::Check(kv.second.get())) return false; + if constexpr (!std::is_same_v) { + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + } + if constexpr (!std::is_same_v) { + if (!ObjectTypeChecker::Check(kv.second.get())) return false; + } } return true; } @@ -2454,6 +2481,10 @@ struct PackedFuncValueConverter> { static Array From(const TVMArgValue& val) { auto untyped_array = val.AsObjectRef>(); + if constexpr (std::is_same_v) { + return untyped_array; + } + // Attempt to convert each item of the array into the desired // type. If the items do not require a conversion, no copies are // made. @@ -2491,6 +2522,10 @@ struct PackedFuncValueConverter> { static Array From(const TVMRetValue& val) { auto untyped_array = val.AsObjectRef>(); + if constexpr (std::is_same_v) { + return untyped_array; + } + return untyped_array.Map([](ObjectRef item) { TVMRetValue item_val; item_val = std::move(item); @@ -2504,6 +2539,10 @@ struct PackedFuncValueConverter> { static Map From(const TVMArgValue& val) { auto untyped_map = val.AsObjectRef>(); + if constexpr (std::is_same_v && std::is_same_v) { + return Downcast>(untyped_map); + } + if (ObjectTypeChecker>::Check(untyped_map.get())) { // Early bail-out for common case where no type conversions are // required. @@ -2513,20 +2552,28 @@ struct PackedFuncValueConverter> { Map output; for (const auto& kv : untyped_map) { T new_key = [&]() { - TVMValue pod_value; - int type_code; - TVMArgsSetter setter(&pod_value, &type_code); - setter(0, kv.first); - TVMArgValue pod_arg(pod_value, type_code); - return PackedFuncValueConverter::From(pod_arg); + if constexpr (std::is_same_v) { + return kv.first; + } else { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.first); + TVMArgValue pod_arg(pod_value, type_code); + return PackedFuncValueConverter::From(pod_arg); + } }(); U new_value = [&]() { - TVMValue pod_value; - int type_code; - TVMArgsSetter setter(&pod_value, &type_code); - setter(0, kv.second); - TVMArgValue key_arg(pod_value, type_code); - return PackedFuncValueConverter::From(key_arg); + if constexpr (std::is_same_v) { + return kv.second; + } else { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.second); + TVMArgValue key_arg(pod_value, type_code); + return PackedFuncValueConverter::From(key_arg); + } }(); output.Set(new_key, new_value); } @@ -2535,6 +2582,10 @@ struct PackedFuncValueConverter> { static Map From(const TVMRetValue& val) { auto untyped_map = val.AsObjectRef>(); + if constexpr (std::is_same_v && std::is_same_v) { + return Downcast>(untyped_map); + } + if (ObjectTypeChecker>::Check(untyped_map.get())) { // Early bail-out for common case where no type conversions are // required. @@ -2544,14 +2595,22 @@ struct PackedFuncValueConverter> { Map output; for (const auto& kv : untyped_map) { T new_key = [&]() { - TVMRetValue pod; - pod = kv.first; - return PackedFuncValueConverter::From(pod); + if constexpr (std::is_same_v) { + return kv.first; + } else { + TVMRetValue pod; + pod = kv.first; + return PackedFuncValueConverter::From(pod); + } }(); U new_value = [&]() { - TVMRetValue pod; - pod = kv.second; - return PackedFuncValueConverter::From(pod); + if constexpr (std::is_same_v) { + return kv.second; + } else { + TVMRetValue pod; + pod = kv.second; + return PackedFuncValueConverter::From(pod); + } }(); output.Set(new_key, new_value); }