Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix,TOPI] Consolidate generic and x86 scatter nd #13755

Merged
merged 2 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.x86.scatter_nd),
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
Expand Down
55 changes: 25 additions & 30 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
from ..te import extern, hybrid
from ..tir import decl_buffer, expr, ir_builder


@hybrid.script
Expand Down Expand Up @@ -268,63 +268,58 @@ def scatter_nd(data, indices, updates, mode):
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just give it a good name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from the old code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you fix it while you are here, small change

ib = ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

fused_shape = 1
for i in data.shape:
fused_shape *= i
with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i
fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in data_ptr.shape:
fused_shape *= i

with ib.for_range(0, fused_shape) as i:
out[i] = data[i]

with ib.for_range(0, fused_indices_dimension, name="i") as i:
with ib.for_range(0, fused_data_dimension, name="j") as j:
offset = fused_data_dimension
with ib.for_range(0, fused_indices_dimension) as i:
with ib.for_range(0, fused_updates_dimension, kind="parallel") as j:
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
ib.emit(
AssertStmt(
indices[i + l * fused_indices_dimension] < shape[l],
StringImm("index out of bounds"),
Evaluate(0),
)
)
offset *= shape[l]
if mode == "add":
out[index] += updates[i * fused_data_dimension + j]
elif mode == "update":
out[index] = updates[i * fused_data_dimension + j]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

out_buf = decl_buffer(shape, data.dtype, "out_buf")
out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
[shape],
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_generic",
tag="scatter_nd_generic",
name="scatter_nd.generic",
tag="scatter_nd.generic",
)
1 change: 0 additions & 1 deletion python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
from .scatter import *
from .group_conv2d import *
from .math_alter_op import *
from .concat import *
119 changes: 0 additions & 119 deletions python/tvm/topi/x86/scatter.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/python/topi/python/test_topi_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
"cpu": (
lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm(
Expand Down