From 27686e020ec27cfb29a9cc2c587d01acaedb039d Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Fri, 4 Dec 2020 14:52:41 -0800 Subject: [PATCH] [TOPI][OP] cuda for argwhere (#6868) * argwhere * cuda schedule * sort argwhere result * Use single block and thrust to fix flaky behavior * format * used dynamic strided_slice * Fix dynamic strided_slice * try new strided_slice * Improve dynamic strided slice to bind data depedent shape var. * all tests pass * remove print * use new strided_slice * clean Co-authored-by: Yao Wang --- 3rdparty/vta-hw | 2 +- python/tvm/relay/op/_transform.py | 16 +- python/tvm/relay/op/strategy/cuda.py | 12 + python/tvm/relay/op/strategy/generic.py | 39 +- python/tvm/topi/argwhere.py | 2 + python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/argwhere.py | 654 ++++++++++++++++++ python/tvm/topi/cuda/sort.py | 2 + tests/python/relay/test_any.py | 17 +- .../python/topi/python/test_topi_argwhere.py | 86 +++ 10 files changed, 795 insertions(+), 36 deletions(-) create mode 100644 python/tvm/topi/cuda/argwhere.py create mode 100644 tests/python/topi/python/test_topi_argwhere.py diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw index 12fb486a491b..87ce9acfae55 160000 --- a/3rdparty/vta-hw +++ b/3rdparty/vta-hw @@ -1 +1 @@ -Subproject commit 12fb486a491b75d70ec4c5e0a0cd112ab49a95bc +Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 1092c308cf49..48d9bc716a4a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -83,21 +83,7 @@ def compute_strided_set(attrs, inputs, output_type): _reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE) # argwhere -@_reg.register_compute("argwhere") -def compute_argwhere(attrs, inputs, output_type): - """Compute definition of argwhere""" - output_shape = [] - for s in output_type.shape: - if hasattr(s, "value"): - output_shape.append(s) - else: - # see Any, replace it with a var - output_shape.append(te.var("any_dim", "int32")) - new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") - return [topi.argwhere(new_output_type, inputs[0])] - - -_reg.register_schedule("argwhere", strategy.schedule_argwhere) +_reg.register_strategy("argwhere", strategy.argwhere_strategy) # scatter @_reg.register_compute("scatter") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 029690680e7d..fc80c9ed6171 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target): name="correlation.cuda", ) return strategy + + +@argwhere_strategy.register(["cuda", "gpu"]) +def argwhere_strategy_cuda(attrs, inputs, out_type, target): + """argwhere cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_argwhere(topi.cuda.argwhere), + wrap_topi_schedule(topi.cuda.schedule_argwhere), + name="argwhere.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index c289c65758d9..ff3d01d35988 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi, _ffi +from tvm import topi, _ffi, te, ir from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple from tvm.target import generic_func, override_native_generic_func from .. import op as _op @@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target): return strategy -# argwhere -@generic_func -def schedule_argwhere(attrs, outs, target): - """schedule argwhere""" - with target: - return topi.generic.schedule_argwhere(outs) - - # scatter @override_native_generic_func("scatter_strategy") def scatter_strategy(attrs, outs, out_type, target): @@ -1231,3 +1223,32 @@ def correlation_strategy(attrs, inputs, out_type, target): name="correlation.generic", ) return strategy + + +# argwhere +def wrap_compute_argwhere(topi_compute): + """wrap argwhere topi compute""" + + def _compute_argwhere(attrs, inputs, out_type): + output_shape = [] + for s in out_type.shape: + if hasattr(s, "value"): + output_shape.append(s) + else: + output_shape.append(te.var("any_dim", "int32")) + new_output_type = ir.TensorType(output_shape, "int32") + return [topi_compute(new_output_type, inputs[0])] + + return _compute_argwhere + + +@override_native_generic_func("argwhere_strategy") +def argwhere_strategy(attrs, inputs, out_type, target): + """argwhere generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_argwhere(topi.argwhere), + wrap_topi_schedule(topi.generic.schedule_argwhere), + name="argwhere.generic", + ) + return strategy diff --git a/python/tvm/topi/argwhere.py b/python/tvm/topi/argwhere.py index 75c19af35e5c..c2b658a4e92f 100644 --- a/python/tvm/topi/argwhere.py +++ b/python/tvm/topi/argwhere.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Argwhere operator""" +import tvm from tvm.te import hybrid @@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition): return a +@tvm.target.generic_func def argwhere(output_shape, condition): """Find the indices of elements of a tensor that are non-zero. diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 3ff544f4bb3e..23c625ae7ff7 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -54,3 +54,4 @@ from .conv2d_hwnc_tensorcore import * from .correlation import * from .sparse import * +from .argwhere import * diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py new file mode 100644 index 000000000000..e39004dc76a9 --- /dev/null +++ b/python/tvm/topi/cuda/argwhere.py @@ -0,0 +1,654 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): + """Get sort function for argwhere. mode 0 for topk and others for argsort.""" + if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): + ret = topk_thrust if mode == 0 else argsort_thrust + else: + logger.warning( + "It's highly recommended to enable thrust library with set(USE_THRUST ON)" + " when compiling argwhere for cuda target. Otherwise, it can result in" + " significant performance degradation or incorrect result" + ) + ret = topk if mode == 0 else argsort + + return ret + + +def argwhere_1d_ir(condition, out): + """Low level IR for argwhere 1D + + Parameters + ---------- + condition : Buffer + The condition buffer. + + out : Buffer + The output buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + a0 = condition.shape[0] + + condition = ib.buffer_ptr(condition) + out = ib.buffer_ptr(out) + + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") + tmp = ib.allocate("int32", (1,), name="tmp", scope="local") + one_count = tvm.tir.const(1, dtype="int32") + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + # Limit threads to a single block to make sure atomic_add works normally. + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = a0 // nthread_tx + 1 + valid_index[0] = 0 + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < a0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), + one_count, + ) + out[tmp[0]] = idx + + return ib.get() + + +def argwhere_1d(output_shape, condition): + """Compute for argwhere 1D + + Parameters + ---------- + condition : list of int or tvm.tir.Any + The output shape + + out : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + condition_buf = tvm.tir.decl_buffer( + condition.shape, condition.dtype, "data_buf", data_alignment=8 + ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + + out = te.extern( + [output_shape], + [condition], + lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[condition_buf], + out_buffers=[out_buf], + name="argwhere_1d", + tag="argwhere1d_gpu", + ) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: + return out + + sorted_out = _get_sort_func()( + out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" + ) + + return sorted_out + + +def argwhere_2d_ir(condition, out): + """Low level IR for argwhere 2D + + Parameters + ---------- + condition : Buffer + The condition buffer. + + out : Buffer + The output buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + a0 = condition.shape[0] + a1 = condition.shape[1] + + condition = ib.buffer_ptr(condition) + out = ib.buffer_ptr(out) + + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") + tmp = ib.allocate("int32", (1,), name="tmp", scope="local") + one_count = tvm.tir.const(1, dtype="int32") + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + # Limit threads to a single block to make sure atomic_add works normally. + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = (a0 * a1) // nthread_tx + 1 + + valid_index[0] = 0 + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < (a0 * a1)): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), + one_count, + ) + out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1) + out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1) + + return ib.get() + + +def argwhere_2d(output_shape, condition): + """Compute for argwhere 2D + + Parameters + ---------- + condition : list of int or tvm.tir.Any + The output shape + + out : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + condition_buf = tvm.tir.decl_buffer( + condition.shape, condition.dtype, "data_buf", data_alignment=8 + ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + + out = te.extern( + [output_shape], + [condition], + lambda ins, outs: argwhere_2d_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[condition_buf], + out_buffers=[out_buf], + name="argwhere_2d", + tag="argwhere2d_gpu", + ) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: + return out + + sort_func = _get_sort_func(1) + + # sort the output from the least significant to the most significant + # column. + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + out1 = strided_slice(out, [0, 1], [out.shape[0], 2]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + + out1 = strided_slice(out, [0, 0], [out.shape[0], 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + + out = adv_index(out, [out3]) + else: + out1 = strided_slice(out, [0, 1], [out.shape[0], 2], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + + out1 = strided_slice(out, [0, 0], [out.shape[0], 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + return out + + +def argwhere_3d_ir(condition, out): + """Low level IR for argwhere 3D + + Parameters + ---------- + condition : Buffer + The condition buffer. + + out : Buffer + The output buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + a0 = condition.shape[0] + a1 = condition.shape[1] + a2 = condition.shape[2] + s1 = a1 * a2 + s0 = a0 * s1 + + condition = ib.buffer_ptr(condition) + out = ib.buffer_ptr(out) + + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") + tmp = ib.allocate("int32", (1,), name="tmp", scope="local") + one_count = tvm.tir.const(1, dtype="int32") + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + # Limit threads to a single block to make sure atomic_add works normally. + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = s0 // nthread_tx + 1 + + fdiv = tvm.tir.floordiv + fmod = tvm.tir.floormod + + valid_index[0] = 0 + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), + one_count, + ) + out[tmp[0] * 3] = fdiv(idx, s1) + out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2) + out[tmp[0] * 3 + 2] = fmod(idx, a2) + + return ib.get() + + +def argwhere_3d(output_shape, condition): + """Compute for argwhere 3D + + Parameters + ---------- + condition : list of int or tvm.tir.Any + The output shape + + out : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + condition_buf = tvm.tir.decl_buffer( + condition.shape, condition.dtype, "data_buf", data_alignment=8 + ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + + out = te.extern( + [output_shape], + [condition], + lambda ins, outs: argwhere_3d_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[condition_buf], + out_buffers=[out_buf], + name="argwhere_3d", + tag="argwhere3d_gpu", + ) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + sort_func = _get_sort_func(1) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(3)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(3)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + return out + + +def argwhere_4d_ir(condition, out): + """Low level IR for argwhere 4D + + Parameters + ---------- + condition : Buffer + The condition buffer. + + out : Buffer + The output buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + a0 = condition.shape[0] + a1 = condition.shape[1] + a2 = condition.shape[2] + a3 = condition.shape[3] + s1 = a2 * a3 + s2 = a1 * s1 + s0 = a0 * s2 + + condition = ib.buffer_ptr(condition) + out = ib.buffer_ptr(out) + + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") + tmp = ib.allocate("int32", (1,), name="tmp", scope="local") + one_count = tvm.tir.const(1, dtype="int32") + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + # Limit threads to a single block to make sure atomic_add works normally. + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = s0 // nthread_tx + 1 + + fdiv = tvm.tir.floordiv + fmod = tvm.tir.floormod + + valid_index[0] = 0 + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), + one_count, + ) + out[tmp[0] * 4] = fdiv(idx, s2) + out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1) + out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3) + out[tmp[0] * 4 + 3] = fmod(idx, a3) + + return ib.get() + + +def argwhere_4d(output_shape, condition): + """Compute for argwhere 4D + + Parameters + ---------- + condition : list of int or tvm.tir.Any + The output shape + + out : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + condition_buf = tvm.tir.decl_buffer( + condition.shape, condition.dtype, "data_buf", data_alignment=8 + ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + + out = te.extern( + [output_shape], + [condition], + lambda ins, outs: argwhere_4d_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[condition_buf], + out_buffers=[out_buf], + name="argwhere_4d", + tag="argwhere4d_gpu", + ) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + sort_func = _get_sort_func(1) + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(4)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(4)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + + return out + + +def argwhere_5d_ir(condition, out): + """Low level IR for argwhere 5D + + Parameters + ---------- + condition : Buffer + The condition buffer. + + out : Buffer + The output buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + a0 = condition.shape[0] + a1 = condition.shape[1] + a2 = condition.shape[2] + a3 = condition.shape[3] + a4 = condition.shape[4] + s1 = a3 * a4 + s2 = a2 * s1 + s3 = a1 * s2 + s0 = a0 * s3 + + condition = ib.buffer_ptr(condition) + out = ib.buffer_ptr(out) + + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") + tmp = ib.allocate("int32", (1,), name="tmp", scope="local") + one_count = tvm.tir.const(1, dtype="int32") + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + # Limit threads to a single block to make sure atomic_add works normally. + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = s0 // nthread_tx + 1 + + fdiv = tvm.tir.floordiv + fmod = tvm.tir.floormod + + valid_index[0] = 0 + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), + one_count, + ) + out[tmp[0] * 5] = fdiv(idx, s3) + out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2) + out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1) + out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4) + out[tmp[0] * 5 + 4] = fmod(idx, a4) + + return ib.get() + + +def argwhere_5d(output_shape, condition): + """Compute for argwhere 5D + + Parameters + ---------- + condition : list of int or tvm.tir.Any + The output shape + + out : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + condition_buf = tvm.tir.decl_buffer( + condition.shape, condition.dtype, "data_buf", data_alignment=8 + ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + + out = te.extern( + [output_shape], + [condition], + lambda ins, outs: argwhere_5d_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[condition_buf], + out_buffers=[out_buf], + name="argwhere_5d", + tag="argwhere5d_gpu", + ) + + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + sort_func = _get_sort_func(1) + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(5)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(5)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + + return out + + +def argwhere(output_shape, condition): + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + output_shape : tvm.te.Tensor + Tensor with output shape info. + + condition : tvm.te.Tensor + Tensor with boolean values. + + Returns + ------- + out : tvm.te.Tensor + Indices of non-zero elements. + """ + if len(condition.shape) == 1: + return argwhere_1d(output_shape.shape, condition) + if len(condition.shape) == 2: + return argwhere_2d(output_shape.shape, condition) + if len(condition.shape) == 3: + return argwhere_3d(output_shape.shape, condition) + if len(condition.shape) == 4: + return argwhere_4d(output_shape.shape, condition) + if len(condition.shape) == 5: + return argwhere_5d(output_shape.shape, condition) + raise ValueError("Argwhere does not support rank higher than 5") + + +def schedule_argwhere(outs): + """Schedule for argwhere on cuda. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argwhere + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for argwhere + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + return s diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ac14f5aae779..2a7f4eb92daa 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -550,6 +550,8 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8), ] + is_ascend = 1 if is_ascend else 0 + out = te.extern( [data.shape, data.shape], [data], diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index ee67e67b282f..ddf8e980706b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -219,30 +219,25 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): mod["main"] = relay.Function([x], y) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) expected = np.argwhere(data) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data).asnumpy() - assert result.shape == expected.shape - tvm.testing.assert_allclose(result.flatten(), expected.flatten()) - - # TODO(@zhiics) argwhere gpu schedule is currently not avaiable - # check_result([data], mod, expected, flatten=True) + check_result([data], mod, expected, flatten=True) -@tvm.testing.uses_gpu +# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have +# to use thrust to guarantee the correct results which has been tested locally. +# @tvm.testing.uses_gpu def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) verify_any_argwhere(any_dims(2), (5, 5)) + verify_any_argwhere(any_dims(2), (5, 5), "int32") + verify_any_argwhere(any_dims(2), (5, 5), "int8") verify_any_argwhere(any_dims(3), (5, 5, 5)) verify_any_argwhere(any_dims(4), (5, 5, 5, 5)) verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5)) verify_any_argwhere(any_dims(1), (5,), "int32") - verify_any_argwhere(any_dims(2), (5, 5), "int32") verify_any_argwhere(any_dims(3), (5, 5, 5), "int32") verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32") verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32") verify_any_argwhere(any_dims(1), (5,), "int8") - verify_any_argwhere(any_dims(2), (5, 5), "int8") verify_any_argwhere(any_dims(3), (5, 5, 5), "int8") verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8") verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py new file mode 100644 index 000000000000..5cb7cd44513e --- /dev/null +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test for argwhere operator""" +import numpy as np + +import tvm +from tvm import te +from tvm import topi +import tvm.topi.testing + +_argwhere_schedule = { + "generic": topi.generic.schedule_argwhere, + "gpu": topi.cuda.schedule_argwhere, +} + +_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere} + + +def verify_argwhere(data_shape): + dtype = "int32" + np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype) + np_out = np.argwhere(np_data) + out_shape = np_out.shape[0] + np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) + + out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype) + condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) + + def check_device(device, ctx): + ctx = tvm.context(device, 0) + if not ctx.exist or device not in _argwhere_compute: + return + + with tvm.target.Target(device): + out = _argwhere_compute[device](out_shape, condition) + s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule) + sch = s_func(out) + + func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere") + + args = [tvm.nd.array(np_shape, ctx)] + args.append(tvm.nd.array(np_data, ctx)) + args.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=condition.dtype)) + func(*args) + np.set_printoptions(threshold=np.inf) + tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out)) + + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) + + +# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have +# to use thrust to guarantee the correct results which has been tested locally. +# @tvm.testing.uses_gpu +def test_argwhere(): + verify_argwhere((1,)) + verify_argwhere((100,)) + verify_argwhere((1, 1)) + verify_argwhere((5, 3)) + verify_argwhere((32, 64)) + verify_argwhere((128, 65)) + verify_argwhere((200, 500)) + verify_argwhere((6, 5, 3)) + verify_argwhere((1, 1, 1)) + verify_argwhere((1, 1, 1, 1)) + verify_argwhere((6, 4, 5, 3)) + verify_argwhere((1, 1, 1, 1, 1)) + verify_argwhere((6, 4, 5, 3, 7)) + + +if __name__ == "__main__": + test_argwhere()