Skip to content

Commit

Permalink
[FFI] Re-introduce the boxed primitive values
Browse files Browse the repository at this point in the history
Initially introduced in apache#16183,
these changes were reverted in
apache#17252 due to performance
degredation in some Relax models.  This could occur when a model
contained a large number of calls to `"vm.builtin.tuple_getitem"`,
which may occur when model weights are provided as a tuple.

This PR re-applies the changes from
apache#16183, but with the performance
degredation resolved.  The root cause was unnecessary type-checking
when converting from an untyped `tvm::ArrayNode*` to the typed
`tvm::Array<T>`, in the case where `T` is `ObjectRef`.
  • Loading branch information
Lunderberg committed Aug 8, 2024
1 parent ee9b8f1 commit c855809
Showing 1 changed file with 81 additions and 22 deletions.
103 changes: 81 additions & 22 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ struct ObjectTypeChecker<Array<T>> {
if (!ptr->IsInstance<ArrayNode>()) {
return String(ptr->GetTypeKey());
}

if constexpr (std::is_same_v<T, ObjectRef>) {
return NullOpt;
}

const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (size_t i = 0; i < n->size(); i++) {
const ObjectRef& p = (*n)[i];
Expand All @@ -504,6 +509,8 @@ struct ObjectTypeChecker<Array<T>> {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
if constexpr (std::is_same_v<T, ObjectRef>) return true;

const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const ObjectRef& p : *n) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
Expand All @@ -520,10 +527,21 @@ struct ObjectTypeChecker<Map<K, V>> {
static Optional<String> CheckAndGetMismatch(const Object* ptr) {
if (ptr == nullptr) return NullOpt;
if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey());

if constexpr (std::is_same_v<K, ObjectRef> && std::is_same_v<V, ObjectRef>) {
return NullOpt;
}

const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : *n) {
Optional<String> key_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
Optional<String> value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
Optional<String> key_type = NullOpt;
if constexpr (!std::is_same_v<K, ObjectRef>) {
key_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
}
Optional<String> value_type = NullOpt;
if constexpr (!std::is_same_v<V, ObjectRef>) {
value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
}
if (key_type.defined() || value_type.defined()) {
std::string key_name =
key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName();
Expand All @@ -537,10 +555,19 @@ struct ObjectTypeChecker<Map<K, V>> {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<MapNode>()) return false;

if constexpr (std::is_same_v<K, ObjectRef> && std::is_same_v<V, ObjectRef>) {
return true;
}

const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : *n) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
if constexpr (!std::is_same_v<K, ObjectRef>) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
}
if constexpr (!std::is_same_v<V, ObjectRef>) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
}
return true;
}
Expand Down Expand Up @@ -2454,6 +2481,10 @@ struct PackedFuncValueConverter<Array<T>> {
static Array<T> From(const TVMArgValue& val) {
auto untyped_array = val.AsObjectRef<Array<ObjectRef>>();

if constexpr (std::is_same_v<T, ObjectRef>) {
return untyped_array;
}

// Attempt to convert each item of the array into the desired
// type. If the items do not require a conversion, no copies are
// made.
Expand Down Expand Up @@ -2491,6 +2522,10 @@ struct PackedFuncValueConverter<Array<T>> {
static Array<T> From(const TVMRetValue& val) {
auto untyped_array = val.AsObjectRef<Array<ObjectRef>>();

if constexpr (std::is_same_v<T, ObjectRef>) {
return untyped_array;
}

return untyped_array.Map([](ObjectRef item) {
TVMRetValue item_val;
item_val = std::move(item);
Expand All @@ -2504,6 +2539,10 @@ struct PackedFuncValueConverter<Map<T, U>> {
static Map<T, U> From(const TVMArgValue& val) {
auto untyped_map = val.AsObjectRef<Map<ObjectRef, ObjectRef>>();

if constexpr (std::is_same_v<T, ObjectRef> && std::is_same_v<U, ObjectRef>) {
return Downcast<Map<T, U>>(untyped_map);
}

if (ObjectTypeChecker<Map<T, U>>::Check(untyped_map.get())) {
// Early bail-out for common case where no type conversions are
// required.
Expand All @@ -2513,20 +2552,28 @@ struct PackedFuncValueConverter<Map<T, U>> {
Map<T, U> output;
for (const auto& kv : untyped_map) {
T new_key = [&]() {
TVMValue pod_value;
int type_code;
TVMArgsSetter setter(&pod_value, &type_code);
setter(0, kv.first);
TVMArgValue pod_arg(pod_value, type_code);
return PackedFuncValueConverter<T>::From(pod_arg);
if constexpr (std::is_same_v<T, ObjectRef>) {
return kv.first;
} else {
TVMValue pod_value;
int type_code;
TVMArgsSetter setter(&pod_value, &type_code);
setter(0, kv.first);
TVMArgValue pod_arg(pod_value, type_code);
return PackedFuncValueConverter<T>::From(pod_arg);
}
}();
U new_value = [&]() {
TVMValue pod_value;
int type_code;
TVMArgsSetter setter(&pod_value, &type_code);
setter(0, kv.second);
TVMArgValue key_arg(pod_value, type_code);
return PackedFuncValueConverter<U>::From(key_arg);
if constexpr (std::is_same_v<U, ObjectRef>) {
return kv.second;
} else {
TVMValue pod_value;
int type_code;
TVMArgsSetter setter(&pod_value, &type_code);
setter(0, kv.second);
TVMArgValue key_arg(pod_value, type_code);
return PackedFuncValueConverter<U>::From(key_arg);
}
}();
output.Set(new_key, new_value);
}
Expand All @@ -2535,6 +2582,10 @@ struct PackedFuncValueConverter<Map<T, U>> {
static Map<T, U> From(const TVMRetValue& val) {
auto untyped_map = val.AsObjectRef<Map<ObjectRef, ObjectRef>>();

if constexpr (std::is_same_v<T, ObjectRef> && std::is_same_v<U, ObjectRef>) {
return Downcast<Map<T, U>>(untyped_map);
}

if (ObjectTypeChecker<Map<T, U>>::Check(untyped_map.get())) {
// Early bail-out for common case where no type conversions are
// required.
Expand All @@ -2544,14 +2595,22 @@ struct PackedFuncValueConverter<Map<T, U>> {
Map<T, U> output;
for (const auto& kv : untyped_map) {
T new_key = [&]() {
TVMRetValue pod;
pod = kv.first;
return PackedFuncValueConverter<T>::From(pod);
if constexpr (std::is_same_v<T, ObjectRef>) {
return kv.first;
} else {
TVMRetValue pod;
pod = kv.first;
return PackedFuncValueConverter<T>::From(pod);
}
}();
U new_value = [&]() {
TVMRetValue pod;
pod = kv.second;
return PackedFuncValueConverter<U>::From(pod);
if constexpr (std::is_same_v<T, ObjectRef>) {
return kv.second;
} else {
TVMRetValue pod;
pod = kv.second;
return PackedFuncValueConverter<U>::From(pod);
}
}();
output.Set(new_key, new_value);
}
Expand Down

0 comments on commit c855809

Please sign in to comment.