Skip to content

Commit

Permalink
Only use thrust for cuda target
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Oct 23, 2020
1 parent 839c06f commit 36aa2cc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,9 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda",
)
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
Expand All @@ -674,7 +676,9 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda",
)
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,12 @@ def non_max_suppression(
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
target = tvm.target.Target.current()
if (
target
and target.kind.name == "cuda"
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
Expand Down

0 comments on commit 36aa2cc

Please sign in to comment.