Skip to content

Commit

Permalink
[Fix,TOPI] Consolidate generic and x86 scatter nd
Browse files Browse the repository at this point in the history
The generic scatter nd was almost identical to the x86 one and was not
tested. They now are one and the same.
  • Loading branch information
Tristan Konolige committed Jan 11, 2023
1 parent d979949 commit 97354cd
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 154 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.x86.scatter_nd),
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
Expand Down
53 changes: 24 additions & 29 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,63 +268,58 @@ def scatter_nd(data, indices, updates, mode):
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
ib = ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

fused_shape = 1
for i in data.shape:
fused_shape *= i
with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i
fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in data_ptr.shape:
fused_shape *= i

with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

with ib.for_range(0, fused_indices_dimension, name="i") as i:
with ib.for_range(0, fused_data_dimension, name="j") as j:
offset = fused_data_dimension
with ib.for_range(0, fused_indices_dimension) as i:
with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
ib.emit(
AssertStmt(
indices[i + l * fused_indices_dimension] < shape[l],
StringImm("index out of bounds"),
Evaluate(0),
)
)
offset *= shape[l]
if mode == "add":
out[index] += updates[i * fused_data_dimension + j]
elif mode == "update":
out[index] = updates[i * fused_data_dimension + j]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

out_buf = decl_buffer(shape, data.dtype, "out_buf")
out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
[shape],
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_generic",
tag="scatter_nd_generic",
name="scatter_nd.generic",
tag="scatter_nd.generic",
)
1 change: 0 additions & 1 deletion python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
from .scatter import *
from .group_conv2d import *
from .math_alter_op import *
from .concat import *
119 changes: 0 additions & 119 deletions python/tvm/topi/x86/scatter.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/python/topi/python/test_topi_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
"cpu": (
lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm(
Expand Down

0 comments on commit 97354cd

Please sign in to comment.