From 36aa2cc92be6d8769137064b8d13a86941e6d404 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 20 Oct 2020 22:37:13 +0000 Subject: [PATCH] Only use thrust for cuda target --- python/tvm/relay/op/strategy/cuda.py | 8 ++++++-- python/tvm/topi/cuda/nms.py | 7 ++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 7031365251aa..7346b6fe37e7 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -655,7 +655,9 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_argsort), name="argsort.cuda", ) - if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): + if target.kind.name == "cuda" and get_global_func( + "tvm.contrib.thrust.sort", allow_missing=True + ): strategy.add_implementation( wrap_compute_argsort(topi.cuda.argsort_thrust), wrap_topi_schedule(topi.cuda.schedule_argsort), @@ -674,7 +676,9 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_topk), name="topk.cuda", ) - if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): + if target.kind.name == "cuda" and get_global_func( + "tvm.contrib.thrust.sort", allow_missing=True + ): strategy.add_implementation( wrap_compute_topk(topi.cuda.topk_thrust), wrap_topi_schedule(topi.cuda.schedule_topk), diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2041f4c232a2..ed6e8f086a0d 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -483,7 +483,12 @@ def non_max_suppression( score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): + target = tvm.target.Target.current() + if ( + target + and target.kind.name == "cuda" + and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) + ): sort_tensor = argsort_thrust( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype )