diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index f705d591e6ee..99dd312d870a 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -106,10 +106,20 @@ def context(target, extra_files=None): if isinstance(tgt, str): tgt = Target(tgt) + # The TOPHUB file names rely on Target's device or kind. Both these types of + # information exist in Target.keys, but rules of filling this filed is not explicitly + # defined, we are afraid to rely only on Target.keys. At the same time Target.device + # is filled only if device was pointed explicitly in target string, that is not mandatory + # and in some cases we need to get information about device from Target.keys + # In priority order we verify: + # 1) Target.device + # 2) Target.keys + # 3) Target.kind possible_names = [] device = tgt.attrs.get("device", "") if device != "": possible_names.append(_alias(device)) + possible_names.extend(tgt.keys) possible_names.append(tgt.kind.name) all_packages = list(PACKAGE_VERSION.keys()) diff --git a/tests/python/unittest/test_autotvm_dispatch_context.py b/tests/python/unittest/test_autotvm_dispatch_context.py index 6ca062047fd7..ba75992128a8 100644 --- a/tests/python/unittest/test_autotvm_dispatch_context.py +++ b/tests/python/unittest/test_autotvm_dispatch_context.py @@ -19,6 +19,7 @@ to the parameters of workload""" from tvm import autotvm +import tvm @autotvm.template("testing/dispatch_fallback") @@ -31,5 +32,20 @@ def test_fallback(): simple_template(2, 3) +def test_tophub_kinds_match(): + def verify_arm_cpu(target): + best_by_targetkey = autotvm.tophub.context(target).best_by_targetkey + assert len(best_by_targetkey) + found_arm_cpu = False + for a, _ in best_by_targetkey: + if "arm_cpu" in a: + found_arm_cpu = True + break + assert found_arm_cpu + + verify_arm_cpu("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod") + verify_arm_cpu("llvm -model=snapdragon835 -mtriple=arm64-linux-android -mattr=+neon") + + if __name__ == "__main__": test_fallback()