Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add embedding op and gradient #6794

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,14 @@ struct BatchToSpaceNDAttrs : public tvm::AttrsNode<BatchToSpaceNDAttrs> {
}
}; // struct BatchToSpaceNDAttrs

struct EmbedAttrs : public tvm::AttrsNode<EmbedAttrs> {
TVM_DECLARE_ATTRS(EmbedAttrs, "relay.attrs.EmbedAttrs") {}
};

struct EmbedGradAttrs : public tvm::AttrsNode<EmbedGradAttrs> {
TVM_DECLARE_ATTRS(EmbedGradAttrs, "relay.attrs.EmbedGradAttrs") {}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
6 changes: 5 additions & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
raise RuntimeError(
f"Cannot find workload for {task_name}. You may need to "
"register a compute function for it with "
f'`@tvm.autotvm.register_topi_compute("{task_name}")`'
)
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,6 @@ def arange_grad(orig, grad):

return [grad_start, grad_stop, grad_step]


@register_gradient("gather_nd")
def gather_nd_grad(orig, grad):
"""
Expand Down Expand Up @@ -866,3 +865,9 @@ def less_equal_grad(orig, grad):
Returns the gradient of less_equal.
"""
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]


@register_gradient("nn.embed")
def embed_grad(orig, grad):
table, indices = orig.args
return [_nn.embed_grad(table, indices, grad), zeros_like(indices)]
21 changes: 21 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,27 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_injective_schedule("nn.batch_to_space_nd")


# embed
@reg.register_compute("nn.embed")
def compute_embed(attrs, inputs, out_type):
"""Compute definition of embed"""
return [topi.nn.embed(inputs[0], inputs[1])]


reg.register_injective_schedule("nn.embed")
reg.register_pattern("nn.embed", OpPattern.INJECTIVE)


@reg.register_compute("nn.embed_grad")
def compute_embed_grad(attrs, inputs, out_type):
"""Compute definition of embed_grad"""
return [topi.nn.embed_grad(inputs[0], inputs[1], inputs[2])]


reg.register_strategy("nn.embed_grad", strategy.embed_grad_strategy)
reg.register_pattern("nn.embed_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


#####################
# Shape functions #
#####################
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3307,3 +3307,61 @@ def batch_to_space_nd(data, block_shape, crops):
"""

return _make.batch_to_space_nd(data, block_shape, crops)


def embed(table, indices):
"""Lookup indices in an embedding table.

The embedding lookup is defined as:

.. math::

O[i,j] = T[I[i],j]

where :math:`T` is the embedding table, and :math:`I` is the indices to
lookup. This is specialization of take with two dimensional input and axis
= 0.


Parameters
----------
table : tvm.te.Tensor
M x N tensor of embedding locations.
indices : tvm.te.Tensor
Length K vector of indices to lookup in `table`.

Returns
-------
Output : tvm.te.Tensor
K x N tensor corresponding to the rows of `table` indexed with `indices`.
"""
return _make.embed(table, indices)


def embed_grad(table, indices, grad):
"""Gradient of :py:func:`embed`.

The gradient of an embedding lookup is defined as:

.. math::

O[I[i],j] = G[i, j]

where :math:`G` is the gradient, and :math:`I` is the indices to lookup.


Parameters
----------
table : tvm.te.Tensor
M x N tensor of embedding locations.
indices : tvm.te.Tensor
Length K vector of indices to lookup in `table`.
grad : tvm.te.Tensor
K x N tensor of the gradient to propagate.

Returns
-------
Output : tvm.te.Tensor
K x N tensor containing the propagated gradient.
"""
return _make.embed_grad(table, indices, grad)
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
name="cumsum.cuda",
)
return strategy


@embed_grad_strategy.register(["cuda", "gpu"])
def embed_grad_strategy_gpu(attrs, inputs, out_type, target):
"""gpu strategy for embed_grad"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.cuda.embed_grad),
wrap_topi_schedule(topi.cuda.schedule_embed_grad),
name="embed_grad.cuda",
)
return strategy
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,3 +1432,25 @@ def cumsum_strategy(attrs, inputs, out_type, target):
name="cumsum.generic",
)
return strategy


def wrap_compute_embed_grad(topi_compute):
"""wrap embed_grad"""

def _wrapped(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2])]

return _wrapped


@override_native_generic_func("embed_grad_strategy")
def embed_grad_strategy(attrs, inputs, out_type, target):
"""embed gradient generic strategy"""
logger.warning("embed_grad is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.nn.embed_grad),
wrap_topi_schedule(topi.generic.schedule_embed_grad),
name="embed_grad.generic",
)
return strategy
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,15 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ
"Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
)
return strategy


@embed_grad_strategy.register("cpu")
def embed_grad_strategy_cpu(attrs, inputs, out_type, target):
"""x86 strategy for embed_grad"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.nn.embed_grad),
wrap_topi_schedule(topi.x86.schedule_embed_grad),
name="embed_grad.x86",
)
return strategy
5 changes: 5 additions & 0 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def _cast(func_id, args):
return _expr.Cast(func_id, args[0])


def cast(_, args):
_internal_assert(args.__len__() == 2, "cast requires two arguments: dtype, value")
return _expr.Cast(args[0], args[1])


float16 = float32 = float64 = _cast # pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast # pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast # pylint: disable=invalid-name
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/te/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def ninf(dtype):
return numpy.iinfo(dtype).min


def cast(x, dtype):
"""Convert `x` to `dtype`."""
return getattr(numpy, dtype)(x)


HYBRID_GLOBALS = {
"unroll": range,
"vectorize": range,
Expand Down Expand Up @@ -150,8 +155,13 @@ def ninf(dtype):
"float64": numpy.float64,
"ceil_div": lambda a, b: (a + b - 1) // b,
"max_num_threads": max_num_threads,
<<<<<<< HEAD
"inf": inf,
"ninf": inf,
||||||| parent of 57e650690... [TOPI] Add embed op and gradient.
=======
"cast": cast,
>>>>>>> 57e650690... [TOPI] Add embed op and gradient.
}


Expand Down
29 changes: 29 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,4 +745,33 @@ def terminate_self():
sys.exit(-1)


def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule):
"""Compare a numpy inputs and output of a function to the results of the TVM version.

Parameters
----------
inputs : Sequence[numpy.nd.array]
List of input numpy arrays to pass to the function.
output : numpy.nd.array
Verified correct function output.
target : tvm.target.Target
Target to run on.
ctx : tvm.TVMContext
Context to run on.
compute : callable
Topi compute function to test against.
schedule : callable
Topi scheduling function to test against.
"""
te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs]
te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx)
with tvm.target.Target(target):
out = compute(*te_inputs)
s = schedule([out])
func = tvm.build(s, te_inputs + [out])
arys = [tvm.nd.array(x, ctx=ctx) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4)


tvm._ffi._init_api("testing", __name__)
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import *
from .pooling import *
from .nn import schedule_lrn
from .nn import schedule_lrn, schedule_embed_grad
from .batch_matmul import *
from .batch_matmul_tensorcore import *
from .vision import *
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/topi/cuda/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""scheduler functions for cuda backend"""
from __future__ import absolute_import as _abs

from tvm import te
from .. import cpp


Expand All @@ -36,3 +37,43 @@ def schedule_lrn(outs):
The computation schedule for the op.
"""
return cpp.cuda.schedule_lrn(outs)


def loads_per_thread(dtype):
"""Number elements per load per thread"""
s = regex.search("[0-9]+", dtype)
assert s is not None
byts = int(s.group()) // 8
return 16 // byts


def schedule_embed_grad(outs):
"""Schedule for embed_grad

Parameters
----------
outs: Array of Tensor
The computation graph description of embed_grad
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
s = te.create_schedule([outs[0].op])
# this should be autotuned, but we can't with hybrid script
vec_size = loads_per_thread(outs[0].dtype)
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
num_warps = 4
out = s.outputs[0].output(0)
i, j = s[out].op.axis
jo, ji = s[out].split(j, factor=vec_size)
s[out].vectorize(ji)
joo, joi = s[out].split(jo, factor=warp_size)
s[out].bind(joi, te.thread_axis("threadIdx.x"))
_, jooi = s[out].split(joo, factor=num_warps)
s[out].bind(jooi, te.thread_axis("threadIdx.y"))
s[out].bind(i, te.thread_axis("blockIdx.x"))

return s
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,20 @@ def schedule_correlation_nchw(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_embed_grad(outs):
"""Schedule for embed_grad

Parameters
----------
outs: Array of Tensor
The computation graph description of embed_grad
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .deformable_conv2d import *
from .depthwise_conv2d import *
from .elemwise import *
from .embed import *
from .dilate import *
from .flatten import *
from .dense import *
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/topi/nn/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Embedding operators"""

from tvm import te


@te.hybrid.script
def embed(table, indices):
out = output_tensor((indices.shape[0], table.shape[1]), table.dtype)
for i in range(indices.shape[0]):
for j in range(table.shape[1]):
out[i, j] = table[indices[i], j]
return out


@te.hybrid.script
def embed_grad(table, indices, grad_in):
grad_out = output_tensor(table.shape, table.dtype)
for i in range(table.shape[0]):
for j in range(table.shape[1]):
grad_out[i, j] = cast(table.dtype, 0.0)
for i in range(indices.shape[0]):
for j in range(table.shape[1]):
grad_out[indices[i], j] += grad_in[i, j]
return grad_out
Loading