Skip to content

Commit

Permalink
[TOPI] Add embed op and gradient.
Browse files Browse the repository at this point in the history
The embed op is a specialization of take with a 2D lookup table.
  • Loading branch information
tkonolige committed Nov 5, 2020
1 parent 7291a92 commit 57e6506
Show file tree
Hide file tree
Showing 20 changed files with 501 additions and 3 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,14 @@ struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
}
}; // struct CorrelationAttrs

struct EmbedAttrs : public tvm::AttrsNode<EmbedAttrs> {
TVM_DECLARE_ATTRS(EmbedAttrs, "relay.attrs.EmbedAttrs") {}
};

struct EmbedGradAttrs : public tvm::AttrsNode<EmbedGradAttrs> {
TVM_DECLARE_ATTRS(EmbedGradAttrs, "relay.attrs.EmbedGradAttrs") {}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
raise RuntimeError(f"Cannot find workload for {task_name}. You may need to register a compute function for it with `@tvm.autotvm.register_topi_compute(\"{task_name}\")`")
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,9 @@ def arange_grad(orig, grad):
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]


@register_gradient("nn.embed")
def embed_grad(orig, grad):
table, indices = orig.args
return [_nn.embed_grad(table, indices, grad), zeros_like(indices)]
21 changes: 21 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,27 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE)


# embed
@reg.register_compute("nn.embed")
def compute_embed_grad(attrs, inputs, out_type):
"""Compute definition of embed"""
return [topi.nn.embed(inputs[0], inputs[1])]


reg.register_injective_schedule("nn.embed")
reg.register_pattern("nn.embed", OpPattern.INJECTIVE)


@reg.register_compute("nn.embed_grad")
def compute_embed_grad(attrs, inputs, out_type):
"""Compute definition of embed_grad"""
return [topi.nn.embed_grad(inputs[0], inputs[1], inputs[2])]


reg.register_strategy("nn.embed_grad", strategy.embed_grad_strategy)
reg.register_pattern("nn.embed_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


#####################
# Shape functions #
#####################
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3158,3 +3158,61 @@ def correlation(
return _make.correlation(
data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, layout
)


def embed(table, indices):
"""Lookup indices in an embedding table.
The embedding lookup is defined as:
.. math::
O[i,j] = T[I[i],j]
where :math:`T` is the embedding table, and :math:`I` is the indices to
lookup. This is specialization of take with two dimensional input and axis
= 0.
Parameters
----------
table : tvm.te.Tensor
M x N tensor of embedding locations.
indices : tvm.te.Tensor
Length K vector of indices to lookup in `table`.
Returns
-------
Output : tvm.te.Tensor
K x N tensor corresponding to the rows of `table` indexed with `indices`.
"""
return _make.embed(table, indices)


def embed_grad(table, indices, grad):
"""Gradient of :py:func:`embed`.
The gradient of an embedding lookup is defined as:
.. math::
O[I[i],j] = G[i, j]
where :math:`G` is the gradient, and :math:`I` is the indices to lookup.
Parameters
----------
table : tvm.te.Tensor
M x N tensor of embedding locations.
indices : tvm.te.Tensor
Length K vector of indices to lookup in `table`.
grad : tvm.te.Tensor
K x N tensor of the gradient to propagate.
Returns
-------
Output : tvm.te.Tensor
K x N tensor containing the propagated gradient.
"""
return _make.embed_grad(table, indices, grad)
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
name="correlation.cuda",
)
return strategy


@embed_grad_strategy.register(["cuda", "gpu"])
def embed_grad_strategy_gpu(attrs, inputs, out_type, target):
"""gpu strategy for embed_grad"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.cuda.embed_grad),
wrap_topi_schedule(topi.cuda.schedule_embed_grad),
name="embed_grad.cuda",
)
return strategy
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,3 +1187,25 @@ def correlation_strategy(attrs, inputs, out_type, target):
name="correlation.generic",
)
return strategy


def wrap_compute_embed_grad(topi_compute):
"""wrap embed_grad"""

def _wrapped(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2])]

return _wrapped


@override_native_generic_func("embed_grad_strategy")
def embed_grad_strategy(attrs, inputs, out_type, target):
"""embed gradient generic strategy"""
logger.warning("embed_grad is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.nn.embed_grad),
wrap_topi_schedule(topi.generic.schedule_embed_grad),
name="embed_grad.generic",
)
return strategy
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,15 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
name="bitserial_dense.x86",
)
return strategy


@embed_grad_strategy.register("cpu")
def embed_grad_strategy_cpu(attrs, inputs, out_type, target):
"""x86 strategy for embed_grad"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_embed_grad(topi.nn.embed_grad),
wrap_topi_schedule(topi.x86.schedule_embed_grad),
name="embed_grad.x86",
)
return strategy
5 changes: 5 additions & 0 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def _cast(func_id, args):
return _expr.Cast(func_id, args[0])


def cast(func_id, args):
_internal_assert(args.__len__() == 2, "cast requires two arguments: dtype, value")
return _expr.Cast(args[0], args[1])


float16 = float32 = float64 = _cast # pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast # pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast # pylint: disable=invalid-name
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/te/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def max_num_threads(allow_none=True):
return Target.current(allow_none).max_num_threads


def cast(x, dtype):
"""Convert `x` to `dtype`."""
return getattr(numpy, dtype)(x)


HYBRID_GLOBALS = {
"unroll": range,
"vectorize": range,
Expand Down Expand Up @@ -142,6 +147,7 @@ def max_num_threads(allow_none=True):
"float64": numpy.float64,
"ceil_div": lambda a, b: (a + b - 1) // b,
"max_num_threads": max_num_threads,
"cast": cast,
}


Expand Down
29 changes: 29 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,4 +714,33 @@ def func(f):
return wrap(args)


def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule):
"""Compare a numpy inputs and output of a function to the results of the TVM version.
Parameters
----------
inputs : Sequence[numpy.nd.array]
List of input numpy arrays to pass to the function.
output : numpy.nd.array
Verified correct function output.
target : tvm.target.Target
Target to run on.
ctx : tvm.TVMContext
Context to run on.
compute : callable
Topi compute function to test against.
schedule : callable
Topi scheduling function to test against.
"""
te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs]
te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx)
with tvm.target.Target(target):
out = compute(*te_inputs)
s = schedule([out])
func = tvm.build(s, te_inputs + [out])
arys = [tvm.nd.array(x, ctx=ctx) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4)


tvm._ffi._init_api("testing", __name__)
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import *
from .pooling import *
from .nn import schedule_lrn
from .nn import schedule_lrn, schedule_embed_grad
from .batch_matmul import *
from .vision import *
from .ssd import *
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/topi/cuda/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from .. import cpp

from tvm import autotvm, te


def schedule_lrn(outs):
"""Schedule for LRN
Expand All @@ -36,3 +38,42 @@ def schedule_lrn(outs):
The computation schedule for the op.
"""
return cpp.cuda.schedule_lrn(outs)


def loads_per_thread(dtype):
"""Number elements per load per thread"""
s = regex.search("[0-9]+", dtype)
assert s is not None
bytes = int(s.group()) // 8
return 16 // bytes


def schedule_embed_grad(outs):
"""Schedule for embed_grad
Parameters
----------
outs: Array of Tensor
The computation graph description of embed_grad
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
s = te.create_schedule([outs[0].op])
vec_size = loads_per_thread(outs[0].dtype) # this should be autotuned, but we can't with hybrid script
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
num_warps = 4
out = s.outputs[0].output(0)
i, j = s[out].op.axis
jo, ji = s[out].split(j, factor=vec_size)
s[out].vectorize(ji)
joo, joi = s[out].split(jo, factor=warp_size)
s[out].bind(joi, te.thread_axis("threadIdx.x"))
jooo, jooi = s[out].split(joo, factor=num_warps)
s[out].bind(jooi, te.thread_axis("threadIdx.y"))
s[out].bind(i, te.thread_axis("blockIdx.x"))

return s
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,20 @@ def schedule_correlation_nchw(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_embed_grad(outs):
"""Schedule for embed_grad
Parameters
----------
outs: Array of Tensor
The computation graph description of embed_grad
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .deformable_conv2d import *
from .depthwise_conv2d import *
from .elemwise import *
from .embed import *
from .dilate import *
from .flatten import *
from .dense import *
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/topi/nn/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Embedding operators"""

from tvm import te


@te.hybrid.script
def embed(table, indices):
out = output_tensor((indices.shape[0], table.shape[1]), table.dtype)
for i in range(indices.shape[0]):
for j in range(table.shape[1]):
out[i, j] = table[indices[i], j]
return out


@te.hybrid.script
def embed_grad(table, indices, grad_in):
grad_out = output_tensor(table.shape, table.dtype)
for i in range(table.shape[0]):
for j in range(table.shape[1]):
grad_out[i, j] = cast(table.dtype, 0.0)
for i in range(indices.shape[0]):
for j in range(table.shape[1]):
grad_out[indices[i], j] += grad_in[i, j]
return grad_out
Loading

0 comments on commit 57e6506

Please sign in to comment.