diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b2ce50d91f58..b6083c88cfd8 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -483,9 +483,9 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { +// common rule for RetValue and ArgValue template <> struct PackedFuncValueConverter { - // common rule for both RetValue and ArgValue. static PrimExpr From(const TVMPODValue_& val) { if (val.type_code() == kTVMNullptr) { return PrimExpr(ObjectPtr(nullptr)); @@ -500,6 +500,35 @@ struct PackedFuncValueConverter { return PrimExpr::FromObject_(val.AsObjectRef()); } }; + +template <> +struct PackedFuncValueConverter { + static tvm::Integer From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Integer(ObjectPtr(nullptr)); + } + if (val.type_code() == kTVMArgInt) { + return Integer(val.operator int()); + } + return val.AsObjectRef(); + } +}; + +template <> +struct PackedFuncValueConverter { + static tvm::Bool From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Bool(ObjectPtr(nullptr)); + } + if (val.type_code() == kTVMArgInt) { + int v = val.operator int(); + CHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; + return Bool(static_cast(v)); + } + return val.AsObjectRef(); + } +}; + } // namespace runtime } // namespace tvm #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a51f70984011..f0e6d898d7d8 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1147,26 +1147,6 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm -namespace tvm { -namespace runtime { -// Additional implementattion overloads for PackedFunc. - -template <> -struct PackedFuncValueConverter { - // common rule for RetValue and ArgValue - static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); - } - if (val.type_code() == kDLInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); - } -}; -} // namespace runtime -} // namespace tvm - namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};