Skip to content

Commit

Permalink
Allow implicit conversion in TVM FFI to tvm::Bool (#5907)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jun 24, 2020
1 parent a1fb841 commit e8ccfd0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
31 changes: 30 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// common rule for RetValue and ArgValue
template <>
struct PackedFuncValueConverter<PrimExpr> {
// common rule for both RetValue and ArgValue.
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
Expand All @@ -500,6 +500,35 @@ struct PackedFuncValueConverter<PrimExpr> {
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
CHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
return val.AsObjectRef<tvm::Bool>();
}
};

} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_
20 changes: 0 additions & 20 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1147,26 +1147,6 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
} // namespace tir
} // namespace tvm

namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.

template <>
struct PackedFuncValueConverter<tvm::Integer> {
// common rule for RetValue and ArgValue
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};
} // namespace runtime
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
Expand Down

0 comments on commit e8ccfd0

Please sign in to comment.