From 47a344c2fd14bed3f8218bf60442ca486438b727 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 1 Mar 2019 14:53:46 -0800 Subject: [PATCH] [Relay/TOPI][Op] Add batch_matmul in relay and TOPI (#2561) * Add batch_dot and cpu schedule * Add relay support for batch_dot * Rename batch_dot to batch_matmul * nits * Add missing file * Put batch_matmul and dense x86 schedule in separate files * Fix pylint * Remove unused import * Add cuda schedule for batch_matmul * Add test case with larger batch size * Add batch_matmul in api doc * Fix quantize pass rounding error * Fix pylint and minor change * bug fix --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + python/tvm/relay/frontend/mxnet.py | 14 +- python/tvm/relay/op/nn/_nn.py | 15 ++ python/tvm/relay/op/nn/nn.py | 25 +++ src/relay/op/nn/nn.cc | 63 ++++++ tests/python/relay/test_op_level1.py | 1 - tests/python/relay/test_op_level10.py | 36 +++- tests/python/relay/test_pass_quantize.py | 2 +- topi/include/topi/nn/batch_matmul.h | 49 +++++ topi/python/topi/cuda/__init__.py | 1 + topi/python/topi/cuda/batch_matmul.py | 89 +++++++++ topi/python/topi/generic/nn.py | 6 + topi/python/topi/nn/__init__.py | 1 + topi/python/topi/nn/batch_matmul.py | 35 ++++ topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/batch_matmul.py | 26 +++ topi/python/topi/util.py | 26 +++ topi/python/topi/x86/batch_matmul.py | 53 +++++ topi/python/topi/x86/dense.py | 208 +++++++++++++++++++ topi/python/topi/x86/nn.py | 209 +------------------- topi/src/topi.cc | 10 + topi/tests/python/test_topi_batch_matmul.py | 53 +++++ 23 files changed, 715 insertions(+), 212 deletions(-) create mode 100644 topi/include/topi/nn/batch_matmul.h create mode 100644 topi/python/topi/cuda/batch_matmul.py create mode 100644 topi/python/topi/nn/batch_matmul.py create mode 100644 topi/python/topi/testing/batch_matmul.py create mode 100644 topi/python/topi/x86/batch_matmul.py create mode 100644 topi/python/topi/x86/dense.py create mode 100644 topi/tests/python/test_topi_batch_matmul.py diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 9680adc1231b..e8b63637ffb5 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -41,6 +41,7 @@ List of operators topi.nn.upsampling topi.nn.softmax topi.nn.dense + topi.nn.batch_matmul topi.nn.log_softmax topi.nn.conv2d_nchw topi.nn.conv2d_hwcn @@ -138,6 +139,7 @@ topi.nn .. autofunction:: topi.nn.upsampling .. autofunction:: topi.nn.softmax .. autofunction:: topi.nn.dense +.. autofunction:: topi.nn.batch_matmul .. autofunction:: topi.nn.log_softmax .. autofunction:: topi.nn.conv2d_nchw .. autofunction:: topi.nn.conv2d_hwcn diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index e2da42b6ab32..7958d6cbe553 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -152,6 +152,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.device_copy tvm.relay.annotation.on_device tvm.relay.reverse_reshape + tvm.relay.nn.batch_matmul Level 1 Definitions @@ -264,3 +265,4 @@ Level 10 Definitions .. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.annotation.on_device .. autofunction:: tvm.relay.reverse_reshape +.. autofunction:: tvm.relay.nn.batch_matmul diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9ef5f626393a..3d3bb8e4fd84 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -283,6 +283,18 @@ def _mx_multibox_detection(inputs, attrs): return _op.vision.nms(ret[0], ret[1], **new_attrs1) +def _mx_batch_dot(inputs, attrs): + assert len(inputs) == 2 + a, b = inputs + transpose_a = attrs.get_bool("transpose_a", False) + transpose_b = attrs.get_bool("transpose_b", False) + if transpose_a is True: + raise RuntimeError("batch_dot: only support transpose_a=False") + if transpose_b is False: + b = _op.transpose(b, axes=[0, 2, 1]) + return _op.batch_matmul(a, b) + + def _mx_arange(inputs, attrs): assert len(inputs) == 0 if attrs.get_int("repeat", 1) != 1: @@ -389,6 +401,7 @@ def _mx_roi_align(inputs, attrs): "expand_dims" : _mx_expand_dims, "Concat" : _mx_concat, "concat" : _mx_concat, + "batch_dot" : _mx_batch_dot, "LeakyReLU" : _mx_leaky_relu, "_arange" : _mx_arange, "SoftmaxOutput" : _mx_softmax_output, @@ -403,7 +416,6 @@ def _mx_roi_align(inputs, attrs): # "broadcast_to", # "gather_nd", # "Crop" : _crop_like, - } # set identity list diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index a4b41d92371e..0c2733ecae92 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -46,6 +46,21 @@ def schedule_dense(attrs, outputs, target): reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +# batch_matmul +@reg.register_compute("nn.batch_matmul") +def compute_batch_matmul(attrs, inputs, out_type, target): + """Compute definition of batch_matmul""" + return [topi.nn.batch_matmul(inputs[0], inputs[1])] + +@reg.register_schedule("nn.batch_matmul") +def schedule_batch_matmul(attrs, outputs, target): + """Schedule definition of batch_matmul""" + with target: + return topi.generic.schedule_batch_matmul(outputs) + +reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + # conv2d @reg.register_compute("nn.conv2d") def compute_conv2d(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 06cd79a8ff8b..41b2148ec390 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -767,6 +767,31 @@ def batch_norm(data, return TupleWrapper(result, 3) +def batch_matmul(x, y): + r""" + Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data + in batch. + + .. math:: + + \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + + Parameters + ---------- + x : tvm.relay.Expr + The first input. + + y : tvm.relay.Expr + The second input. + + Returns + ------- + result: tvm.relay.Expr + The computed result. + """ + return _make.batch_matmul(x, y) + + def contrib_conv2d_winograd_without_weight_transform(data, weight, tile_size, diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9ab841cf4286..59f68d9d8880 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -654,5 +654,68 @@ axis to be the last item in the input shape. .set_support_level(1) .add_type_rel("BatchNorm", BatchNormRel); + +// relay.nn.batch_matmul +bool BatchMatmulRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + if (x->shape.size() != 3 || y->shape.size() != 3) return false; + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) + << "BatchDot: batch dimension doesn't match, " + << " x shape=" << x->shape + << ", y shape=" << y->shape; + CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) + << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape + << ", y shape=" << y->shape; + + Array oshape = x->shape; + oshape.Set(2, y->shape[1]); + + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype)); + return true; +} + + +// Positional relay function to create batch_matmul operator used by frontend FFI. +Expr MakeBatchMatmul(Expr x, + Expr y) { + static const Op& op = Op::Get("nn.batch_matmul"); + return CallNode::make(op, {x, y}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.batch_matmul") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeBatchMatmul, args, rv); + }); + + +RELAY_REGISTER_OP("nn.batch_matmul") +.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` +are data in batch. + +.. math:: + + batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T) + +- **x**: `(b, m, k)` +- **y**: `(b, n, k)` +- **out**: `(b, m, n)`. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("x", "3D Tensor", "First input.") +.add_argument("y", "3D Tensor", "Second input.") +.set_support_level(10) +.add_type_rel("BatchMatmul", BatchMatmulRel); + + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index d29b808be0d1..b954e42bf1ab 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -306,7 +306,6 @@ def test_dense(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - if __name__ == "__main__": test_concatenate() test_bias_add() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index a6e169e23a6c..34285d2b18dd 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -4,6 +4,8 @@ import tvm from tvm import relay from tvm.relay.testing import ctx_list +import topi +import topi.testing def test_collapse_sum_like(): shape = (3, 4, 5, 6) @@ -126,7 +128,6 @@ def verify_reverse_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) z = relay.reverse_reshape(x, newshape=newshape) zz = relay.ir_pass.infer_type(z) - print(zz.checked_type) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") @@ -144,8 +145,41 @@ def verify_reverse_reshape(shape, newshape, oshape): verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4)) verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) +def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): + x = relay.var("x", relay.TensorType(x_shape, dtype)) + y = relay.var("y", relay.TensorType(y_shape, dtype)) + z = relay.nn.batch_matmul(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = topi.testing.batch_matmul(x_np, y_np) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + z = intrp.evaluate(func)(x_np, y_np) + tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) + +def test_batch_matmul(): + b, m, n, k = tvm.var("b"), tvm.var("m"), tvm.var("n"), tvm.var("k") + x = relay.var("x", relay.TensorType((b, m, k), "float32")) + y = relay.var("y", relay.TensorType((b, n, k), "float32")) + z = relay.nn.batch_matmul(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((b, m, n), "float32") + + verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) + verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) + verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) + verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) + + if __name__ == "__main__": test_collapse_sum_like() test_broadcast_to_like() test_slice_like() test_reverse_reshape() + test_batch_matmul() diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index 6d65d7b2d9ee..2e2389d16244 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -75,7 +75,7 @@ def make_qgraph(data, weight): graph = relay.create_executor('graph') res0 = graph.evaluate(qgraph0)(dataset[0]['data']) res1 = graph.evaluate(qgraph1)(dataset[0]['data']) - tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy()) + tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) if __name__ == "__main__": diff --git a/topi/include/topi/nn/batch_matmul.h b/topi/include/topi/nn/batch_matmul.h new file mode 100644 index 000000000000..968e1b0c697c --- /dev/null +++ b/topi/include/topi/nn/batch_matmul.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2019 by Contributors + * \brief Batch matmul op constructions + * \file nn/batch_matmul.h + */ +#ifndef TOPI_NN_BATCH_MATMUL_H_ +#define TOPI_NN_BATCH_MATMUL_H_ + +#include + +#include "topi/tags.h" +#include "tvm/tvm.h" + +namespace topi { +namespace nn { +using namespace tvm; + +/*! +* \brief Creates an operation that calculates matrix multiplication in batch. +* +* \param x Tensor with shape [batch, M, K] +* \param y Tensor with shape [batch, N, K] +* +* \return Tensor with shape [batch, M, N] +*/ +inline tvm::Tensor batch_matmul(const tvm::Tensor& x, + const tvm::Tensor& y) { + CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; + CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; + + auto batch = x->shape[0]; + auto M = x->shape[1]; + auto K = x->shape[2]; + auto N = y->shape[1]; + + auto k = tvm::reduce_axis(Range(0, K), "k"); + auto result = tvm::compute( + { batch, M, N }, + [&](Var b, Var i, Var j) { + return tvm::sum(x(b, i, k) * y(b, j, k), { k }); + }, "tensor", "batch_matmul"); + + return result; +} + +} // namespace nn +} // namespace topi + +#endif // TOPI_NN_BATCH_MATMUL_H_ diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 91c2235fcf70..ba577cd944f0 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -14,6 +14,7 @@ from .pooling import schedule_pool, schedule_global_pool from .extern import schedule_extern from .nn import schedule_lrn, schedule_l2_normalize +from .batch_matmul import schedule_batch_matmul from .vision import * from . import ssd from .ssd import * diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py new file mode 100644 index 000000000000..a1fa256028da --- /dev/null +++ b/topi/python/topi/cuda/batch_matmul.py @@ -0,0 +1,89 @@ +# pylint: disable=invalid-name,too-many-locals,unused-variable +"""cuda batch_matmul operators""" +from __future__ import absolute_import as _abs +import tvm + +from .. import generic +from ..util import traverse_inline, get_const_tuple, get_max_power2_factor + + +@generic.schedule_batch_matmul.register(["cuda", "gpu"]) +def schedule_batch_matmul(outs): + """Schedule for batch_matmul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of batch_matmul + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + s = tvm.create_schedule([x.op for x in outs]) + + def _schedule(op): + C = op.output(0) + A, B = s[C].op.input_tensors + _, M, N = get_const_tuple(C.shape) + AA = s.cache_read(A, "shared", [C]) + AL = s.cache_read(AA, "local", [C]) + BB = s.cache_read(B, "shared", [C]) + BL = s.cache_read(BB, "local", [C]) + CC = s.cache_write(C, "local") + + b, y, x = s[C].op.axis + y_bn = get_max_power2_factor(M, 64) + x_bn = get_max_power2_factor(N, 64) + by, y = s[C].split(y, y_bn) + bx, x = s[C].split(x, x_bn) + y_nthreads = min(y_bn, 8) + x_nthreads = min(x_bn, 8) + ty, yi = s[C].split(y, nparts=y_nthreads) + tx, xi = s[C].split(x, nparts=x_nthreads) + thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x") + thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y") + + s[C].reorder(b, by, bx, ty, tx, yi, xi) + s[C].bind(b, tvm.thread_axis("blockIdx.z")) + s[C].bind(by, tvm.thread_axis("blockIdx.y")) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].pragma(yi, "auto_unroll_max_step", 16) + + s[CC].compute_at(s[C], tx) + _, yi, xi = s[CC].op.axis + k, = s[CC].op.reduce_axis + ko, ki = s[CC].split(k, 8) + s[CC].reorder(ko, ki, yi, xi) + s[CC].pragma(ki, "auto_unroll_max_step", 16) + + s[AA].compute_at(s[CC], ko) + s[AL].compute_at(s[CC], ki) + s[BB].compute_at(s[CC], ko) + s[BL].compute_at(s[CC], ki) + _, y, k = s[AA].op.axis + ty, yi = s[AA].split(y, nparts=y_nthreads) + tx, ki = s[AA].split(k, nparts=x_nthreads) + s[AA].reorder(ty, tx, yi, ki) + s[AA].bind(ty, thread_y) + s[AA].bind(tx, thread_x) + s[AA].pragma(yi, "auto_unroll_max_step", 16) + + _, x, k = s[BB].op.axis + ty, xi = s[BB].split(x, nparts=y_nthreads) + tx, ki = s[BB].split(k, nparts=x_nthreads) + s[BB].bind(ty, thread_y) + s[BB].bind(tx, thread_x) + s[BB].reorder(ty, tx, xi, ki) + s[BB].pragma(xi, "auto_unroll_max_step", 16) + + def _callback(op): + if "batch_matmul" in op.tag: + _schedule(op) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 8c303e5be182..00b742f24e64 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -410,3 +410,9 @@ def schedule_l2_normalize(outs): target = tvm.target.current_target(allow_none=False) cpp_target = cpp.TEST_create_target(target.target_name) return cpp.generic.default_schedule(cpp_target, outs, False) + +@tvm.target.generic_func +def schedule_batch_matmul(outs): + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.generic.default_schedule(cpp_target, outs, False) diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index cfb9e566279a..941fec91a6bd 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -17,3 +17,4 @@ from .local_response_norm import * from .bitserial_conv2d import * from .l2_normalize import * +from .batch_matmul import * diff --git a/topi/python/topi/nn/batch_matmul.py b/topi/python/topi/nn/batch_matmul.py new file mode 100644 index 000000000000..07e363868b05 --- /dev/null +++ b/topi/python/topi/nn/batch_matmul.py @@ -0,0 +1,35 @@ +"""Binary Neural Network (BNN) Operators""" +# pylint: disable=invalid-name +from __future__ import absolute_import as _abs +import tvm +from ..util import get_const_tuple + + +def batch_matmul(x, y): + """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are + data in batch. + + Parameters + ---------- + x : tvm.Tensor + 3-D with shape [batch, M, K] + + y : tvm.TEnsor + 3-D with shape [batch, N, K] + + Returns + ------- + output : tvm.Tensor + 3-D with shape [batch, M, N] + """ + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + batch, M, K = x.shape + N = y.shape[1] + k = tvm.reduce_axis((0, K), name='k') + return tvm.compute((batch, M, N), + lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k), + tag='batch_matmul') diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 81dd379257e0..0ccc422010c1 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -19,3 +19,4 @@ from .l2_normalize_python import l2_normalize_python from .gather_nd_python import gather_nd_python from .strided_slice_python import strided_slice_python +from .batch_matmul import batch_matmul diff --git a/topi/python/topi/testing/batch_matmul.py b/topi/python/topi/testing/batch_matmul.py new file mode 100644 index 000000000000..a7b2f9344f29 --- /dev/null +++ b/topi/python/topi/testing/batch_matmul.py @@ -0,0 +1,26 @@ +# pylint: disable=invalid-name +"""Batch matmul in python""" +import numpy as np + +def batch_matmul(x, y): + """batch_matmul operator implemented in numpy. + + Parameters + ---------- + x : numpy.ndarray + 3-D with shape [batch, M, K] + + y : numpy.ndarray + 3-D with shape [batch, N, K] + + Returns + ------- + out : numpy.ndarray + 3-D with shape [batch, M, N] + """ + batch, M, _ = x.shape + N = y.shape[1] + out = np.zeros((batch, M, N)).astype(x.dtype) + for i in range(batch): + out[i] = np.dot(x[i], y[i].T) + return out diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 6d7326580f6d..d630628b4379 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -255,3 +255,29 @@ def select_array(i, j): return now return tvm.compute(matrix.shape, select_array, name=name) + + +def get_max_power2_factor(n, max_value=None): + """Get max factor of n in power of 2. If max_value is specificed, max factor + value will be no more max_value, + + Parameter + --------- + n : int + The input value + + max_value : int, optional + The max value for the factor + + Returns + ------- + factor : int + The max factor in power of 2. + """ + x = 1 + while n % 2 == 0: + if max_value is not None and max_value < x * 2: + break + x *= 2 + n /= 2 + return x diff --git a/topi/python/topi/x86/batch_matmul.py b/topi/python/topi/x86/batch_matmul.py new file mode 100644 index 000000000000..37890e389366 --- /dev/null +++ b/topi/python/topi/x86/batch_matmul.py @@ -0,0 +1,53 @@ +# pylint: disable=invalid-name,too-many-locals,unused-variable +"""x86 batch_matmul operators""" +from __future__ import absolute_import as _abs +import tvm + +from .. import generic +from ..util import traverse_inline, get_const_tuple, get_max_power2_factor + + +@generic.schedule_batch_matmul.register(["cpu"]) +def schedule_batch_matmul(outs): + """Schedule for batch_matmul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of batch_matmul + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if "batch_matmul" in op.tag: + C = op.output(0) + A, B = s[C].op.input_tensors + _, M, N = get_const_tuple(C.shape) + k, = s[C].op.reduce_axis + ko, ki = s[C].split(k, 16) + CC = s.rfactor(C, ki) + + b, y, x = s[C].op.axis + y_bn = get_max_power2_factor(M, 8) + x_bn = get_max_power2_factor(N, 8) + yo, yi = s[C].split(y, y_bn) + xo, xi = s[C].split(x, x_bn) + s[C].reorder(b, yo, xo, yi, xi) + bxyo = s[C].fuse(b, yo, xo) + s[C].parallel(bxyo) + s[C].fuse(yi, xi) + + s[CC].compute_at(s[C], bxyo) + _, _, y, x = s[CC].op.axis + s[CC].fuse(y, x) + s[CC].vectorize(s[CC].op.axis[0]) + s[C].pragma(bxyo, 'auto_unroll_max_step', 16) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py new file mode 100644 index 000000000000..33575b4c399d --- /dev/null +++ b/topi/python/topi/x86/dense.py @@ -0,0 +1,208 @@ +# pylint: disable=invalid-name,too-many-locals,unused-variable +"""x86 dense operators""" +from __future__ import absolute_import as _abs +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity + +from .util import get_fp32_len +from .. import generic, tag, nn +from ..util import traverse_inline, get_const_tuple + +@autotvm.register_topi_compute(nn.dense, "cpu", "direct") +def _declaration_dense(cfg, data, weight, bias=None): + batch, _ = get_const_tuple(data.shape) + + # For small batch sizes, don't pack weight into cache-friendly layout + # because of overhead in packing and limited reuse from batch dimension + # TODO(icemelon9): use a more systematic way to determine which schedule to use + if batch <= 16: + return _declaration_dense_nopack(cfg, data, weight, bias) + return _declaration_dense_pack(cfg, data, weight, bias) + + +# Declare dense compute with packing weight into cache-friendly layout +@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack") +def _declaration_dense_pack(cfg, data, weight, bias=None): + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) + # create tuning space + cfg.define_split("tile_y", batch, num_outputs=3) + cfg.define_split("tile_x", out_dim, num_outputs=3) + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + _default_dense_pack_config(cfg, batch, out_dim, in_dim) + + packw_bn = cfg["tile_x"].size[-1] + packw_shape = (out_dim // packw_bn, in_dim, packw_bn) + packw = tvm.compute(packw_shape, + lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") + + k = tvm.reduce_axis((0, in_dim), name="k") + C = tvm.compute((batch, out_dim), + lambda y, x: tvm.sum( + data[y, k] * packw[x // packw_bn, k, x % packw_bn], + axis=k), + tag="dense_pack") + if bias is not None: + C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j], + tag=tag.BROADCAST) + return C + + +# Declare dense compute without packing weight +@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack") +def _declaration_dense_nopack(cfg, data, weight, bias=None): + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) + # create tuning space + cfg.define_split("tile_x", out_dim, num_outputs=2) + cfg.define_split("tile_y", batch, num_outputs=2) + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + _default_dense_nopack_config(cfg, batch, out_dim, in_dim) + + vec = cfg["tile_k"].size[-1] + k = tvm.reduce_axis((0, in_dim // vec), "k") + CC = tvm.compute((batch, out_dim, vec), + lambda z, y, x: tvm.sum( + data[z, k * vec + x] * weight[y, k * vec + x], axis=k)) + + kk = tvm.reduce_axis((0, vec), "kk") + C = tvm.compute((batch, out_dim), + lambda y, x: tvm.sum(CC[y, x, kk], axis=kk), + tag="dense_nopack") + if bias is not None: + C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j], + tag=tag.BROADCAST) + + return C + + +@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct") +def _schedule_dense(cfg, outs): + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_pack" in op.tag: + _schedule_dense_pack_template(cfg, s, op.output(0)) + elif 'dense_nopack' in op.tag: + _schedule_dense_nopack_template(cfg, s, op.output(0)) + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") +def _schedule_dense_pack(cfg, outs): + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_pack" in op.tag: + _schedule_dense_pack_template(cfg, s, op.output(0)) + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") +def _schedule_dense_nopack(cfg, outs): + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'dense_nopack' in op.tag: + _schedule_dense_nopack_template(cfg, s, op.output(0)) + traverse_inline(s, outs[0].op, _callback) + return s + + +def _schedule_dense_pack_template(cfg, s, C): + A, packedB = s[C].op.input_tensors + + CC = s.cache_write(C, "global") + y, x = s[C].op.axis + k, = s[CC].op.reduce_axis + + yt, yo, yi = cfg["tile_y"].apply(s, C, y) + xt, xo, xi = cfg["tile_x"].apply(s, C, x) + s[C].reorder(yt, xt, yo, xo, yi, xi) + xyt = s[C].fuse(yt, xt) + s[C].parallel(xyt) + xyo = s[C].fuse(yo, xo) + s[C].unroll(yi) + s[C].vectorize(xi) + + s[CC].compute_at(s[C], xyo) + y, x = s[CC].op.axis + ko, ki = cfg["tile_k"].apply(s, CC, k) + s[CC].reorder(ko, ki, y, x) + s[CC].vectorize(x) + s[CC].unroll(y) + s[CC].unroll(ki) + + z, y, x = s[packedB].op.axis + s[packedB].reorder(z, x, y) + s[packedB].parallel(z) + s[packedB].vectorize(y) + return s + + +def _schedule_dense_nopack_template(cfg, s, C): + y, x = s[C].op.axis + kk, = s[C].op.reduce_axis + yo, yi = cfg["tile_y"].apply(s, C, y) + xo, xi = cfg["tile_x"].apply(s, C, x) + s[C].reorder(yo, xo, yi, xi) + xyo = s[C].fuse(yo, xo) + s[C].parallel(xyo) + s[C].unroll(kk) + + CC, = s[C].op.input_tensors + s[CC].compute_at(s[C], xyo) + z, y, x = s[CC].op.axis + k, = s[CC].op.reduce_axis + yz = s[CC].fuse(z, y) + s[CC].reorder(k, yz, x) + s[CC].unroll(yz) + s[CC].vectorize(x) + return s + + +def _default_dense_pack_config(cfg, M, N, K): + vec_width = get_fp32_len() + + tilex_ii = 1 + for bn in range(vec_width*2, 0, -1): + if N % bn == 0: + tilex_ii = bn + break + NN = N // tilex_ii + tilex_oi = 1 + while NN // tilex_oi > 4: + if (NN // tilex_oi) % 2 == 1: + break + tilex_oi *= 2 + + tiley_ii = 8 + while M % tiley_ii != 0: + tiley_ii //= 2 + MM = M // tiley_ii + tiley_oi = 1 + while MM // tiley_oi > 4: + if (MM // tiley_oi) % 2 == 1: + break + tiley_oi *= 2 + + cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) + cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) + cfg["tile_k"] = SplitEntity([K, 1]) + + +def _default_dense_nopack_config(cfg, M, N, K): + vec_width = get_fp32_len() + tilek_bn = 1 + for bn in range(vec_width*2, 0, -1): + if K % bn == 0: + tilek_bn = bn + break + cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) + cfg["tile_x"] = SplitEntity([N, 1]) + cfg["tile_y"] = SplitEntity([1, M]) diff --git a/topi/python/topi/x86/nn.py b/topi/python/topi/x86/nn.py index ab6dda40cc9d..73463242e96d 100644 --- a/topi/python/topi/x86/nn.py +++ b/topi/python/topi/x86/nn.py @@ -2,12 +2,7 @@ """x86 nn operators""" from __future__ import absolute_import as _abs import tvm -from tvm import autotvm -from tvm.autotvm.task.space import SplitEntity - -from .util import get_fp32_len -from .. import generic, tag, nn -from ..util import traverse_inline, get_const_tuple +from .. import generic @generic.schedule_softmax.register(["cpu"]) def schedule_softmax(outs): @@ -37,205 +32,3 @@ def schedule_softmax(outs): else: s[x].parallel(s[x].op.axis[0]) return s - - -@autotvm.register_topi_compute(nn.dense, "cpu", "direct") -def _declaration_dense(cfg, data, weight, bias=None): - batch, _ = get_const_tuple(data.shape) - - # For small batch sizes, don't pack weight into cache-friendly layout - # because of overhead in packing and limited reuse from batch dimension - # TODO(icemelon9): use a more systematic way to determine which schedule to use - if batch <= 16: - return _declaration_dense_nopack(cfg, data, weight, bias) - return _declaration_dense_pack(cfg, data, weight, bias) - - -# Declare dense compute with packing weight into cache-friendly layout -@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack") -def _declaration_dense_pack(cfg, data, weight, bias=None): - batch, in_dim = get_const_tuple(data.shape) - out_dim, _ = get_const_tuple(weight.shape) - # create tuning space - cfg.define_split("tile_y", batch, num_outputs=3) - cfg.define_split("tile_x", out_dim, num_outputs=3) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - _default_dense_pack_config(cfg, batch, out_dim, in_dim) - - packw_bn = cfg["tile_x"].size[-1] - packw_shape = (out_dim // packw_bn, in_dim, packw_bn) - packw = tvm.compute(packw_shape, - lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") - - k = tvm.reduce_axis((0, in_dim), name="k") - C = tvm.compute((batch, out_dim), - lambda y, x: tvm.sum( - data[y, k] * packw[x // packw_bn, k, x % packw_bn], - axis=k), - tag="dense_pack") - if bias is not None: - C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j], - tag=tag.BROADCAST) - return C - - -# Declare dense compute without packing weight -@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack") -def _declaration_dense_nopack(cfg, data, weight, bias=None): - batch, in_dim = get_const_tuple(data.shape) - out_dim, _ = get_const_tuple(weight.shape) - # create tuning space - cfg.define_split("tile_x", out_dim, num_outputs=2) - cfg.define_split("tile_y", batch, num_outputs=2) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - _default_dense_nopack_config(cfg, batch, out_dim, in_dim) - - vec = cfg["tile_k"].size[-1] - k = tvm.reduce_axis((0, in_dim // vec), "k") - CC = tvm.compute((batch, out_dim, vec), - lambda z, y, x: tvm.sum( - data[z, k * vec + x] * weight[y, k * vec + x], axis=k)) - - kk = tvm.reduce_axis((0, vec), "kk") - C = tvm.compute((batch, out_dim), - lambda y, x: tvm.sum(CC[y, x, kk], axis=kk), - tag="dense_nopack") - if bias is not None: - C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j], - tag=tag.BROADCAST) - - return C - - -@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct") -def _schedule_dense(cfg, outs): - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _callback(op): - if "dense_pack" in op.tag: - _schedule_dense_pack_template(cfg, s, op.output(0)) - elif 'dense_nopack' in op.tag: - _schedule_dense_nopack_template(cfg, s, op.output(0)) - traverse_inline(s, outs[0].op, _callback) - return s - - -@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") -def _schedule_dense_pack(cfg, outs): - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _callback(op): - if "dense_pack" in op.tag: - _schedule_dense_pack_template(cfg, s, op.output(0)) - traverse_inline(s, outs[0].op, _callback) - return s - - -@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") -def _schedule_dense_nopack(cfg, outs): - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _callback(op): - if 'dense_nopack' in op.tag: - _schedule_dense_nopack_template(cfg, s, op.output(0)) - traverse_inline(s, outs[0].op, _callback) - return s - - -def _schedule_dense_pack_template(cfg, s, C): - A, packedB = s[C].op.input_tensors - - CC = s.cache_write(C, "global") - y, x = s[C].op.axis - k, = s[CC].op.reduce_axis - - yt, yo, yi = cfg["tile_y"].apply(s, C, y) - xt, xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(yt, xt, yo, xo, yi, xi) - xyt = s[C].fuse(yt, xt) - s[C].parallel(xyt) - xyo = s[C].fuse(yo, xo) - s[C].unroll(yi) - s[C].vectorize(xi) - - s[CC].compute_at(s[C], xyo) - y, x = s[CC].op.axis - ko, ki = cfg["tile_k"].apply(s, CC, k) - s[CC].reorder(ko, ki, y, x) - s[CC].vectorize(x) - s[CC].unroll(y) - s[CC].unroll(ki) - - z, y, x = s[packedB].op.axis - s[packedB].reorder(z, x, y) - s[packedB].parallel(z) - s[packedB].vectorize(y) - return s - - -def _schedule_dense_nopack_template(cfg, s, C): - y, x = s[C].op.axis - kk, = s[C].op.reduce_axis - yo, yi = cfg["tile_y"].apply(s, C, y) - xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(yo, xo, yi, xi) - xyo = s[C].fuse(yo, xo) - s[C].parallel(xyo) - s[C].unroll(kk) - - CC, = s[C].op.input_tensors - s[CC].compute_at(s[C], xyo) - z, y, x = s[CC].op.axis - k, = s[CC].op.reduce_axis - yz = s[CC].fuse(z, y) - s[CC].reorder(k, yz, x) - s[CC].unroll(yz) - s[CC].vectorize(x) - return s - - -def _default_dense_pack_config(cfg, M, N, K): - vec_width = get_fp32_len() - - tilex_ii = 1 - for bn in range(vec_width*2, 0, -1): - if N % bn == 0: - tilex_ii = bn - break - NN = N // tilex_ii - tilex_oi = 1 - while NN // tilex_oi > 4: - if (NN // tilex_oi) % 2 == 1: - break - tilex_oi *= 2 - - tiley_ii = 8 - while M % tiley_ii != 0: - tiley_ii //= 2 - MM = M // tiley_ii - tiley_oi = 1 - while MM // tiley_oi > 4: - if (MM // tiley_oi) % 2 == 1: - break - tiley_oi *= 2 - - cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) - cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) - cfg["tile_k"] = SplitEntity([K, 1]) - - -def _default_dense_nopack_config(cfg, M, N, K): - vec_width = get_fp32_len() - tilek_bn = 1 - for bn in range(vec_width*2, 0, -1): - if K % bn == 0: - tilek_bn = bn - break - cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) - cfg["tile_x"] = SplitEntity([N, 1]) - cfg["tile_y"] = SplitEntity([1, M]) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index aac2d1653c78..6fa748547cd9 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -351,6 +352,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense") *rv = nn::dense(args[0], args[1], args[2]); }); +/* Ops from nn/batch_matmul.h */ +TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::batch_matmul(args[0], args[1]); + }); + /* Ops from nn/dilate.h */ TVM_REGISTER_GLOBAL("topi.nn.dilate") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -589,6 +596,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense) .register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) .register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); +TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) +.set_default(WrapSchedule(topi::generic::default_schedule)); + TVM_REGISTER_GENERIC_FUNC(schedule_pool) .set_default(WrapSchedule(topi::generic::default_schedule)) .register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) diff --git a/topi/tests/python/test_topi_batch_matmul.py b/topi/tests/python/test_topi_batch_matmul.py new file mode 100644 index 000000000000..f699d6aa8dcb --- /dev/null +++ b/topi/tests/python/test_topi_batch_matmul.py @@ -0,0 +1,53 @@ +"""Test code for batch_matmul operator""" +import numpy as np +import tvm +import topi +import topi.testing +from topi.util import get_const_tuple +from tvm.contrib.pickle_memoize import memoize + +from common import get_all_backend + +def verify_batch_matmul(batch, M, N, K): + x = tvm.placeholder((batch, M, K), name='x') + y = tvm.placeholder((batch, N, K), name='y') + dtype = x.dtype + + # use memoize to pickle the test data for next time use + @memoize("topi.tests.test_topi_batch_matmul") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(batch, N, K)).astype(dtype) + c_np = topi.testing.batch_matmul(a_np, b_np) + return (a_np, b_np, c_np) + # get the test data + a_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + out = topi.nn.batch_matmul(x, y) + s = topi.generic.schedule_batch_matmul([out]) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), ctx) + f = tvm.build(s, [x, y, out], device, name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + +def test_batch_matmul(): + verify_batch_matmul(1, 16, 16, 32) + verify_batch_matmul(5, 16, 16, 32) + verify_batch_matmul(5, 16, 20, 32) + verify_batch_matmul(30, 16, 20, 32) + + +if __name__ == "__main__": + test_batch_matmul()