From 97354cd781fb2da24beecbaec538e1c748622f9a Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 10 Jan 2023 16:00:23 -0800 Subject: [PATCH] [Fix,TOPI] Consolidate generic and x86 scatter nd The generic scatter nd was almost identical to the x86 one and was not tested. They now are one and the same. --- python/tvm/relay/op/strategy/x86.py | 2 +- python/tvm/topi/scatter.py | 53 ++++---- python/tvm/topi/x86/__init__.py | 1 - python/tvm/topi/x86/scatter.py | 119 ------------------ tests/python/topi/python/test_topi_scatter.py | 4 - 5 files changed, 25 insertions(+), 154 deletions(-) delete mode 100644 python/tvm/topi/x86/scatter.py diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 4585809f63e1..f76a2daa35a1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -772,7 +772,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target): """scatter_nd x86 strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_scatter_nd(topi.x86.scatter_nd), + wrap_compute_scatter_nd(topi.scatter_nd), wrap_topi_schedule(topi.generic.schedule_extern), name="scatter_nd.x86", plevel=10, diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index afb0d6633a2b..480ba830d119 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -268,6 +268,7 @@ def scatter_nd(data, indices, updates, mode): _verify_scatter_nd_inputs(data, indices, updates) def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + # pylint: disable=invalid-name ib = ir_builder.create() data = ib.buffer_ptr(data_ptr) @@ -275,56 +276,50 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - fused_shape = 1 - for i in data.shape: - fused_shape *= i - with ib.for_range(0, fused_shape) as i: - out[i] = data[i] - # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the data dimensions. + # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i - fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_data_dimension *= i + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i + + fused_shape = 1 + for i in data_ptr.shape: + fused_shape *= i + + with ib.for_range(0, fused_shape) as i: + out[i] = data[i] - with ib.for_range(0, fused_indices_dimension, name="i") as i: - with ib.for_range(0, fused_data_dimension, name="j") as j: - offset = fused_data_dimension + with ib.for_range(0, fused_indices_dimension) as i: + with ib.for_range(0, fused_updates_dimension, kind="parallel") as j: + offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - ib.emit( - AssertStmt( - indices[i + l * fused_indices_dimension] < shape[l], - StringImm("index out of bounds"), - Evaluate(0), - ) - ) - offset *= shape[l] - if mode == "add": - out[index] += updates[i * fused_data_dimension + j] - elif mode == "update": - out[index] = updates[i * fused_data_dimension + j] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() - out_buf = decl_buffer(shape, data.dtype, "out_buf") + out_buf = decl_buffer(data.shape, data.dtype, "out_buf") return extern( - [shape], + [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], - name="scatter_nd_generic", - tag="scatter_nd_generic", + name="scatter_nd.generic", + tag="scatter_nd.generic", ) diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index d075090f01ea..a54b156380d0 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -40,7 +40,6 @@ from .sparse import * from .conv2d_alter_op import * from .dense_alter_op import * -from .scatter import * from .group_conv2d import * from .math_alter_op import * from .concat import * diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py deleted file mode 100644 index 5eb5e6e99b6c..000000000000 --- a/python/tvm/topi/x86/scatter.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Scatter operators for x86""" -import tvm -from tvm import te -from ..scatter import _verify_scatter_nd_inputs - - -def scatter_nd(data, indices, updates, mode): - """Scatter elements from a n-dimension array. - - Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), - scatter_nd computes - - .. code-block:: - - output[indices[0, y_0, ..., y_{K-1}], - ..., - indices[M-1, y_0, ..., y_{K-1}], - x_M, - ..., - x_{N-1} - ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - - where the update function f is determinted by the mode. - - Parameters - ---------- - data : tvm.te.Tensor - The source array. - - indices : tvm.te.Tensor - The indices of the values to extract. - - updates : tvm.te.Tensor - The updates to apply at the Indices - - mode : string - The update mode for the algorithm, either "update" or "add" - If update, the update values will replace the input data - If add, the update values will be added to the input data - - Returns - ------- - ret : tvm.te.Tensor - """ - _verify_scatter_nd_inputs(data, indices, updates) - - def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): - # pylint: disable=invalid-name - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data_ptr) - indices = ib.buffer_ptr(indices_ptr) - updates = ib.buffer_ptr(updates_ptr) - out = ib.buffer_ptr(out_ptr) - - # We combine all the indices dimensions but the first one into a single - # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the update dimensions. - fused_indices_dimension = 1 - for i in indices_ptr.shape[1:]: - fused_indices_dimension *= i - - fused_updates_dimension = 1 - for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_updates_dimension *= i - - fused_shape = 1 - for i in data_ptr.shape: - fused_shape *= i - - with ib.for_range(0, fused_shape) as i: - out[i] = data[i] - - with ib.for_range(0, fused_indices_dimension) as i: - with ib.for_range(0, fused_updates_dimension, kind="parallel") as j: - offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] - elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) - - return ib.get() - - out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf") - return te.extern( - [data.shape], - [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), - dtype=data.dtype, - out_buffers=[out_buf], - name="scatter_nd_x86", - tag="scatter_nd_x86", - ) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index 648ef62a04ee..025e44889d63 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -33,10 +33,6 @@ def check_scatter_nd(data, indices, updates, out, mode="add"): lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode), topi.generic.schedule_extern, ), - "cpu": ( - lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode), - topi.generic.schedule_extern, - ), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm(