Skip to content

Commit

Permalink
[Relay/TOPI][Op] Add batch_matmul in relay and TOPI (apache#2561)
Browse files Browse the repository at this point in the history
* Add batch_dot and cpu schedule

* Add relay support for batch_dot

* Rename batch_dot to batch_matmul

* nits

* Add missing file

* Put batch_matmul and dense x86 schedule in separate files

* Fix pylint

* Remove unused import

* Add cuda schedule for batch_matmul

* Add test case with larger batch size

* Add batch_matmul in api doc

* Fix quantize pass rounding error

* Fix pylint and minor change

* bug fix
  • Loading branch information
icemelon authored and wweic committed Mar 9, 2019
1 parent 1184dae commit 8e3058d
Show file tree
Hide file tree
Showing 23 changed files with 715 additions and 212 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ List of operators
topi.nn.upsampling
topi.nn.softmax
topi.nn.dense
topi.nn.batch_matmul
topi.nn.log_softmax
topi.nn.conv2d_nchw
topi.nn.conv2d_hwcn
Expand Down Expand Up @@ -138,6 +139,7 @@ topi.nn
.. autofunction:: topi.nn.upsampling
.. autofunction:: topi.nn.softmax
.. autofunction:: topi.nn.dense
.. autofunction:: topi.nn.batch_matmul
.. autofunction:: topi.nn.log_softmax
.. autofunction:: topi.nn.conv2d_nchw
.. autofunction:: topi.nn.conv2d_hwcn
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.device_copy
tvm.relay.annotation.on_device
tvm.relay.reverse_reshape
tvm.relay.nn.batch_matmul


Level 1 Definitions
Expand Down Expand Up @@ -264,3 +265,4 @@ Level 10 Definitions
.. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape
.. autofunction:: tvm.relay.nn.batch_matmul
14 changes: 13 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
raise RuntimeError("batch_dot: only support transpose_a=False")
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)


def _mx_arange(inputs, attrs):
assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1:
Expand Down Expand Up @@ -389,6 +401,7 @@ def _mx_roi_align(inputs, attrs):
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
"concat" : _mx_concat,
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output,
Expand All @@ -403,7 +416,6 @@ def _mx_roi_align(inputs, attrs):
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,

}

# set identity list
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]

@reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_batch_matmul(outputs)

reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d
@reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target):
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,31 @@ def batch_norm(data,
return TupleWrapper(result, 3)


def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
.. math::
\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
Parameters
----------
x : tvm.relay.Expr
The first input.
y : tvm.relay.Expr
The second input.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y)


def contrib_conv2d_winograd_without_weight_transform(data,
weight,
tile_size,
Expand Down
63 changes: 63 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,5 +654,68 @@ axis to be the last item in the input shape.
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);


// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
if (x->shape.size() != 3 || y->shape.size() != 3) return false;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;

Array<tvm::Expr> oshape = x->shape;
oshape.Set(2, y->shape[1]);

// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
return true;
}


// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x,
Expr y) {
static const Op& op = Op::Get("nn.batch_matmul");
return CallNode::make(op, {x, y}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
});


RELAY_REGISTER_OP("nn.batch_matmul")
.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
are data in batch.
.. math::
batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T)
- **x**: `(b, m, k)`
- **y**: `(b, n, k)`
- **out**: `(b, m, n)`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "3D Tensor", "First input.")
.add_argument("y", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel);


} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def test_dense():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)



if __name__ == "__main__":
test_concatenate()
test_bias_add()
Expand Down
36 changes: 35 additions & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi
import topi.testing

def test_collapse_sum_like():
shape = (3, 4, 5, 6)
Expand Down Expand Up @@ -126,7 +128,6 @@ def verify_reverse_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reverse_reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z)
print(zz.checked_type)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

Expand All @@ -144,8 +145,41 @@ def verify_reverse_reshape(shape, newshape, oshape):
verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))

def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType(y_shape, dtype))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)

func = relay.Function([x, y], z)
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
z_np = topi.testing.batch_matmul(x_np, y_np)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
z = intrp.evaluate(func)(x_np, y_np)
tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)

def test_batch_matmul():
b, m, n, k = tvm.var("b"), tvm.var("m"), tvm.var("n"), tvm.var("k")
x = relay.var("x", relay.TensorType((b, m, k), "float32"))
y = relay.var("y", relay.TensorType((b, n, k), "float32"))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((b, m, n), "float32")

verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))


if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
test_slice_like()
test_reverse_reshape()
test_batch_matmul()
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def make_qgraph(data, weight):
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy())
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)


if __name__ == "__main__":
Expand Down
49 changes: 49 additions & 0 deletions topi/include/topi/nn/batch_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*!
* Copyright (c) 2019 by Contributors
* \brief Batch matmul op constructions
* \file nn/batch_matmul.h
*/
#ifndef TOPI_NN_BATCH_MATMUL_H_
#define TOPI_NN_BATCH_MATMUL_H_

#include <string>

#include "topi/tags.h"
#include "tvm/tvm.h"

namespace topi {
namespace nn {
using namespace tvm;

/*!
* \brief Creates an operation that calculates matrix multiplication in batch.
*
* \param x Tensor with shape [batch, M, K]
* \param y Tensor with shape [batch, N, K]
*
* \return Tensor with shape [batch, M, N]
*/
inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
const tvm::Tensor& y) {
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";

auto batch = x->shape[0];
auto M = x->shape[1];
auto K = x->shape[2];
auto N = y->shape[1];

auto k = tvm::reduce_axis(Range(0, K), "k");
auto result = tvm::compute(
{ batch, M, N },
[&](Var b, Var i, Var j) {
return tvm::sum(x(b, i, k) * y(b, j, k), { k });
}, "tensor", "batch_matmul");

return result;
}

} // namespace nn
} // namespace topi

#endif // TOPI_NN_BATCH_MATMUL_H_
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
from . import ssd
from .ssd import *
Expand Down
Loading

0 comments on commit 8e3058d

Please sign in to comment.