Skip to content

Commit

Permalink
[Fix,TOPI] Consolidate generic and x86 scatter nd (apache#13755)
Browse files Browse the repository at this point in the history
The generic scatter nd was almost identical to the x86 one and was not tested. They now are one and the same.
  • Loading branch information
Tristan Konolige authored and fzi-peccia committed Mar 27, 2023
1 parent 7eae01f commit 3f8926a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 155 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.x86.scatter_nd),
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
Expand Down
55 changes: 25 additions & 30 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
from ..te import extern, hybrid
from ..tir import decl_buffer, expr, ir_builder


@hybrid.script
Expand Down Expand Up @@ -268,63 +268,58 @@ def scatter_nd(data, indices, updates, mode):
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
ib = ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

fused_shape = 1
for i in data.shape:
fused_shape *= i
with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i
fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in data_ptr.shape:
fused_shape *= i

with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

with ib.for_range(0, fused_indices_dimension, name="i") as i:
with ib.for_range(0, fused_data_dimension, name="j") as j:
offset = fused_data_dimension
with ib.for_range(0, fused_indices_dimension) as i:
with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
ib.emit(
AssertStmt(
indices[i + l * fused_indices_dimension] < shape[l],
StringImm("index out of bounds"),
Evaluate(0),
)
)
offset *= shape[l]
if mode == "add":
out[index] += updates[i * fused_data_dimension + j]
elif mode == "update":
out[index] = updates[i * fused_data_dimension + j]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

out_buf = decl_buffer(shape, data.dtype, "out_buf")
out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
[shape],
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_generic",
tag="scatter_nd_generic",
name="scatter_nd.generic",
tag="scatter_nd.generic",
)
1 change: 0 additions & 1 deletion python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
from .scatter import *
from .group_conv2d import *
from .math_alter_op import *
from .concat import *
119 changes: 0 additions & 119 deletions python/tvm/topi/x86/scatter.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/python/topi/python/test_topi_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
"cpu": (
lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm(
Expand Down

0 comments on commit 3f8926a

Please sign in to comment.