Skip to content

Commit

Permalink
[PR-15983][FFI] Allow IntImm arguments to PackedFunc with int parameter
Browse files Browse the repository at this point in the history
TVM containers, such as tvm::runtime::Array, require the contained
objects to inherit from `ObjectRef`.  As a result, the wrapper types
`IntImm`, `FloatImm`, and `StringImm` are often used to allow native
types in the TVM containers.  Conversions into these wrapper type may
be required when using a container, and may be performed automatically
when passing an object across the FFI.  By also providing conversion
to an unwrapped type, these automatic conversions are transparent
become transparent to users.

The trait can be specialized to add type specific conversion logic
from the TVMArgvalue and TVMRetValue.
  • Loading branch information
Lunderberg committed Nov 7, 2023
1 parent eb20534 commit 61f6322
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
36 changes: 36 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,24 @@ class IntImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
};

/* \brief FFI extention, ObjectRef to integer conversion
*
* If a PackedFunc expects an integer type, and the user passes an
* IntImm as the argument, this specialization allows it to be
* converted by the FFI.
*/
template <typename IntType>
struct runtime::PackedFuncObjectRefConverter<IntType,
std::enable_if_t<std::is_integral_v<IntType>>> {
static std::optional<IntType> TryFrom(const ObjectRef& obj) {
if (auto ptr = obj.as<IntImmNode>()) {
return ptr->value;
} else {
return std::nullopt;
}
}
};

/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
Expand Down Expand Up @@ -587,6 +605,24 @@ class FloatImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
};

/* \brief FFI extention, ObjectRef to integer conversion
*
* If a PackedFunc expects an integer type, and the user passes an
* IntImm as the argument, this specialization allows it to be
* converted by the FFI.
*/
template <typename FloatType>
struct runtime::PackedFuncObjectRefConverter<
FloatType, std::enable_if_t<std::is_floating_point_v<FloatType>>> {
static std::optional<FloatType> TryFrom(const ObjectRef& obj) {
if (auto ptr = obj.as<FloatImmNode>()) {
return ptr->value;
} else {
return std::nullopt;
}
}
};

/*!
* \brief Boolean constant.
*
Expand Down
74 changes: 74 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -537,6 +538,42 @@ struct ObjectTypeChecker<Map<K, V>> {
}
};

class TVMPODValue_;

/*!
* \brief Type trait to specify special value conversion rules from
* ObjectRef to primitive types.
*
* TVM containers, such as tvm::runtime::Array, require the contained
* objects to inherit from ObjectRef. As a result, the wrapper types
* IntImm, FloatImm, and StringImm are often used to hold primitive
* types inside a TVM container. Conversions into this type may be
* required when using a container, and may be performed
* automatically when passing an object across the FFI. By also
* handling conversions from wrapped to unwrapped types, these
* conversions can be transparent to users.
*
* The trait can be specialized to add type specific conversion logic
* from the TVMArgvalue and TVMRetValue.
*
* \tparam T The type (e.g. int64_t) which may be contained within the
* ObjectRef.
*
* \tparam (anonymous) An anonymous and unused type parameter, which
* may be used for SFINAE.
*/
template <typename T, typename = void>
struct PackedFuncObjectRefConverter {
/*!
* \brief Attempt to convert an ObjectRef from an argument value.
*
* \param obj The ObjectRef which may be convertible to T
*
* \return The converted result, or std::nullopt if not convertible.
*/
static std::optional<T> TryFrom(const ObjectRef& obj) { return std::nullopt; }
};

/*!
* \brief Internal base class to
* handle conversion to POD values.
Expand All @@ -549,25 +586,41 @@ class TVMPODValue_ {
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
} else if (auto opt = ThroughObjectRef<double>()) {
return opt.value();
} else if (auto opt = ThroughObjectRef<int64_t>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
operator int64_t() const {
if (auto opt = ThroughObjectRef<int64_t>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator uint64_t() const {
if (auto opt = ThroughObjectRef<uint64_t>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator int() const {
if (auto opt = ThroughObjectRef<int>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
ICHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
ICHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
if (auto opt = ThroughObjectRef<bool>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0;
}
Expand Down Expand Up @@ -638,6 +691,27 @@ class TVMPODValue_ {
TVMValue value_;
/*! \brief the type code */
int type_code_;

private:
/* \brief A utility function to check for conversions through
* PackedFuncObjectRefConverter
*
* \tparam T The type to attempt to convert into
*
* \return The converted type, or std::nullopt if the value cannot
* be converted into T.
*/
template <typename T>
std::optional<T> ThroughObjectRef() const {
if (IsObjectRef<ObjectRef>()) {
if (std::optional<T> from_obj =
PackedFuncObjectRefConverter<T>::TryFrom(AsObjectRef<ObjectRef>())) {
return from_obj.value();
}
}

return std::nullopt;
}
};

/*!
Expand Down
60 changes: 60 additions & 0 deletions tests/cpp/packed_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,63 @@ TEST(TypedPackedFunc, RValue) {
tf(1, true);
}
}

TEST(TypedPackedFunc, IntImmWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(int)> typed_func = [](int x) {};
PackedFunc func = typed_func;

// Integer argument may be provided
func(5);

// IntImm argument may be provided, automatically unwrapped.
tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
func(lvalue_intimm);

// Unwrapping of IntImm argument works for rvalues as well
func(tvm::IntImm(DataType::Int(32), 10));
}

TEST(TypedPackedFunc, FloatImmWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(double)> typed_func = [](double x) {};
PackedFunc func = typed_func;

// Argument may be provided as a floating point. If provided as an
// integer, it will be converted to a float.
func(static_cast<double>(5.0));
func(static_cast<int>(5));

// IntImm and FloatImm arguments may be provided, and are
// automatically unwrapped. These arguments work correctly for
// either lvalue or rvalue arguments.

tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5);

func(lvalue_intimm);
func(lvalue_floatimm);
func(tvm::IntImm(DataType::Int(32), 10));
func(tvm::FloatImm(DataType::Float(32), 10.5));
}

TEST(TypedPackedFunc, BoolWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(bool)> typed_func = [](bool x) {};
PackedFunc func = typed_func;

// Argument may be provided as a floating point. If provided as an
// integer, it will be converted to a float.
func(true);

tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
func(lvalue_intimm);
func(tvm::IntImm(DataType::Int(32), 10));

tvm::Bool lvalue_bool(false);
func(lvalue_bool);
func(tvm::Bool(true));
}

0 comments on commit 61f6322

Please sign in to comment.